first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

294
core/schemas/account.go Normal file
View File

@@ -0,0 +1,294 @@
// Package schemas defines the core schemas and types used by the Bifrost system.
package schemas
import (
"context"
"fmt"
"slices"
"strings"
)
type KeyStatusType string
const (
KeyStatusSuccess KeyStatusType = "success"
KeyStatusListModelsFailed KeyStatusType = "list_models_failed"
)
// WhiteList is a list of values that are allowed to be used.
// Semantics:
// - "*" (alone) means all values are allowed.
// - Empty list means nothing is allowed.
// - Non-empty list (without "*") means only the listed values are allowed.
//
// This type is used generically for any field that needs whitelist behavior
// (e.g., allowed models, allowed tools).
type WhiteList []string
// Contains reports whether value is in the whitelist.
// Returns true if value is in the list.
func (wl WhiteList) Contains(value string) bool {
return slices.ContainsFunc(wl, func(s string) bool {
return strings.EqualFold(s, value)
})
}
// IsAllowed reports whether value is in the whitelist.
// Returns true if value is in the list.
func (wl WhiteList) IsAllowed(value string) bool {
return wl.IsUnrestricted() || wl.Contains(value)
}
// IsEmpty reports whether the whitelist has no entries.
func (wl WhiteList) IsEmpty() bool {
return len(wl) == 0
}
// IsUnrestricted reports whether the whitelist contains only "*",
// meaning all values are allowed.
func (wl WhiteList) IsUnrestricted() bool {
return len(wl) == 1 && wl[0] == "*"
}
// IsRestricted reports whether the whitelist contains entries other than "*",
// meaning only the listed values are allowed.
func (wl WhiteList) IsRestricted() bool {
return !wl.IsUnrestricted()
}
// Validate checks that the whitelist is well-formed.
// Returns an error if "*" is present alongside other values, or if there are duplicate entries.
func (wl WhiteList) Validate() error {
if wl.Contains("*") && len(wl) > 1 {
return fmt.Errorf("wildcard '*' cannot be used with other values in the whitelist")
}
seen := make(map[string]struct{}, len(wl))
for _, v := range wl {
normalized := strings.ToLower(v)
if _, ok := seen[normalized]; ok {
return fmt.Errorf("duplicate value '%s' in whitelist", v)
}
seen[normalized] = struct{}{}
}
return nil
}
// BlackList is a list of values that are denied.
// Semantics:
// - "*" (alone) means all values are blocked.
// - Empty list means nothing is blocked.
// - Non-empty list (without "*") means only the listed values are blocked.
type BlackList []string
func (bl BlackList) Contains(value string) bool {
return slices.ContainsFunc(bl, func(s string) bool {
return strings.EqualFold(s, value)
})
}
// IsBlocked reports whether value is blocked.
func (bl BlackList) IsBlocked(value string) bool {
return bl.IsBlockAll() || bl.Contains(value)
}
// IsEmpty reports whether the blacklist has no entries (nothing is blocked).
func (bl BlackList) IsEmpty() bool {
return len(bl) == 0
}
// IsBlockAll reports whether the blacklist contains "*", meaning all values are blocked.
func (bl BlackList) IsBlockAll() bool {
return len(bl) == 1 && bl[0] == "*"
}
// Validate checks that the blacklist is well-formed.
func (bl BlackList) Validate() error {
if bl.Contains("*") && len(bl) > 1 {
return fmt.Errorf("wildcard '*' cannot be used with other values in the blacklist")
}
seen := make(map[string]struct{}, len(bl))
for _, v := range bl {
normalized := strings.ToLower(v)
if _, ok := seen[normalized]; ok {
return fmt.Errorf("duplicate value '%s' in blacklist", v)
}
seen[normalized] = struct{}{}
}
return nil
}
// Key represents an API key and its associated configuration for a provider.
// It contains the key value, supported models, and a weight for load balancing.
type Key struct {
ID string `json:"id"` // The unique identifier for the key (used by bifrost to identify the key)
Name string `json:"name"` // The name of the key (used by users to identify the key, not used by bifrost)
Value EnvVar `json:"value"` // The actual API key value
Models WhiteList `json:"models"` // List of models this key can access
BlacklistedModels BlackList `json:"blacklisted_models"` // List of models this key cannot access
Weight float64 `json:"weight"` // Weight for load balancing between multiple keys
Aliases KeyAliases `json:"aliases,omitempty"` // Mapping of model identifiers to inference profiles
AzureKeyConfig *AzureKeyConfig `json:"azure_key_config,omitempty"` // Azure-specific key configuration
VertexKeyConfig *VertexKeyConfig `json:"vertex_key_config,omitempty"` // Vertex-specific key configuration
BedrockKeyConfig *BedrockKeyConfig `json:"bedrock_key_config,omitempty"` // AWS Bedrock-specific key configuration
VLLMKeyConfig *VLLMKeyConfig `json:"vllm_key_config,omitempty"` // vLLM-specific key configuration
ReplicateKeyConfig *ReplicateKeyConfig `json:"replicate_key_config,omitempty"` // Replicate-specific key configuration
OllamaKeyConfig *OllamaKeyConfig `json:"ollama_key_config,omitempty"` // Ollama-specific key configuration
SGLKeyConfig *SGLKeyConfig `json:"sgl_key_config,omitempty"` // SGLang-specific key configuration
Enabled *bool `json:"enabled,omitempty"` // Whether the key is active (default:true)
UseForBatchAPI *bool `json:"use_for_batch_api,omitempty"` // Whether this key can be used for batch API operations (default:false for new keys, migrated keys default to true)
ConfigHash string `json:"config_hash,omitempty"` // Hash of config.json version, used for change detection
Status KeyStatusType `json:"status,omitempty"` // Status of key
Description string `json:"description,omitempty"` // Description of key
}
type KeyAliases map[string]string
func (ka KeyAliases) Validate() error {
seen := make(map[string]struct{}, len(ka))
for from, to := range ka {
if strings.TrimSpace(from) == "" {
return fmt.Errorf("alias source cannot be empty")
}
if strings.TrimSpace(to) == "" {
return fmt.Errorf("alias target for %q cannot be empty", from)
}
if strings.TrimSpace(from) != from {
return fmt.Errorf("alias source %q cannot have leading or trailing whitespace", from)
}
if strings.TrimSpace(to) != to {
return fmt.Errorf("alias target for %q cannot have leading or trailing whitespace", from)
}
normalized := strings.ToLower(from)
if _, ok := seen[normalized]; ok {
return fmt.Errorf("duplicate alias source %q (case-insensitive)", from)
}
seen[normalized] = struct{}{}
}
return nil
}
func (ka KeyAliases) Resolve(model string) string {
if ka == nil {
return model
}
if alias, ok := ka[model]; ok {
return alias
}
// Fall back to case-insensitive lookup for consistency with WhiteList.Contains
for k, v := range ka {
if strings.EqualFold(k, model) {
return v
}
}
return model
}
type AzureAuthType string
const (
AzureAuthTypeClientSecret AzureAuthType = "client_secret"
AzureAuthTypeManagedIdentity AzureAuthType = "managed_identity"
)
// AzureKeyConfig represents the Azure-specific configuration.
// It contains Azure-specific settings required for service access and deployment management.
type AzureKeyConfig struct {
Endpoint EnvVar `json:"endpoint"` // Azure service endpoint URL
APIVersion *EnvVar `json:"api_version,omitempty"` // Azure API version to use; defaults to "2024-10-21"
ClientID *EnvVar `json:"client_id,omitempty"` // Azure client ID for authentication
ClientSecret *EnvVar `json:"client_secret,omitempty"` // Azure client secret for authentication
TenantID *EnvVar `json:"tenant_id,omitempty"` // Azure tenant ID for authentication
Scopes []string `json:"scopes,omitempty"`
}
// VertexKeyConfig represents the Vertex-specific configuration.
// It contains Vertex-specific settings required for authentication and service access.
type VertexKeyConfig struct {
ProjectID EnvVar `json:"project_id"`
ProjectNumber EnvVar `json:"project_number"`
Region EnvVar `json:"region"`
AuthCredentials EnvVar `json:"auth_credentials"`
}
// NOTE: To use Vertex IAM role authentication, set AuthCredentials to empty string.
// S3BucketConfig represents a single S3 bucket configuration for batch operations.
type S3BucketConfig struct {
BucketName string `json:"bucket_name"` // S3 bucket name
Prefix string `json:"prefix,omitempty"` // S3 key prefix for batch files
IsDefault bool `json:"is_default,omitempty"` // Whether this is the default bucket for batch operations
}
// BatchS3Config holds S3 bucket configurations for Bedrock batch operations.
// Supports multiple buckets to allow flexible batch job routing.
type BatchS3Config struct {
Buckets []S3BucketConfig `json:"buckets,omitempty"` // List of S3 bucket configurations
}
// BedrockKeyConfig represents the AWS Bedrock-specific configuration.
// It contains AWS-specific settings required for authentication and service access.
type BedrockKeyConfig struct {
AccessKey EnvVar `json:"access_key,omitempty"` // AWS access key for authentication
SecretKey EnvVar `json:"secret_key,omitempty"` // AWS secret access key for authentication
SessionToken *EnvVar `json:"session_token,omitempty"` // AWS session token for temporary credentials
Region *EnvVar `json:"region,omitempty"` // AWS region for service access
ARN *EnvVar `json:"arn,omitempty"` // Amazon Resource Name for resource identification
// IAM role for STS AssumeRole
RoleARN *EnvVar `json:"role_arn,omitempty"`
ExternalID *EnvVar `json:"external_id,omitempty"`
RoleSessionName *EnvVar `json:"session_name,omitempty"`
BatchS3Config *BatchS3Config `json:"batch_s3_config,omitempty"` // S3 bucket configuration for batch operations
}
// NOTE: To use Bedrock IAM role authentication, set both AccessKey and SecretKey to empty strings.
// To use Bedrock API Key authentication, set Value in Key struct instead.
// VLLMKeyConfig represents the vLLM-specific key configuration.
// It allows each key to target a different vLLM server URL and model name,
// enabling per-key routing and round-robin load balancing across multiple vLLM instances.
type VLLMKeyConfig struct {
URL EnvVar `json:"url"` // VLLM server base URL (required, supports env. prefix)
ModelName string `json:"model_name"` // Exact model name served on this VLLM instance (used for key selection)
}
// ReplicateKeyConfig represents the Replicate-specific key configuration.
// It contains Replicate-specific settings required for authentication and service access.
type ReplicateKeyConfig struct {
UseDeploymentsEndpoint bool `json:"use_deployments_endpoint"` // Whether to use the deployments endpoint instead of the models endpoint
}
// OllamaKeyConfig represents the Ollama-specific key configuration.
// It allows each key to target a different Ollama server URL,
// enabling per-key routing and round-robin load balancing across multiple Ollama instances.
type OllamaKeyConfig struct {
URL EnvVar `json:"url"` // Ollama server base URL (required, supports env. prefix)
}
// SGLKeyConfig represents the SGLang-specific key configuration.
// It allows each key to target a different SGLang server URL,
// enabling per-key routing and round-robin load balancing across multiple SGLang instances.
type SGLKeyConfig struct {
URL EnvVar `json:"url"` // SGLang server base URL (required, supports env. prefix)
}
// Account defines the interface for managing provider accounts and their configurations.
// It provides methods to access provider-specific settings, API keys, and configurations.
type Account interface {
// GetConfiguredProviders returns a list of providers that are configured
// in the account. This is used to determine which providers are available for use.
GetConfiguredProviders() ([]ModelProvider, error)
// GetKeysForProvider returns the API keys configured for a specific provider.
// The keys include their values, supported models, and weights for load balancing.
// The context can carry data from any source that sets values before the Bifrost request,
// including but not limited to plugin pre-hooks, application logic, or any in app middleware sharing the context.
// This enables dynamic key selection based on any context values present during the request.
GetKeysForProvider(ctx context.Context, providerKey ModelProvider) ([]Key, error)
// GetConfigForProvider returns the configuration for a specific provider.
// This includes network settings, authentication details, and other provider-specific
// configurations.
GetConfigForProvider(providerKey ModelProvider) (*ProviderConfig, error)
}

34
core/schemas/async.go Normal file
View File

@@ -0,0 +1,34 @@
package schemas
import "time"
// AsyncJobStatus represents the status of an async job
type AsyncJobStatus string
const (
AsyncJobStatusPending AsyncJobStatus = "pending"
AsyncJobStatusProcessing AsyncJobStatus = "processing"
AsyncJobStatusCompleted AsyncJobStatus = "completed"
AsyncJobStatusFailed AsyncJobStatus = "failed"
)
const (
// AsyncHeaderResultTTL is the header containing the result TTL for async job retrieval.
AsyncHeaderResultTTL = "x-bf-async-job-result-ttl"
// AsyncHeaderCreate is the header that triggers async job creation on integration routes.
AsyncHeaderCreate = "x-bf-async"
// AsyncHeaderGetID is the header containing the job ID for async job retrieval on integration routes.
AsyncHeaderGetID = "x-bf-async-id"
)
// AsyncJobResponse is the JSON response returned when creating or polling an async job
type AsyncJobResponse struct {
ID string `json:"id"`
Status AsyncJobStatus `json:"status"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
CompletedAt *time.Time `json:"completed_at,omitempty"`
StatusCode int `json:"status_code,omitempty"`
Result interface{} `json:"result,omitempty"`
Error *BifrostError `json:"error,omitempty"`
}

329
core/schemas/batch.go Normal file
View File

@@ -0,0 +1,329 @@
// Package schemas defines the core schemas and types used by the Bifrost system.
package schemas
// BatchStatus represents the status of a batch job.
type BatchStatus string
const (
BatchStatusValidating BatchStatus = "validating"
BatchStatusFailed BatchStatus = "failed"
BatchStatusInProgress BatchStatus = "in_progress"
BatchStatusFinalizing BatchStatus = "finalizing"
BatchStatusCompleted BatchStatus = "completed"
BatchStatusExpired BatchStatus = "expired"
BatchStatusCancelling BatchStatus = "cancelling"
BatchStatusCancelled BatchStatus = "cancelled"
BatchStatusEnded BatchStatus = "ended" // Anthropic-specific
BatchStatusDeleted BatchStatus = "deleted" // Gemini-specific
)
// BatchEndpoint represents supported batch API endpoints.
type BatchEndpoint string
const (
BatchEndpointChatCompletions BatchEndpoint = "/v1/chat/completions"
BatchEndpointEmbeddings BatchEndpoint = "/v1/embeddings"
BatchEndpointCompletions BatchEndpoint = "/v1/completions"
BatchEndpointResponses BatchEndpoint = "/v1/responses"
BatchEndpointMessages BatchEndpoint = "/v1/messages" // Anthropic
)
// BatchRequestItem represents a single request in a batch (for inline requests).
type BatchRequestItem struct {
CustomID string `json:"custom_id"` // User-provided unique ID for this request
Method string `json:"method,omitempty"` // HTTP method (typically "POST")
URL string `json:"url,omitempty"` // Endpoint URL (e.g., "/v1/chat/completions")
Body map[string]interface{} `json:"body,omitempty"` // Request body parameters
Params map[string]interface{} `json:"params,omitempty"` // Alternative to Body for Anthropic
}
// BatchRequestCounts tracks the counts of requests in different states.
type BatchRequestCounts struct {
Total int `json:"total"`
Completed int `json:"completed"`
Failed int `json:"failed"`
Succeeded int `json:"succeeded,omitempty"` // Anthropic-specific
Expired int `json:"expired,omitempty"` // Anthropic-specific
Canceled int `json:"canceled,omitempty"` // Anthropic-specific
Pending int `json:"pending,omitempty"` // Anthropic-specific
}
// BatchErrors represents errors encountered during batch processing.
type BatchErrors struct {
Object string `json:"object,omitempty"`
Data []BatchError `json:"data,omitempty"`
}
// BatchError represents a single error in batch processing.
type BatchError struct {
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
Param string `json:"param,omitempty"`
Line *int `json:"line,omitempty"`
}
// BifrostBatchCreateRequest represents a request to create a batch job.
type BifrostBatchCreateRequest struct {
Provider ModelProvider `json:"provider"`
Model *string `json:"model,omitempty"` // Model hint for routing (optional for file-based) it may or may not present depending on the provider and usage of integration vs direct API
RawRequestBody []byte `json:"-"` // Raw request body (not serialized)
// OpenAI-style: file-based batching
InputFileID string `json:"input_file_id,omitempty"` // ID of uploaded JSONL file
// Anthropic-style: inline requests
Requests []BatchRequestItem `json:"requests,omitempty"` // Inline request items
// Common fields
Endpoint BatchEndpoint `json:"endpoint,omitempty"` // Target endpoint for batch requests
CompletionWindow string `json:"completion_window,omitempty"` // Time window (e.g., "24h")
Metadata map[string]string `json:"metadata,omitempty"` // User-provided metadata
OutputExpiresAfter *BatchExpiresAfter `json:"output_expires_after,omitempty"` // Expiration for batch output (OpenAI only)
// Extra parameters for provider-specific features
ExtraParams map[string]interface{} `json:"-"`
}
// BatchExpiresAfter represents an expiration configuration for batch output.
type BatchExpiresAfter struct {
Anchor string `json:"anchor"` // e.g., "created_at"
Seconds int `json:"seconds"` // 3600-2592000 (1 hour to 30 days)
}
// GetRawRequestBody returns the raw request body.
func (request *BifrostBatchCreateRequest) GetRawRequestBody() []byte {
return request.RawRequestBody
}
// BifrostBatchCreateResponse represents the response from creating a batch job.
type BifrostBatchCreateResponse struct {
ID string `json:"id"`
Object string `json:"object,omitempty"` // "batch" for OpenAI
Endpoint string `json:"endpoint,omitempty"`
InputFileID string `json:"input_file_id,omitempty"`
CompletionWindow string `json:"completion_window,omitempty"`
Status BatchStatus `json:"status"`
RequestCounts BatchRequestCounts `json:"request_counts,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
CreatedAt int64 `json:"created_at,omitempty"`
ExpiresAt *int64 `json:"expires_at,omitempty"`
// Output file references (OpenAI)
OutputFileID *string `json:"output_file_id,omitempty"`
ErrorFileID *string `json:"error_file_id,omitempty"`
// Anthropic-specific
ProcessingStatus *string `json:"processing_status,omitempty"`
ResultsURL *string `json:"results_url,omitempty"`
// Gemini-specific (operation response)
OperationName *string `json:"operation_name,omitempty"`
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}
// BifrostBatchListRequest represents a request to list batch jobs.
type BifrostBatchListRequest struct {
Provider ModelProvider `json:"provider"`
Model *string `json:"model"`
// Pagination
Limit int `json:"limit,omitempty"` // Max results to return
After *string `json:"after,omitempty"` // Cursor for pagination (OpenAI)
BeforeID *string `json:"before_id,omitempty"` // Pagination cursor (Anthropic)
AfterID *string `json:"after_id,omitempty"` // Pagination cursor (Anthropic)
PageToken *string `json:"page_token,omitempty"` // For Gemini pagination
PageSize int `json:"page_size,omitempty"` // For Gemini pagination
NextCursor *string `json:"next_cursor,omitempty"` // For Gemini pagination
// Extra parameters for provider-specific features
ExtraParams map[string]interface{} `json:"-"`
}
// BifrostBatchListResponse represents the response from listing batch jobs.
type BifrostBatchListResponse struct {
Object string `json:"object,omitempty"` // "list"
Data []BifrostBatchRetrieveResponse `json:"data"`
FirstID *string `json:"first_id,omitempty"`
LastID *string `json:"last_id,omitempty"`
HasMore bool `json:"has_more,omitempty"`
// Anthropic pagination
NextCursor *string `json:"next_cursor,omitempty"` // For cursor-based pagination
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}
// BifrostBatchRetrieveRequest represents a request to retrieve a batch job.
type BifrostBatchRetrieveRequest struct {
Provider ModelProvider `json:"provider"`
Model *string `json:"model"`
BatchID string `json:"batch_id"` // ID of the batch to retrieve
RawRequestBody []byte `json:"-"` // Raw request body (not serialized)
// Extra parameters for provider-specific features
ExtraParams map[string]interface{} `json:"-"`
}
// GetRawRequestBody returns the raw request body.
func (request *BifrostBatchRetrieveRequest) GetRawRequestBody() []byte {
return request.RawRequestBody
}
// BifrostBatchRetrieveResponse represents the response from retrieving a batch job.
type BifrostBatchRetrieveResponse struct {
ID string `json:"id"`
Object string `json:"object,omitempty"`
Endpoint string `json:"endpoint,omitempty"`
InputFileID string `json:"input_file_id,omitempty"`
CompletionWindow string `json:"completion_window,omitempty"`
Status BatchStatus `json:"status"`
RequestCounts BatchRequestCounts `json:"request_counts,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
CreatedAt int64 `json:"created_at,omitempty"`
ExpiresAt *int64 `json:"expires_at,omitempty"`
InProgressAt *int64 `json:"in_progress_at,omitempty"`
FinalizingAt *int64 `json:"finalizing_at,omitempty"`
CompletedAt *int64 `json:"completed_at,omitempty"`
FailedAt *int64 `json:"failed_at,omitempty"`
ExpiredAt *int64 `json:"expired_at,omitempty"`
CancellingAt *int64 `json:"cancelling_at,omitempty"`
CancelledAt *int64 `json:"cancelled_at,omitempty"`
// Output references
OutputFileID *string `json:"output_file_id,omitempty"`
ErrorFileID *string `json:"error_file_id,omitempty"`
Errors *BatchErrors `json:"errors,omitempty"`
// Anthropic-specific
ProcessingStatus *string `json:"processing_status,omitempty"`
ResultsURL *string `json:"results_url,omitempty"`
ArchivedAt *int64 `json:"archived_at,omitempty"`
// Gemini-specific
OperationName *string `json:"operation_name,omitempty"`
Done *bool `json:"done,omitempty"`
Progress *int `json:"progress,omitempty"` // Percentage progress
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}
// BifrostBatchCancelRequest represents a request to cancel a batch job.
type BifrostBatchCancelRequest struct {
Provider ModelProvider `json:"provider"`
Model *string `json:"model"`
BatchID string `json:"batch_id"` // ID of the batch to cancel
RawRequestBody []byte `json:"-"` // Raw request body (not serialized)
// Extra parameters for provider-specific features
ExtraParams map[string]interface{} `json:"-"`
}
// GetRawRequestBody returns the raw request body.
func (request *BifrostBatchCancelRequest) GetRawRequestBody() []byte {
return request.RawRequestBody
}
// BifrostBatchCancelResponse represents the response from cancelling a batch job.
type BifrostBatchCancelResponse struct {
ID string `json:"id"`
Object string `json:"object,omitempty"`
Status BatchStatus `json:"status"`
RequestCounts BatchRequestCounts `json:"request_counts,omitempty"`
CancellingAt *int64 `json:"cancelling_at,omitempty"`
CancelledAt *int64 `json:"cancelled_at,omitempty"`
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}
// BifrostBatchDeleteRequest represents a request to delete a batch job.
type BifrostBatchDeleteRequest struct {
Provider ModelProvider `json:"provider"`
Model *string `json:"model"`
BatchID string `json:"batch_id"` // ID of the batch to delete
RawRequestBody []byte `json:"-"` // Raw request body (not serialized)
// Extra parameters for provider-specific features
ExtraParams map[string]interface{} `json:"-"`
}
// GetRawRequestBody returns the raw request body.
func (request *BifrostBatchDeleteRequest) GetRawRequestBody() []byte {
return request.RawRequestBody
}
// BifrostBatchDeleteResponse represents the response from deleting a batch job.
type BifrostBatchDeleteResponse struct {
ID string `json:"id"`
Object string `json:"object,omitempty"`
Status BatchStatus `json:"status"`
RequestCounts BatchRequestCounts `json:"request_counts,omitempty"`
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}
// BifrostBatchResultsRequest represents a request to retrieve batch results.
type BifrostBatchResultsRequest struct {
Provider ModelProvider `json:"provider"`
Model *string `json:"model"`
BatchID string `json:"batch_id"` // ID of the batch to get results for
RawRequestBody []byte `json:"-"` // Raw request body (not serialized)
// For OpenAI, results are retrieved via output_file_id (file download)
// For Anthropic, results are streamed from a dedicated endpoint
// Extra parameters for provider-specific features
ExtraParams map[string]interface{} `json:"-"`
}
// GetRawRequestBody returns the raw request body.
func (request *BifrostBatchResultsRequest) GetRawRequestBody() []byte {
return request.RawRequestBody
}
// BatchResultItem represents a single result from a batch request.
type BatchResultItem struct {
CustomID string `json:"custom_id"`
// Result data (varies by request type)
Response *BatchResultResponse `json:"response,omitempty"` // OpenAI format
Result *BatchResultData `json:"result,omitempty"` // Anthropic format
// Error if the individual request failed
Error *BatchResultError `json:"error,omitempty"`
}
// BatchResultResponse represents OpenAI-style result response.
type BatchResultResponse struct {
StatusCode int `json:"status_code"`
RequestID string `json:"request_id,omitempty"`
Body map[string]interface{} `json:"body,omitempty"`
}
// BatchResultData represents Anthropic-style result data.
type BatchResultData struct {
Type string `json:"type"` // "succeeded", "errored", "expired", "canceled"
Message map[string]interface{} `json:"message,omitempty"`
}
// BatchResultError represents an error for a single batch request.
type BatchResultError struct {
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}
// BifrostBatchResultsResponse represents the response from retrieving batch results.
type BifrostBatchResultsResponse struct {
BatchID string `json:"batch_id"`
Results []BatchResultItem `json:"results"`
// For streaming results (Anthropic)
HasMore bool `json:"has_more,omitempty"`
NextCursor *string `json:"next_cursor,omitempty"`
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}

1287
core/schemas/bifrost.go Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

255
core/schemas/containers.go Normal file
View File

@@ -0,0 +1,255 @@
// Package schemas defines the core schemas and types used by the Bifrost system.
package schemas
// ContainerStatus represents the status of a container.
type ContainerStatus string
const (
ContainerStatusRunning ContainerStatus = "running"
)
// ContainerExpiresAfter represents the expiration configuration for a container.
type ContainerExpiresAfter struct {
Anchor string `json:"anchor"` // The anchor point for expiration (e.g., "last_active_at")
Minutes int `json:"minutes"` // Number of minutes after anchor point
}
// ContainerObject represents a container object returned by the API.
type ContainerObject struct {
ID string `json:"id"`
Object string `json:"object,omitempty"` // "container"
Name string `json:"name"`
CreatedAt int64 `json:"created_at"`
Status ContainerStatus `json:"status,omitempty"`
ExpiresAfter *ContainerExpiresAfter `json:"expires_after,omitempty"`
LastActiveAt *int64 `json:"last_active_at,omitempty"`
MemoryLimit string `json:"memory_limit,omitempty"` // e.g., "1g", "4g"
Metadata map[string]string `json:"metadata,omitempty"`
}
// BifrostContainerCreateRequest represents a request to create a container.
type BifrostContainerCreateRequest struct {
Provider ModelProvider `json:"provider"`
// Required fields
Name string `json:"name"` // Name of the container
// Optional fields
ExpiresAfter *ContainerExpiresAfter `json:"expires_after,omitempty"` // Expiration configuration
FileIDs []string `json:"file_ids,omitempty"` // IDs of existing files to copy into this container
MemoryLimit string `json:"memory_limit,omitempty"` // Memory limit (e.g., "1g", "4g")
Metadata map[string]string `json:"metadata,omitempty"` // User-provided metadata
// Extra parameters for provider-specific features
ExtraParams map[string]interface{} `json:"-"`
}
// BifrostContainerCreateResponse represents the response from creating a container.
type BifrostContainerCreateResponse struct {
ID string `json:"id"`
Object string `json:"object,omitempty"` // "container"
Name string `json:"name"`
CreatedAt int64 `json:"created_at"`
Status ContainerStatus `json:"status,omitempty"`
ExpiresAfter *ContainerExpiresAfter `json:"expires_after,omitempty"`
LastActiveAt *int64 `json:"last_active_at,omitempty"`
MemoryLimit string `json:"memory_limit,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}
// BifrostContainerListRequest represents a request to list containers.
type BifrostContainerListRequest struct {
Provider ModelProvider `json:"provider"`
// Pagination
Limit int `json:"limit,omitempty"` // Max results to return (1-100, default 20)
After *string `json:"after,omitempty"` // Cursor for pagination
Order *string `json:"order,omitempty"` // Sort order (asc/desc), default desc
// Extra parameters for provider-specific features
ExtraParams map[string]interface{} `json:"-"`
}
// BifrostContainerListResponse represents the response from listing containers.
type BifrostContainerListResponse struct {
Object string `json:"object,omitempty"` // "list"
Data []ContainerObject `json:"data"`
FirstID *string `json:"first_id,omitempty"`
LastID *string `json:"last_id,omitempty"`
HasMore bool `json:"has_more,omitempty"`
After *string `json:"after,omitempty"` // Encoded cursor for next page (includes key index for multi-key pagination)
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}
// BifrostContainerRetrieveRequest represents a request to retrieve a container.
type BifrostContainerRetrieveRequest struct {
Provider ModelProvider `json:"provider"`
ContainerID string `json:"container_id"` // ID of the container to retrieve
// Extra parameters for provider-specific features
ExtraParams map[string]interface{} `json:"-"`
}
// BifrostContainerRetrieveResponse represents the response from retrieving a container.
type BifrostContainerRetrieveResponse struct {
ID string `json:"id"`
Object string `json:"object,omitempty"` // "container"
Name string `json:"name"`
CreatedAt int64 `json:"created_at"`
Status ContainerStatus `json:"status,omitempty"`
ExpiresAfter *ContainerExpiresAfter `json:"expires_after,omitempty"`
LastActiveAt *int64 `json:"last_active_at,omitempty"`
MemoryLimit string `json:"memory_limit,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}
// BifrostContainerDeleteRequest represents a request to delete a container.
type BifrostContainerDeleteRequest struct {
Provider ModelProvider `json:"provider"`
ContainerID string `json:"container_id"` // ID of the container to delete
// Extra parameters for provider-specific features
ExtraParams map[string]interface{} `json:"-"`
}
// BifrostContainerDeleteResponse represents the response from deleting a container.
type BifrostContainerDeleteResponse struct {
ID string `json:"id"`
Object string `json:"object,omitempty"` // "container.deleted"
Deleted bool `json:"deleted"`
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}
// =============================================================================
// CONTAINER FILES API
// =============================================================================
// ContainerFileObject represents a file within a container.
type ContainerFileObject struct {
ID string `json:"id"`
Object string `json:"object,omitempty"` // "container.file"
Bytes int64 `json:"bytes"`
CreatedAt int64 `json:"created_at"`
ContainerID string `json:"container_id"`
Path string `json:"path"`
Source string `json:"source"` // "user" typically
}
// BifrostContainerFileCreateRequest represents a request to create a file in a container.
type BifrostContainerFileCreateRequest struct {
Provider ModelProvider `json:"provider"`
ContainerID string `json:"container_id"` // ID of the container
// One of these must be provided
File []byte `json:"-"` // File content (for multipart upload)
FileID *string `json:"file_id,omitempty"` // Reference to existing file
Path *string `json:"file_path,omitempty"` // Path for the file in the container
// Extra parameters for provider-specific features
ExtraParams map[string]interface{} `json:"-"`
}
// BifrostContainerFileCreateResponse represents the response from creating a container file.
type BifrostContainerFileCreateResponse struct {
ID string `json:"id"`
Object string `json:"object,omitempty"` // "container.file"
Bytes int64 `json:"bytes"`
CreatedAt int64 `json:"created_at"`
ContainerID string `json:"container_id"`
Path string `json:"path"`
Source string `json:"source"`
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}
// BifrostContainerFileListRequest represents a request to list files in a container.
type BifrostContainerFileListRequest struct {
Provider ModelProvider `json:"provider"`
ContainerID string `json:"container_id"` // ID of the container
// Pagination
Limit int `json:"limit,omitempty"` // Max results to return (1-100, default 20)
After *string `json:"after,omitempty"` // Cursor for pagination
Order *string `json:"order,omitempty"` // Sort order (asc/desc), default desc
// Extra parameters for provider-specific features
ExtraParams map[string]interface{} `json:"-"`
}
// BifrostContainerFileListResponse represents the response from listing container files.
type BifrostContainerFileListResponse struct {
Object string `json:"object,omitempty"` // "list"
Data []ContainerFileObject `json:"data"`
FirstID *string `json:"first_id,omitempty"`
LastID *string `json:"last_id,omitempty"`
HasMore bool `json:"has_more,omitempty"`
After *string `json:"after,omitempty"` // Encoded cursor for next page (includes key index for multi-key pagination)
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}
// BifrostContainerFileRetrieveRequest represents a request to retrieve a container file.
type BifrostContainerFileRetrieveRequest struct {
Provider ModelProvider `json:"provider"`
ContainerID string `json:"container_id"` // ID of the container
FileID string `json:"file_id"` // ID of the file to retrieve
// Extra parameters for provider-specific features
ExtraParams map[string]interface{} `json:"-"`
}
// BifrostContainerFileRetrieveResponse represents the response from retrieving a container file.
type BifrostContainerFileRetrieveResponse struct {
ID string `json:"id"`
Object string `json:"object,omitempty"` // "container.file"
Bytes int64 `json:"bytes"`
CreatedAt int64 `json:"created_at"`
ContainerID string `json:"container_id"`
Path string `json:"path"`
Source string `json:"source"`
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}
// BifrostContainerFileContentRequest represents a request to retrieve the content of a container file.
type BifrostContainerFileContentRequest struct {
Provider ModelProvider `json:"provider"`
ContainerID string `json:"container_id"` // ID of the container
FileID string `json:"file_id"` // ID of the file
// Extra parameters for provider-specific features
ExtraParams map[string]interface{} `json:"-"`
}
// BifrostContainerFileContentResponse represents the response from retrieving container file content.
type BifrostContainerFileContentResponse struct {
Content []byte `json:"content"` // Raw file content
ContentType string `json:"content_type"` // MIME type of the content
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}
// BifrostContainerFileDeleteRequest represents a request to delete a container file.
type BifrostContainerFileDeleteRequest struct {
Provider ModelProvider `json:"provider"`
ContainerID string `json:"container_id"` // ID of the container
FileID string `json:"file_id"` // ID of the file to delete
// Extra parameters for provider-specific features
ExtraParams map[string]interface{} `json:"-"`
}
// BifrostContainerFileDeleteResponse represents the response from deleting a container file.
type BifrostContainerFileDeleteResponse struct {
ID string `json:"id"`
Object string `json:"object,omitempty"` // "container.file.deleted"
Deleted bool `json:"deleted"`
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}

501
core/schemas/context.go Normal file
View File

@@ -0,0 +1,501 @@
package schemas
import (
"context"
"slices"
"sync"
"sync/atomic"
"time"
)
var NoDeadline time.Time
var reservedKeys = []any{
BifrostContextKeyVirtualKey,
BifrostContextKeyAPIKeyName,
BifrostContextKeyAPIKeyID,
BifrostContextKeyRequestID,
BifrostContextKeyFallbackRequestID,
BifrostContextKeyDirectKey,
BifrostContextKeySelectedKeyID,
BifrostContextKeySelectedKeyName,
BifrostContextKeyNumberOfRetries,
BifrostContextKeyFallbackIndex,
BifrostContextKeySkipKeySelection,
BifrostContextKeyURLPath,
BifrostContextKeyDeferTraceCompletion,
BifrostContextKeyAttemptTrail,
}
// pluginLogStore holds plugin log entries accumulated during request processing.
// It is shared between the root BifrostContext and all scoped contexts derived from it.
// Uses a flat slice (not map) to minimize heap allocations.
type pluginLogStore struct {
mu sync.Mutex
logs []PluginLogEntry
}
// pluginLogStorePool pools pluginLogStore instances to reduce per-request allocations.
var pluginLogStorePool = sync.Pool{
New: func() any {
return &pluginLogStore{logs: make([]PluginLogEntry, 0, 8)}
},
}
// pluginScopePool pools BifrostContext instances used as scoped plugin contexts.
var pluginScopePool = sync.Pool{
New: func() any {
return &BifrostContext{}
},
}
// BifrostContext is a custom context.Context implementation that tracks user-set values.
// It supports deadlines, can be derived from other contexts, and provides layered
// value inheritance when derived from another BifrostContext.
type BifrostContext struct {
parent context.Context
deadline time.Time
hasDeadline bool
done chan struct{}
doneOnce sync.Once
err error
errMu sync.RWMutex
userValues map[any]any
valuesMu sync.RWMutex
blockRestrictedWrites atomic.Bool
// Plugin scoping fields
pluginScope *string // Non-nil when this is a scoped plugin context
pluginLogs atomic.Pointer[pluginLogStore] // Shared log store; lazily initialized on root, shared by scoped contexts
valueDelegate *BifrostContext // For scoped contexts: delegate Value/SetValue to this root context
}
// NewBifrostContext creates a new BifrostContext with the given parent context and deadline.
// If the deadline is zero, no deadline is set on this context (though the parent may have one).
// The context will be cancelled when the deadline expires or when the parent context is cancelled.
func NewBifrostContext(parent context.Context, deadline time.Time) *BifrostContext {
if parent == nil {
parent = context.Background()
}
ctx := &BifrostContext{
parent: parent,
deadline: deadline,
hasDeadline: !deadline.IsZero(),
done: make(chan struct{}),
userValues: make(map[any]any),
blockRestrictedWrites: atomic.Bool{},
}
ctx.blockRestrictedWrites.Store(false)
// Only start goroutine if there's something to watch:
// - If we have a deadline, we need the timer
// - If parent can be cancelled (Done() != nil) AND is not a non-cancelling context
// - If parent has a deadline, we need a timer (parent may not properly cancel via Done())
_, parentHasDeadline := parent.Deadline()
parentCanCancel := parent.Done() != nil && !isNonCancellingContext(parent)
if ctx.hasDeadline || parentCanCancel || parentHasDeadline {
go ctx.watchCancellation()
}
return ctx
}
// NewBifrostContextWithValue creates a new BifrostContext with the given value set.
func NewBifrostContextWithValue(parent context.Context, deadline time.Time, key any, value any) *BifrostContext {
ctx := NewBifrostContext(parent, deadline)
ctx.SetValue(key, value)
return ctx
}
// NewBifrostContextWithTimeout creates a new BifrostContext with a timeout duration.
// This is a convenience wrapper around NewBifrostContext.
// Returns the context and a cancel function that should be called to release resources.
func NewBifrostContextWithTimeout(parent context.Context, timeout time.Duration) (*BifrostContext, context.CancelFunc) {
ctx := NewBifrostContext(parent, time.Now().Add(timeout))
return ctx, func() { ctx.Cancel() }
}
// NewBifrostContextWithCancel creates a new BifrostContext with a cancel function.
// This is a convenience wrapper around NewBifrostContext.
// Returns the context and a cancel function that should be called to release resources.
func NewBifrostContextWithCancel(parent context.Context) (*BifrostContext, context.CancelFunc) {
ctx := NewBifrostContext(parent, NoDeadline)
return ctx, func() { ctx.Cancel() }
}
// WithValue returns a new context with the given value set.
func (bc *BifrostContext) WithValue(key any, value any) *BifrostContext {
bc.SetValue(key, value)
return bc
}
// BlockRestrictedWrites returns true if restricted writes are blocked.
func (bc *BifrostContext) BlockRestrictedWrites() {
bc.blockRestrictedWrites.Store(true)
}
// UnblockRestrictedWrites unblocks restricted writes.
func (bc *BifrostContext) UnblockRestrictedWrites() {
bc.blockRestrictedWrites.Store(false)
}
// Cancel cancels the context, closing the Done channel and setting the error to context.Canceled.
func (bc *BifrostContext) Cancel() {
bc.cancel(context.Canceled)
}
// watchCancellation monitors for deadline expiration and parent cancellation.
func (bc *BifrostContext) watchCancellation() {
var timer <-chan time.Time
// Use effective deadline (considers both own and parent deadlines)
// This handles cases where parent has a deadline but doesn't properly
// cancel via Done() (e.g., fasthttp.RequestCtx)
if effectiveDeadline, hasDeadline := bc.Deadline(); hasDeadline {
duration := time.Until(effectiveDeadline)
if duration <= 0 {
// Deadline already passed
bc.cancel(context.DeadlineExceeded)
return
}
t := time.NewTimer(duration)
defer t.Stop()
timer = t.C
}
// Don't watch parent.Done() for contexts known to never close it
// (e.g., fasthttp.RequestCtx pools contexts and never cancels them)
if isNonCancellingContext(bc.parent) {
select {
case <-timer:
bc.cancel(context.DeadlineExceeded)
case <-bc.done:
// Already cancelled
}
return
}
select {
case <-bc.parent.Done():
bc.cancel(bc.parent.Err())
case <-timer:
bc.cancel(context.DeadlineExceeded)
case <-bc.done:
// Already cancelled
}
}
// cancel closes the done channel and sets the error.
func (bc *BifrostContext) cancel(err error) {
bc.doneOnce.Do(func() {
bc.errMu.Lock()
bc.err = err
bc.errMu.Unlock()
close(bc.done)
})
}
// Deadline returns the deadline for this context.
// For scoped contexts, delegates to the root context.
// If both this context and the parent have deadlines, the earlier one is returned.
func (bc *BifrostContext) Deadline() (time.Time, bool) {
if bc.valueDelegate != nil {
return bc.valueDelegate.Deadline()
}
parentDeadline, parentHasDeadline := bc.parent.Deadline()
if !bc.hasDeadline && !parentHasDeadline {
return time.Time{}, false
}
if !bc.hasDeadline {
return parentDeadline, true
}
if !parentHasDeadline {
return bc.deadline, true
}
// Both have deadlines, return the earlier one
if bc.deadline.Before(parentDeadline) {
return bc.deadline, true
}
return parentDeadline, true
}
// Done returns a channel that is closed when the context is cancelled.
func (bc *BifrostContext) Done() <-chan struct{} {
return bc.done
}
// Err returns the error explaining why the context was cancelled.
// For scoped contexts, delegates to the root context.
// Returns nil if the context has not been cancelled.
func (bc *BifrostContext) Err() error {
if bc.valueDelegate != nil {
return bc.valueDelegate.Err()
}
bc.errMu.RLock()
defer bc.errMu.RUnlock()
return bc.err
}
// Value returns the value associated with the key.
// For scoped contexts, delegates to the root context via valueDelegate.
// Otherwise checks the internal userValues map, then delegates to the parent context.
func (bc *BifrostContext) Value(key any) any {
if bc.valueDelegate != nil {
return bc.valueDelegate.Value(key)
}
bc.valuesMu.RLock()
if val, ok := bc.userValues[key]; ok {
bc.valuesMu.RUnlock()
return val
}
bc.valuesMu.RUnlock()
if bc.parent == nil {
return nil
}
return bc.parent.Value(key)
}
// SetValue sets a value in the internal userValues map.
// For scoped contexts, delegates to the root context via valueDelegate.
// This is thread-safe and can be called concurrently.
func (bc *BifrostContext) SetValue(key, value any) {
if bc.valueDelegate != nil {
bc.valueDelegate.SetValue(key, value)
return
}
// Check if the key is a reserved key
if bc.blockRestrictedWrites.Load() && slices.Contains(reservedKeys, key) {
// we silently drop writes for these reserved keys
return
}
bc.valuesMu.Lock()
defer bc.valuesMu.Unlock()
if bc.userValues == nil {
bc.userValues = make(map[any]any)
}
bc.userValues[key] = value
}
// ClearValue clears a value from the internal userValues map.
// For scoped contexts, delegates to the root context via valueDelegate.
func (bc *BifrostContext) ClearValue(key any) {
if bc.valueDelegate != nil {
bc.valueDelegate.ClearValue(key)
return
}
// Check if the key is a reserved key
if bc.blockRestrictedWrites.Load() && slices.Contains(reservedKeys, key) {
// we silently drop writes for these reserved keys
return
}
bc.valuesMu.Lock()
defer bc.valuesMu.Unlock()
if bc.userValues != nil {
bc.userValues[key] = nil
}
}
// GetAndSetValue gets a value from the internal userValues map and sets it.
// For scoped contexts, delegates to the root context via valueDelegate.
func (bc *BifrostContext) GetAndSetValue(key any, value any) any {
if bc.valueDelegate != nil {
return bc.valueDelegate.GetAndSetValue(key, value)
}
bc.valuesMu.Lock()
defer bc.valuesMu.Unlock()
// Check if the key is a reserved key
if bc.blockRestrictedWrites.Load() && slices.Contains(reservedKeys, key) {
// we silently drop writes for these reserved keys
return bc.userValues[key]
}
if bc.userValues == nil {
bc.userValues = make(map[any]any)
}
oldValue := bc.userValues[key]
bc.userValues[key] = value
return oldValue
}
// GetUserValues returns a copy of all user-set values in this context.
// If the parent is also a PluginContext, the values are merged with parent values
// (this context's values take precedence over parent values).
func (bc *BifrostContext) GetUserValues() map[any]any {
result := make(map[any]any)
// First, get parent's user values if parent is a PluginContext
if parentCtx, ok := bc.parent.(*BifrostContext); ok {
for k, v := range parentCtx.GetUserValues() {
result[k] = v
}
}
// Then overlay with our own values (our values take precedence)
bc.valuesMu.RLock()
for k, v := range bc.userValues {
result[k] = v
}
bc.valuesMu.RUnlock()
return result
}
// GetParentCtxWithUserValues returns a copy of the parent context with all user-set values merged in.
func (bc *BifrostContext) GetParentCtxWithUserValues() context.Context {
parentCtx := bc.parent
bc.valuesMu.RLock()
for k, v := range bc.userValues {
parentCtx = context.WithValue(parentCtx, k, v)
}
bc.valuesMu.RUnlock()
return parentCtx
}
// AppendRoutingEngineLog appends a routing engine log entry to the context.
// Parameters:
// - ctx: The Bifrost context
// - engineName: Name of the routing engine (e.g., "governance", "routing-rule")
// - message: Human-readable log message describing the decision/action
func (bc *BifrostContext) AppendRoutingEngineLog(engineName string, message string) {
entry := RoutingEngineLogEntry{
Engine: engineName,
Message: message,
Timestamp: time.Now().UnixMilli(),
}
AppendToContextList(bc, BifrostContextKeyRoutingEngineLogs, entry)
}
// GetRoutingEngineLogs retrieves all routing engine logs from the context.
// Parameters:
// - ctx: The Bifrost context
//
// Returns:
// - []RoutingEngineLogEntry: Slice of routing engine log entries (nil if none)
func (bc *BifrostContext) GetRoutingEngineLogs() []RoutingEngineLogEntry {
if val := bc.Value(BifrostContextKeyRoutingEngineLogs); val != nil {
if logs, ok := val.([]RoutingEngineLogEntry); ok {
return logs
}
}
return nil
}
// AppendToContextList appends a value to the context list value.
// Parameters:
// - ctx: The Bifrost context
// - key: The key to append the value to
// - value: The value to append
func AppendToContextList[T any](ctx *BifrostContext, key BifrostContextKey, value T) {
if ctx == nil {
return
}
existingValues, ok := ctx.Value(key).([]T)
if !ok {
existingValues = []T{}
}
ctx.SetValue(key, append(existingValues, value))
}
// WithPluginScope returns a lightweight scoped BifrostContext from the pool.
// The scoped context shares the root's pluginLogs store and delegates all
// Value/SetValue operations to the root context.
// Call ReleasePluginScope() when done to return the scoped context to the pool.
func (bc *BifrostContext) WithPluginScope(name *string) *BifrostContext {
// Lazily initialize the plugin log store on the root context (CAS to avoid race)
if bc.pluginLogs.Load() == nil {
newStore := pluginLogStorePool.Get().(*pluginLogStore)
if !bc.pluginLogs.CompareAndSwap(nil, newStore) {
// Another goroutine initialized first — return unused store to pool
pluginLogStorePool.Put(newStore)
}
}
scoped := pluginScopePool.Get().(*BifrostContext)
scoped.parent = bc.parent
scoped.done = bc.done
scoped.pluginScope = name
scoped.pluginLogs.Store(bc.pluginLogs.Load())
scoped.valueDelegate = bc
return scoped
}
// ReleasePluginScope returns a scoped context to the pool.
// Safe no-op if called on a non-scoped context.
// Do not use the scoped context after calling this method.
func (bc *BifrostContext) ReleasePluginScope() {
if bc.valueDelegate == nil {
return // not a scoped context
}
bc.parent = nil
bc.done = nil
bc.pluginScope = nil
bc.pluginLogs.Store(nil)
bc.valueDelegate = nil
pluginScopePool.Put(bc)
}
// Log appends a structured log entry for the current plugin scope.
// No-op if the context is not scoped to a plugin or has no log store.
func (bc *BifrostContext) Log(level LogLevel, msg string) {
store := bc.pluginLogs.Load()
if bc.pluginScope == nil || store == nil {
return
}
store.mu.Lock()
store.logs = append(store.logs, PluginLogEntry{
PluginName: *bc.pluginScope,
Level: level,
Message: msg,
Timestamp: time.Now().UnixMilli(),
})
store.mu.Unlock()
}
// GetPluginLogs returns a deep copy of all accumulated plugin log entries.
// Thread-safe. Returns nil if no logs have been recorded.
func (bc *BifrostContext) GetPluginLogs() []PluginLogEntry {
store := bc.pluginLogs.Load()
if store == nil {
return nil
}
store.mu.Lock()
defer store.mu.Unlock()
if len(store.logs) == 0 {
return nil
}
copied := make([]PluginLogEntry, len(store.logs))
copy(copied, store.logs)
return copied
}
// DrainPluginLogs transfers ownership of the plugin log slice to the caller.
// The internal log store is returned to the pool after draining.
// Returns nil if no logs have been recorded.
// This should be called once on the root context after all plugin hooks have completed.
func (bc *BifrostContext) DrainPluginLogs() []PluginLogEntry {
if bc.valueDelegate != nil {
return nil // scoped contexts must not drain the shared log store
}
store := bc.pluginLogs.Load()
if store == nil {
return nil
}
bc.pluginLogs.Store(nil)
store.mu.Lock()
logs := store.logs
// Reset with fresh pre-allocated slice before returning to pool
store.logs = make([]PluginLogEntry, 0, 8)
store.mu.Unlock()
// Return the store to the pool for reuse
pluginLogStorePool.Put(store)
if len(logs) == 0 {
return nil
}
return logs
}

View File

@@ -0,0 +1,12 @@
//go:build !tinygo && !wasm
package schemas
import "github.com/valyala/fasthttp"
// isNonCancellingContext returns true if the context is known to have
// a Done() channel that never closes (e.g., fasthttp.RequestCtx).
func isNonCancellingContext(parent any) bool {
_, ok := parent.(*fasthttp.RequestCtx)
return ok
}

View File

@@ -0,0 +1,331 @@
package schemas
import (
"context"
"runtime"
"testing"
"time"
)
func TestNewBifrostContext_NoGoroutineLeakWithBackgroundAndNoDeadline(t *testing.T) {
// Get baseline goroutine count
runtime.GC()
time.Sleep(10 * time.Millisecond)
baseline := runtime.NumGoroutine()
// Create multiple contexts with context.Background() and no deadline
// Previously this would leak goroutines
contexts := make([]*BifrostContext, 100)
for i := 0; i < 100; i++ {
contexts[i] = NewBifrostContext(context.Background(), NoDeadline)
}
// Give time for any goroutines to start
runtime.Gosched()
time.Sleep(10 * time.Millisecond)
// Check goroutine count - should not have increased significantly
// (allow some slack for runtime/test goroutines)
afterCreate := runtime.NumGoroutine()
// With the fix, no goroutines should be spawned since there's nothing to watch
// Allow a small margin for test framework goroutines
if afterCreate > baseline+10 {
t.Errorf("Goroutine leak detected: baseline=%d, after creating 100 contexts=%d", baseline, afterCreate)
}
// Verify the contexts still work correctly
for i, ctx := range contexts {
// Should not be cancelled
select {
case <-ctx.Done():
t.Errorf("Context %d should not be done", i)
default:
// Expected
}
// Should return nil error
if ctx.Err() != nil {
t.Errorf("Context %d Err() should be nil, got %v", i, ctx.Err())
}
// Should have no deadline
if _, ok := ctx.Deadline(); ok {
t.Errorf("Context %d should not have deadline", i)
}
}
// Explicitly cancel all contexts
for _, ctx := range contexts {
ctx.Cancel()
}
// Verify all are cancelled
for i, ctx := range contexts {
select {
case <-ctx.Done():
// Expected
default:
t.Errorf("Context %d should be done after Cancel()", i)
}
if ctx.Err() != context.Canceled {
t.Errorf("Context %d Err() should be context.Canceled, got %v", i, ctx.Err())
}
}
}
func TestNewBifrostContext_GoroutineStartsWithDeadline(t *testing.T) {
// Get baseline goroutine count
runtime.GC()
time.Sleep(10 * time.Millisecond)
baseline := runtime.NumGoroutine()
// Create context with a deadline - should spawn goroutine
deadline := time.Now().Add(1 * time.Hour)
ctx := NewBifrostContext(context.Background(), deadline)
// Give time for goroutine to start
runtime.Gosched()
time.Sleep(10 * time.Millisecond)
afterCreate := runtime.NumGoroutine()
// Should have at least one more goroutine for the deadline watcher
if afterCreate <= baseline {
t.Errorf("Expected goroutine to be spawned for deadline context: baseline=%d, after=%d", baseline, afterCreate)
}
// Clean up
ctx.Cancel()
}
func TestNewBifrostContext_GoroutineStartsWithCancellableParent(t *testing.T) {
// Get baseline goroutine count
runtime.GC()
time.Sleep(10 * time.Millisecond)
baseline := runtime.NumGoroutine()
// Create a cancellable parent
parent, parentCancel := context.WithCancel(context.Background())
defer parentCancel()
// Create BifrostContext with cancellable parent but no deadline
// Should spawn goroutine to watch parent
ctx := NewBifrostContext(parent, NoDeadline)
// Give time for goroutine to start
runtime.Gosched()
time.Sleep(10 * time.Millisecond)
afterCreate := runtime.NumGoroutine()
// Should have goroutine for watching parent cancellation
if afterCreate <= baseline {
t.Errorf("Expected goroutine to be spawned for cancellable parent: baseline=%d, after=%d", baseline, afterCreate)
}
// Verify parent cancellation propagates
parentCancel()
time.Sleep(10 * time.Millisecond)
select {
case <-ctx.Done():
// Expected
default:
t.Error("Context should be cancelled when parent is cancelled")
}
if ctx.Err() != context.Canceled {
t.Errorf("Context Err() should be context.Canceled, got %v", ctx.Err())
}
}
func TestNewBifrostContext_DeadlineExpires(t *testing.T) {
// Create context with short deadline
deadline := time.Now().Add(50 * time.Millisecond)
ctx := NewBifrostContext(context.Background(), deadline)
// Should not be done yet
select {
case <-ctx.Done():
t.Error("Context should not be done before deadline")
default:
// Expected
}
// Wait for deadline
time.Sleep(100 * time.Millisecond)
// Should be done now
select {
case <-ctx.Done():
// Expected
default:
t.Error("Context should be done after deadline")
}
if ctx.Err() != context.DeadlineExceeded {
t.Errorf("Context Err() should be context.DeadlineExceeded, got %v", ctx.Err())
}
}
func TestNewBifrostContext_SetAndGetValue(t *testing.T) {
ctx := NewBifrostContext(context.Background(), NoDeadline)
// Set a value
ctx.SetValue("key1", "value1")
// Get the value
if v := ctx.Value("key1"); v != "value1" {
t.Errorf("Expected value1, got %v", v)
}
// Get non-existent key
if v := ctx.Value("nonexistent"); v != nil {
t.Errorf("Expected nil for non-existent key, got %v", v)
}
// Clean up
ctx.Cancel()
}
func TestNewBifrostContext_NilParent(t *testing.T) {
// Should not panic with nil parent
// Note: passing nil is allowed by NewBifrostContext which converts it to context.Background()
var nilCtx context.Context //nolint:staticcheck // testing nil parent handling
ctx := NewBifrostContext(nilCtx, NoDeadline)
// Should work normally
if ctx.Err() != nil {
t.Errorf("New context should have nil error, got %v", ctx.Err())
}
ctx.Cancel()
if ctx.Err() != context.Canceled {
t.Errorf("Cancelled context should have Canceled error, got %v", ctx.Err())
}
}
// Plugin logging tests
func TestPluginLog_NoScopeIsNoop(t *testing.T) {
ctx := NewBifrostContext(context.Background(), NoDeadline)
ctx.Log(LogLevelInfo, "should be ignored")
logs := ctx.GetPluginLogs()
if logs != nil {
t.Errorf("expected nil logs without plugin scope, got %v", logs)
}
}
func TestPluginLog_SinglePlugin(t *testing.T) {
ctx := NewBifrostContext(context.Background(), NoDeadline)
name := "test-plugin"
scoped := ctx.WithPluginScope(&name)
scoped.Log(LogLevelInfo, "hello")
scoped.Log(LogLevelError, "oops")
scoped.ReleasePluginScope()
logs := ctx.GetPluginLogs()
if len(logs) != 2 {
t.Fatalf("expected 2 logs, got %d", len(logs))
}
if logs[0].PluginName != "test-plugin" || logs[0].Level != LogLevelInfo || logs[0].Message != "hello" {
t.Errorf("unexpected first log: %+v", logs[0])
}
if logs[1].Level != LogLevelError || logs[1].Message != "oops" {
t.Errorf("unexpected second log: %+v", logs[1])
}
}
func TestPluginLog_MultiplePlugins(t *testing.T) {
ctx := NewBifrostContext(context.Background(), NoDeadline)
name1 := "plugin-a"
s1 := ctx.WithPluginScope(&name1)
s1.Log(LogLevelDebug, "a-msg")
s1.ReleasePluginScope()
name2 := "plugin-b"
s2 := ctx.WithPluginScope(&name2)
s2.Log(LogLevelWarn, "b-msg")
s2.ReleasePluginScope()
logs := ctx.GetPluginLogs()
if len(logs) != 2 {
t.Fatalf("expected 2 logs, got %d", len(logs))
}
if logs[0].PluginName != "plugin-a" {
t.Errorf("expected plugin-a, got %s", logs[0].PluginName)
}
if logs[1].PluginName != "plugin-b" {
t.Errorf("expected plugin-b, got %s", logs[1].PluginName)
}
}
func TestPluginLog_DrainTransfersOwnership(t *testing.T) {
ctx := NewBifrostContext(context.Background(), NoDeadline)
name := "drain-test"
scoped := ctx.WithPluginScope(&name)
scoped.Log(LogLevelInfo, "msg1")
scoped.ReleasePluginScope()
drained := ctx.DrainPluginLogs()
if len(drained) != 1 {
t.Fatalf("expected 1 drained log, got %d", len(drained))
}
// After drain, GetPluginLogs should return nil
after := ctx.GetPluginLogs()
if after != nil {
t.Errorf("expected nil after drain, got %v", after)
}
// Second drain should return nil
second := ctx.DrainPluginLogs()
if second != nil {
t.Errorf("expected nil on second drain, got %v", second)
}
}
func TestPluginLog_ScopedContextValueDelegation(t *testing.T) {
ctx := NewBifrostContext(context.Background(), NoDeadline)
ctx.SetValue(BifrostContextKeyTraceID, "trace-123")
name := "delegate-test"
scoped := ctx.WithPluginScope(&name)
// Scoped should read from root
val := scoped.Value(BifrostContextKeyTraceID)
if val != "trace-123" {
t.Errorf("expected trace-123, got %v", val)
}
// Scoped should write to root
type testContextKey string
const customKey testContextKey = "custom-key"
scoped.SetValue(customKey, "custom-val")
if ctx.Value(customKey) != "custom-val" {
t.Errorf("SetValue on scoped did not delegate to root")
}
scoped.ReleasePluginScope()
}
func TestPluginLog_PoolReuse(t *testing.T) {
ctx := NewBifrostContext(context.Background(), NoDeadline)
// Create and release multiple scoped contexts to exercise the pool
for i := 0; i < 100; i++ {
name := "pool-test"
scoped := ctx.WithPluginScope(&name)
scoped.Log(LogLevelInfo, "pooled")
scoped.ReleasePluginScope()
}
logs := ctx.DrainPluginLogs()
if len(logs) != 100 {
t.Errorf("expected 100 logs from pool reuse, got %d", len(logs))
}
}

View File

@@ -0,0 +1,10 @@
//go:build tinygo || wasm
package schemas
// isNonCancellingContext returns true if the context is known to have
// a Done() channel that never closes. In wasm builds, fasthttp is not
// available, so this always returns false.
func isNonCancellingContext(parent any) bool {
return false
}

View File

@@ -0,0 +1,14 @@
package schemas
// BifrostCountTokensResponse captures token counts for a provided input.
type BifrostCountTokensResponse struct {
Object string `json:"object,omitempty"`
Model string `json:"model"`
InputTokens int `json:"input_tokens"`
InputTokensDetails *ResponsesResponseInputTokens `json:"input_tokens_details,omitempty"`
Tokens []int `json:"tokens"`
TokenStrings []string `json:"token_strings,omitempty"`
OutputTokens *int `json:"output_tokens,omitempty"`
TotalTokens *int `json:"total_tokens"`
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}

187
core/schemas/embedding.go Normal file
View File

@@ -0,0 +1,187 @@
package schemas
import (
"fmt"
)
type BifrostEmbeddingRequest struct {
Provider ModelProvider `json:"provider"`
Model string `json:"model"`
Input *EmbeddingInput `json:"input,omitempty"`
Params *EmbeddingParameters `json:"params,omitempty"`
Fallbacks []Fallback `json:"fallbacks,omitempty"`
RawRequestBody []byte `json:"-"` // set bifrost-use-raw-request-body to true in ctx to use the raw request body. Bifrost will directly send this to the downstream provider.
}
func (r *BifrostEmbeddingRequest) GetRawRequestBody() []byte {
return r.RawRequestBody
}
type BifrostEmbeddingResponse struct {
Data []EmbeddingData `json:"data"` // Maps to "data" field in provider responses (e.g., OpenAI embedding format)
Model string `json:"model"`
Object string `json:"object"` // "list"
Usage *BifrostLLMUsage `json:"usage"`
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}
// EmbeddingInput represents the input for an embedding request.
type EmbeddingInput struct {
Text *string
Texts []string
Embedding []int
Embeddings [][]int
}
func (e *EmbeddingInput) MarshalJSON() ([]byte, error) {
// enforce one-of
set := 0
if e.Text != nil {
set++
}
if e.Texts != nil {
set++
}
if e.Embedding != nil {
set++
}
if e.Embeddings != nil {
set++
}
if set == 0 {
return nil, fmt.Errorf("embedding input is empty")
}
if set > 1 {
return nil, fmt.Errorf("embedding input must set exactly one of: text, texts, embedding, embeddings")
}
if e.Text != nil {
return MarshalSorted(*e.Text)
}
if e.Texts != nil {
return MarshalSorted(e.Texts)
}
if e.Embedding != nil {
return MarshalSorted(e.Embedding)
}
if e.Embeddings != nil {
return MarshalSorted(e.Embeddings)
}
return nil, fmt.Errorf("invalid embedding input")
}
func (e *EmbeddingInput) UnmarshalJSON(data []byte) error {
e.Text = nil
e.Texts = nil
e.Embedding = nil
e.Embeddings = nil
// Try string
var s string
if err := Unmarshal(data, &s); err == nil {
e.Text = &s
return nil
}
// Try []string
var ss []string
if err := Unmarshal(data, &ss); err == nil {
e.Texts = ss
return nil
}
// Try []int
var i []int
if err := Unmarshal(data, &i); err == nil {
e.Embedding = i
return nil
}
// Try [][]int
var i2 [][]int
if err := Unmarshal(data, &i2); err == nil {
e.Embeddings = i2
return nil
}
return fmt.Errorf("unsupported embedding input shape")
}
type EmbeddingParameters struct {
EncodingFormat *string `json:"encoding_format,omitempty"` // Format for embedding output (e.g., "float", "base64")
Dimensions *int `json:"dimensions,omitempty"` // Number of dimensions for embedding output
// Dynamic parameters that can be provider-specific, they are directly
// added to the request as is.
ExtraParams map[string]interface{} `json:"-"`
}
type EmbeddingData struct {
Index int `json:"index"`
Object string `json:"object"` // "embedding"
Embedding EmbeddingStruct `json:"embedding"` // can be string, []float64, [][]float64, []int8, or []int32
}
type EmbeddingStruct struct {
// Embedding responses preserve provider precision in normalized API output.
EmbeddingStr *string
EmbeddingArray []float64
Embedding2DArray [][]float64
EmbeddingInt8Array []int8 // for int8 / binary formats
EmbeddingInt32Array []int32 // for uint8 / ubinary formats
}
func (be EmbeddingStruct) MarshalJSON() ([]byte, error) {
if be.EmbeddingStr != nil {
return MarshalSorted(be.EmbeddingStr)
}
if be.EmbeddingArray != nil {
return MarshalSorted(be.EmbeddingArray)
}
if be.Embedding2DArray != nil {
return MarshalSorted(be.Embedding2DArray)
}
if be.EmbeddingInt8Array != nil {
return Marshal(be.EmbeddingInt8Array)
}
if be.EmbeddingInt32Array != nil {
return Marshal(be.EmbeddingInt32Array)
}
return nil, fmt.Errorf("no embedding found")
}
func (be *EmbeddingStruct) UnmarshalJSON(data []byte) error {
// First, try to unmarshal as a direct string
var stringContent string
if err := Unmarshal(data, &stringContent); err == nil {
be.EmbeddingStr = &stringContent
return nil
}
// Try to unmarshal as a direct array of float64
var arrayContent []float64
if err := Unmarshal(data, &arrayContent); err == nil {
be.EmbeddingArray = arrayContent
return nil
}
// Try to unmarshal as a direct 2D array of float64
var arrayContent2D [][]float64
if err := Unmarshal(data, &arrayContent2D); err == nil {
be.Embedding2DArray = arrayContent2D
return nil
}
// Try to unmarshal as a direct array of int8
var int8Content []int8
if err := Unmarshal(data, &int8Content); err == nil {
be.EmbeddingInt8Array = int8Content
return nil
}
// Try to unmarshal as a direct array of int32
var int32Content []int32
if err := Unmarshal(data, &int32Content); err == nil {
be.EmbeddingInt32Array = int32Content
return nil
}
return fmt.Errorf("embedding field is neither a string, []float64, [][]float64, []int8, nor []int32")
}

353
core/schemas/envvar.go Normal file
View File

@@ -0,0 +1,353 @@
package schemas
import (
"database/sql/driver"
"fmt"
"os"
"strconv"
"strings"
"github.com/bytedance/sonic"
)
// EnvVar is a wrapper around a value that can be sourced from an environment variable.
type EnvVar struct {
Val string `json:"value"`
EnvVar string `json:"env_var"`
FromEnv bool `json:"from_env"`
}
// NewEnvVar creates a new EnvValue from a string.
func NewEnvVar(value string) *EnvVar {
// Cleanup string if required
// Use strconv.Unquote to properly handle JSON string escape sequences
// This converts "\"{\\\"key\\\":\\\"value\\\"}\"" to "{\"key\":\"value\"}"
val := value
if unquoted, err := strconv.Unquote(value); err == nil {
val = unquoted
}
// Here we will need to check if the incoming data is a valid JSON object
// If it's a valid JSON object and follows the EnvVar schema, then we will unmarshal it into an EnvVar object
if sonic.Valid([]byte(value)) {
valueNode, _ := sonic.Get([]byte(val), "value")
envNode, _ := sonic.Get([]byte(val), "env_var")
if valueNode.Exists() && envNode.Exists() {
// Use a type alias to avoid infinite recursion (alias doesn't inherit methods)
type envVarAlias EnvVar
var envVar envVarAlias
if err := sonic.Unmarshal([]byte(value), &envVar); err == nil {
e := &EnvVar{
Val: envVar.Val,
FromEnv: envVar.FromEnv,
EnvVar: envVar.EnvVar,
}
// Here we will check if the Val starts with env and is same as the EnvVar
if strings.HasPrefix(e.Val, "env.") && e.Val == e.EnvVar {
e.Val = ""
// Load the environment variable value
envValue, ok := os.LookupEnv(strings.TrimPrefix(e.EnvVar, "env."))
if ok {
e.Val = envValue
}
e.FromEnv = true
}
return e
}
}
}
if envKey, ok := strings.CutPrefix(val, "env."); ok {
if envValue, ok := os.LookupEnv(envKey); ok {
return &EnvVar{
Val: envValue,
FromEnv: true,
EnvVar: val,
}
}
return &EnvVar{
Val: "",
FromEnv: true,
EnvVar: val,
}
}
return &EnvVar{
Val: val,
FromEnv: false,
EnvVar: "",
}
}
// IsRedacted returns true if the value is redacted.
func (e *EnvVar) IsRedacted() bool {
if e.Val == "" && !e.FromEnv {
return false
}
// Check if it's an environment variable reference
if e.FromEnv {
return true
}
if len(e.Val) <= 8 {
return strings.Count(e.Val, "*") == len(e.Val)
}
// Check for exact redaction pattern: 4 chars + 24 asterisks + 4 chars
if len(e.Val) == 32 {
middle := e.Val[4:28]
if middle == strings.Repeat("*", 24) {
return true
}
}
// Check if its string <redacted>
if e.Val == "<redacted>" {
return true
}
return false
}
// Equals checks if two SecretKeys are equal.
func (e *EnvVar) Equals(other *EnvVar) bool {
if e == nil && other == nil {
return true
}
if e == nil || other == nil {
return false
}
return e.Val == other.Val &&
e.EnvVar == other.EnvVar &&
e.FromEnv == other.FromEnv
}
// Redacted returns a new SecretKey with the value redacted.
func (e *EnvVar) Redacted() *EnvVar {
if e == nil {
return nil
}
if e.Val == "" {
return &EnvVar{
Val: "",
FromEnv: e.FromEnv,
EnvVar: e.EnvVar,
}
}
// If key is 8 characters or less, just return all asterisks
if len(e.Val) <= 8 {
return &EnvVar{
Val: strings.Repeat("*", len(e.Val)),
FromEnv: e.FromEnv,
EnvVar: e.EnvVar,
}
}
// Show first 4 and last 4 characters, replace middle with asterisks
prefix := e.Val[:4]
suffix := e.Val[len(e.Val)-4:]
middle := strings.Repeat("*", 24)
return &EnvVar{
Val: prefix + middle + suffix,
FromEnv: e.FromEnv,
EnvVar: e.EnvVar,
}
}
// MarshalJSON serializes the EnvVar to JSON.
// SECURITY: When the value was sourced from an environment variable, the resolved
// value is automatically redacted before being serialized. This ensures that secrets
// injected via env vars are never leaked through any JSON API response, regardless
// of whether the surrounding code remembered to call Redacted() explicitly.
//
// Plain (non-env) values are still emitted as-is — callers that want to mask those
// must continue using Redacted() at the field level (this matches the existing
// per-provider redaction logic).
//
// This does NOT affect:
// - GORM persistence (uses the Value() driver method, not JSON)
// - Encryption (operates on the Val field directly)
// - Internal LLM request paths (use GetValue() directly)
func (e EnvVar) MarshalJSON() ([]byte, error) {
type envVarAlias EnvVar
out := envVarAlias(e)
if e.FromEnv {
// Redact the resolved value but keep the env var reference and from_env flag
// so the UI still knows which env var backs this field.
redacted := e.Redacted()
if redacted != nil {
out = envVarAlias(*redacted)
}
}
return sonic.Marshal(out)
}
// UnmarshalJSON unmarshals the value from JSON.
func (e *EnvVar) UnmarshalJSON(data []byte) error {
// This is always going to be value
// Here we will first considering this as value
// if it has env. then we will process it and set the FromEnv to true
// if it doesn't have env. then we will set the FromEnv to false
// if it has env. then we will process it and set the FromEnv to true
val := string(data)
// Cleanup string if required
// Use strconv.Unquote to properly handle JSON string escape sequences
// This converts "\"{\\\"key\\\":\\\"value\\\"}\"" to "{\"key\":\"value\"}"
if unquoted, err := strconv.Unquote(val); err == nil {
val = unquoted
}
// Here we will need to check if the incoming data is a valid JSON object
// If it's a valid JSON object and follows the EnvVar schema, then we will unmarshal it into an EnvVar object
if sonic.Valid(data) {
valueNode, _ := sonic.Get(data, "value")
envNode, _ := sonic.Get(data, "env_var")
if valueNode.Exists() && envNode.Exists() {
// Use a type alias to avoid infinite recursion (alias doesn't inherit methods)
type envVarAlias EnvVar
var envVar envVarAlias
if err := sonic.Unmarshal(data, &envVar); err == nil {
e.Val = envVar.Val
e.FromEnv = envVar.FromEnv
e.EnvVar = envVar.EnvVar
// Here we will check if the Val starts with env and is same as the EnvVar
if strings.HasPrefix(e.Val, "env.") && e.Val == e.EnvVar {
e.Val = ""
// Load the environment variable value
envValue, ok := os.LookupEnv(strings.TrimPrefix(e.EnvVar, "env."))
if ok {
e.Val = envValue
}
e.FromEnv = true
}
return nil
}
// Else the value is JSON, so we will treat this as a normal value
}
}
if envKey, ok := strings.CutPrefix(val, "env."); ok {
if envValue, ok := os.LookupEnv(envKey); ok {
e.Val = envValue
e.FromEnv = true
e.EnvVar = val
return nil
}
e.Val = ""
e.FromEnv = true
e.EnvVar = val
return nil
}
e.Val = val
e.FromEnv = false
e.EnvVar = ""
return nil
}
// String returns the value as a string.
func (e *EnvVar) String() string {
return e.Val
}
// Scan scans the value from the database.
func (e *EnvVar) Scan(value any) error {
if value == nil {
e.Val = ""
e.FromEnv = false
e.EnvVar = ""
return nil
}
switch v := value.(type) {
case []byte:
return e.Scan(string(v))
case string:
// Cleanup string if required
// The string may have "\"env.TEST\"", "env.TEST" or "env.TEST\"", we need to clean it up to "env.TEST"
val := strings.Trim(v, "\"")
if envKey, ok := strings.CutPrefix(val, "env."); ok {
if envValue, ok := os.LookupEnv(envKey); ok {
e.Val = envValue
e.FromEnv = true
e.EnvVar = val
return nil
}
e.Val = ""
e.FromEnv = true
e.EnvVar = val
return nil
}
e.Val = val
e.FromEnv = false
e.EnvVar = ""
return nil
}
return fmt.Errorf("failed to scan value: %v", value)
}
// Value implements driver.Valuer for database storage.
// It stores the original env reference (e.g., "env.API_KEY") if FromEnv is true,
// otherwise stores the raw value.
func (e EnvVar) Value() (driver.Value, error) {
if e.FromEnv {
return e.EnvVar, nil
}
return e.Val, nil
}
// IsFromEnv returns true if the value is sourced from an environment variable.
func (e *EnvVar) IsFromEnv() bool {
return e.FromEnv
}
// IsSet returns true if the EnvVar has a resolved value or an environment variable reference.
// This should be used instead of GetValue() != "" when checking whether a field was configured,
// because env var references may have an empty Val before resolution (e.g., when the env var
// is not available in the current environment).
func (e *EnvVar) IsSet() bool {
if e == nil {
return false
}
return e.Val != "" || e.EnvVar != ""
}
// GetValue returns the value.
func (e *EnvVar) GetValue() string {
if e == nil {
return ""
}
return e.Val
}
// GetValuePtr returns a pointer to the value.
func (e *EnvVar) GetValuePtr() *string {
if e == nil {
return nil
}
return &e.Val
}
// CoerceInt coerces value to int
func (e *EnvVar) CoerceInt(defaultValue int) int {
if e == nil {
return defaultValue
}
val, err := strconv.Atoi(e.GetValue())
if err != nil {
return defaultValue
}
return val
}
// CoerceBool coerces value to bool
func (e *EnvVar) CoerceBool(defaultValue bool) bool {
if e == nil {
return defaultValue
}
val, err := strconv.ParseBool(e.GetValue())
if err != nil {
return defaultValue
}
return val
}
// IsDefined returns true if the EnvVar has a source (static value or env key)
func (e *EnvVar) IsDefined() bool {
if e == nil {
return false
}
if e.IsFromEnv() {
return e.EnvVar != ""
}
return e.Val != ""
}

609
core/schemas/envvar_test.go Normal file
View File

@@ -0,0 +1,609 @@
package schemas
import (
"encoding/json"
"os"
"testing"
)
func TestEnvVar_UnmarshalJSON_DoubleEscapedJSON(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "service account credentials with escaped JSON",
input: `"{\"type\":\"service_account\",\"project_id\":\"test-project\"}"`,
expected: `{"type":"service_account","project_id":"test-project"}`,
},
{
name: "nested JSON object with multiple levels of escaping",
input: `"{\"key\":\"value\",\"nested\":{\"inner\":\"data\"}}"`,
expected: `{"key":"value","nested":{"inner":"data"}}`,
},
{
name: "JSON with escaped newlines in private key",
input: `"{\"private_key\":\"-----BEGIN PRIVATE KEY-----\\nMIIE...\\n-----END PRIVATE KEY-----\\n\"}"`,
expected: `{"private_key":"-----BEGIN PRIVATE KEY-----\nMIIE...\n-----END PRIVATE KEY-----\n"}`,
},
{
name: "simple string value",
input: `"sk-test-api-key-12345"`,
expected: "sk-test-api-key-12345",
},
{
name: "empty string",
input: `""`,
expected: "",
},
{
name: "string with special characters",
input: `"hello\"world"`,
expected: `hello"world`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var envVar EnvVar
err := envVar.UnmarshalJSON([]byte(tt.input))
if err != nil {
t.Fatalf("UnmarshalJSON failed: %v", err)
}
if envVar.Val != tt.expected {
t.Errorf("Expected Val=%q, got Val=%q", tt.expected, envVar.Val)
}
if envVar.FromEnv {
t.Errorf("Expected FromEnv=false, got FromEnv=true")
}
})
}
}
func TestEnvVar_UnmarshalJSON_EnvVarReference(t *testing.T) {
// Set up test environment variable
os.Setenv("TEST_API_KEY", "actual-api-key-value")
defer os.Unsetenv("TEST_API_KEY")
tests := []struct {
name string
input string
expectedVal string
expectedEnvVar string
expectedFromEnv bool
}{
{
name: "env var reference with value present",
input: `"env.TEST_API_KEY"`,
expectedVal: "actual-api-key-value",
expectedEnvVar: "env.TEST_API_KEY",
expectedFromEnv: true,
},
{
name: "env var reference with missing value",
input: `"env.NONEXISTENT_VAR"`,
expectedVal: "",
expectedEnvVar: "env.NONEXISTENT_VAR",
expectedFromEnv: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var envVar EnvVar
err := envVar.UnmarshalJSON([]byte(tt.input))
if err != nil {
t.Fatalf("UnmarshalJSON failed: %v", err)
}
if envVar.Val != tt.expectedVal {
t.Errorf("Expected Val=%q, got Val=%q", tt.expectedVal, envVar.Val)
}
if envVar.EnvVar != tt.expectedEnvVar {
t.Errorf("Expected EnvVar=%q, got EnvVar=%q", tt.expectedEnvVar, envVar.EnvVar)
}
if envVar.FromEnv != tt.expectedFromEnv {
t.Errorf("Expected FromEnv=%v, got FromEnv=%v", tt.expectedFromEnv, envVar.FromEnv)
}
})
}
}
func TestEnvVar_UnmarshalJSON_FullStructure(t *testing.T) {
// Test when the input is already an EnvVar JSON object
input := `{"value":"my-api-key","env_var":"env.MY_KEY","from_env":true}`
var envVar EnvVar
err := envVar.UnmarshalJSON([]byte(input))
if err != nil {
t.Fatalf("UnmarshalJSON failed: %v", err)
}
if envVar.Val != "my-api-key" {
t.Errorf("Expected Val=%q, got Val=%q", "my-api-key", envVar.Val)
}
if envVar.EnvVar != "env.MY_KEY" {
t.Errorf("Expected EnvVar=%q, got EnvVar=%q", "env.MY_KEY", envVar.EnvVar)
}
if !envVar.FromEnv {
t.Errorf("Expected FromEnv=true, got FromEnv=false")
}
}
func TestNewEnvVar_DoubleEscapedJSON(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "service account credentials with escaped JSON",
input: `"{\"type\":\"service_account\",\"project_id\":\"test-project\"}"`,
expected: `{"type":"service_account","project_id":"test-project"}`,
},
{
name: "JSON with escaped newlines",
input: `"{\"private_key\":\"-----BEGIN-----\\nDATA\\n-----END-----\\n\"}"`,
expected: `{"private_key":"-----BEGIN-----\nDATA\n-----END-----\n"}`,
},
{
name: "simple string without quotes",
input: "sk-test-api-key",
expected: "sk-test-api-key",
},
{
name: "simple string with outer quotes",
input: `"sk-test-api-key"`,
expected: "sk-test-api-key",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
envVar := NewEnvVar(tt.input)
if envVar.Val != tt.expected {
t.Errorf("Expected Val=%q, got Val=%q", tt.expected, envVar.Val)
}
})
}
}
func TestNewEnvVar_EnvVarReference(t *testing.T) {
// Set up test environment variable
os.Setenv("TEST_NEW_ENVVAR_KEY", "resolved-value")
defer os.Unsetenv("TEST_NEW_ENVVAR_KEY")
tests := []struct {
name string
input string
expectedVal string
expectedEnvVar string
expectedFromEnv bool
}{
{
name: "env var reference with value present",
input: "env.TEST_NEW_ENVVAR_KEY",
expectedVal: "resolved-value",
expectedEnvVar: "env.TEST_NEW_ENVVAR_KEY",
expectedFromEnv: true,
},
{
name: "env var reference with quotes",
input: `"env.TEST_NEW_ENVVAR_KEY"`,
expectedVal: "resolved-value",
expectedEnvVar: "env.TEST_NEW_ENVVAR_KEY",
expectedFromEnv: true,
},
{
name: "env var reference missing",
input: "env.MISSING_VAR",
expectedVal: "",
expectedEnvVar: "env.MISSING_VAR",
expectedFromEnv: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
envVar := NewEnvVar(tt.input)
if envVar.Val != tt.expectedVal {
t.Errorf("Expected Val=%q, got Val=%q", tt.expectedVal, envVar.Val)
}
if envVar.EnvVar != tt.expectedEnvVar {
t.Errorf("Expected EnvVar=%q, got EnvVar=%q", tt.expectedEnvVar, envVar.EnvVar)
}
if envVar.FromEnv != tt.expectedFromEnv {
t.Errorf("Expected FromEnv=%v, got FromEnv=%v", tt.expectedFromEnv, envVar.FromEnv)
}
})
}
}
// TestEnvVar_RealWorldVertexCredentials tests the actual use case that triggered
// the double-escaping bug: Vertex AI service account credentials
func TestEnvVar_RealWorldVertexCredentials(t *testing.T) {
// This simulates what happens when parsing config.json with embedded service account JSON
type VertexKeyConfig struct {
ProjectID EnvVar `json:"project_id"`
Region EnvVar `json:"region"`
AuthCredentials EnvVar `json:"auth_credentials"`
}
jsonInput := `{
"project_id": "my-project",
"region": "us-central1",
"auth_credentials": "{\"type\":\"service_account\",\"project_id\":\"my-project\",\"private_key_id\":\"abc123\",\"private_key\":\"-----BEGIN PRIVATE KEY-----\\nMIIE...\\n-----END PRIVATE KEY-----\\n\",\"client_email\":\"test@my-project.iam.gserviceaccount.com\"}"
}`
var config VertexKeyConfig
err := json.Unmarshal([]byte(jsonInput), &config)
if err != nil {
t.Fatalf("Failed to unmarshal: %v", err)
}
// Verify auth_credentials is properly unescaped
expectedAuthCreds := `{"type":"service_account","project_id":"my-project","private_key_id":"abc123","private_key":"-----BEGIN PRIVATE KEY-----\nMIIE...\n-----END PRIVATE KEY-----\n","client_email":"test@my-project.iam.gserviceaccount.com"}`
if config.AuthCredentials.Val != expectedAuthCreds {
t.Errorf("AuthCredentials not properly unescaped.\nExpected: %s\nGot: %s", expectedAuthCreds, config.AuthCredentials.Val)
}
// Verify simple string fields work correctly
if config.ProjectID.Val != "my-project" {
t.Errorf("Expected ProjectID=%q, got %q", "my-project", config.ProjectID.Val)
}
if config.Region.Val != "us-central1" {
t.Errorf("Expected Region=%q, got %q", "us-central1", config.Region.Val)
}
}
// TestEnvVar_MixedConfigParsing tests parsing a config with both env var references
// and embedded JSON credentials
func TestEnvVar_MixedConfigParsing(t *testing.T) {
os.Setenv("TEST_PROJECT_ID", "env-project-id")
defer os.Unsetenv("TEST_PROJECT_ID")
type Config struct {
ProjectID EnvVar `json:"project_id"`
Credentials EnvVar `json:"credentials"`
}
jsonInput := `{
"project_id": "env.TEST_PROJECT_ID",
"credentials": "{\"type\":\"service_account\",\"key\":\"value\"}"
}`
var config Config
err := json.Unmarshal([]byte(jsonInput), &config)
if err != nil {
t.Fatalf("Failed to unmarshal: %v", err)
}
// Verify env var reference is resolved
if config.ProjectID.Val != "env-project-id" {
t.Errorf("Expected ProjectID=%q, got %q", "env-project-id", config.ProjectID.Val)
}
if !config.ProjectID.FromEnv {
t.Errorf("Expected ProjectID.FromEnv=true")
}
// Verify JSON credentials are properly unescaped
expectedCreds := `{"type":"service_account","key":"value"}`
if config.Credentials.Val != expectedCreds {
t.Errorf("Expected Credentials=%q, got %q", expectedCreds, config.Credentials.Val)
}
if config.Credentials.FromEnv {
t.Errorf("Expected Credentials.FromEnv=false")
}
}
func TestEnvVar_Equals(t *testing.T) {
tests := []struct {
name string
a *EnvVar
b *EnvVar
expected bool
}{
{
name: "both nil",
a: nil,
b: nil,
expected: true,
},
{
name: "first nil",
a: nil,
b: &EnvVar{Val: "test"},
expected: false,
},
{
name: "second nil",
a: &EnvVar{Val: "test"},
b: nil,
expected: false,
},
{
name: "equal values",
a: &EnvVar{Val: "test", EnvVar: "env.TEST", FromEnv: true},
b: &EnvVar{Val: "test", EnvVar: "env.TEST", FromEnv: true},
expected: true,
},
{
name: "different values",
a: &EnvVar{Val: "test1"},
b: &EnvVar{Val: "test2"},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.a.Equals(tt.b)
if result != tt.expected {
t.Errorf("Expected Equals=%v, got %v", tt.expected, result)
}
})
}
}
func TestEnvVar_Redacted(t *testing.T) {
tests := []struct {
name string
input EnvVar
expectedVal string
}{
{
name: "empty value",
input: EnvVar{Val: ""},
expectedVal: "",
},
{
name: "short value (8 chars)",
input: EnvVar{Val: "12345678"},
expectedVal: "********",
},
{
name: "long value",
input: EnvVar{Val: "sk-1234567890abcdefghijklmnop"},
expectedVal: "sk-1************************mnop",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.input.Redacted()
if result.Val != tt.expectedVal {
t.Errorf("Expected Redacted Val=%q, got %q", tt.expectedVal, result.Val)
}
})
}
}
func TestEnvVar_IsRedacted(t *testing.T) {
tests := []struct {
name string
input EnvVar
expected bool
}{
{
name: "empty not from env",
input: EnvVar{Val: "", FromEnv: false},
expected: false,
},
{
name: "from env",
input: EnvVar{Val: "test", FromEnv: true},
expected: true,
},
{
name: "short all asterisks",
input: EnvVar{Val: "****"},
expected: true,
},
{
name: "redacted pattern 32 chars",
input: EnvVar{Val: "sk-1************************mnop"},
expected: true,
},
{
name: "normal value",
input: EnvVar{Val: "sk-test-key"},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.input.IsRedacted()
if result != tt.expected {
t.Errorf("Expected IsRedacted=%v, got %v", tt.expected, result)
}
})
}
}
// TestEnvVar_IsSet verifies the semantic difference between GetValue() != "" and IsSet().
// IsSet() must return true when the EnvVar references an env var (regardless of whether
// that env var has been resolved to a non-empty Val). This is the property that the
// BeforeSave hooks rely on so env var references survive persistence.
func TestEnvVar_IsSet(t *testing.T) {
tests := []struct {
name string
input *EnvVar
expected bool
}{
{
name: "nil envvar",
input: nil,
expected: false,
},
{
name: "completely empty",
input: &EnvVar{},
expected: false,
},
{
name: "only Val set (plain value)",
input: &EnvVar{Val: "abc"},
expected: true,
},
{
name: "only EnvVar reference set (env not resolved on this server)",
input: &EnvVar{EnvVar: "env.MISSING", FromEnv: true},
expected: true,
},
{
name: "Val and EnvVar both set (env was resolved)",
input: &EnvVar{Val: "resolved-secret", EnvVar: "env.X", FromEnv: true},
expected: true,
},
{
name: "FromEnv true but no reference and no value",
input: &EnvVar{FromEnv: true},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.input.IsSet(); got != tt.expected {
t.Errorf("IsSet() = %v, want %v", got, tt.expected)
}
})
}
}
// TestEnvVar_MarshalJSON_AutoRedactsEnvBackedValues verifies that any EnvVar marshaled
// to JSON with FromEnv=true is automatically masked, regardless of whether the
// surrounding code remembered to call Redacted() explicitly. This is the defense-in-depth
// guarantee that prevents env-resolved secrets from leaking through unredacted fields.
func TestEnvVar_MarshalJSON_AutoRedactsEnvBackedValues(t *testing.T) {
tests := []struct {
name string
input EnvVar
wantValue string
wantEnvVar string
wantFromEnv bool
}{
{
name: "env-backed long secret is redacted",
input: EnvVar{Val: "sk-1234567890abcdefghijklmnop", EnvVar: "env.OPENAI_API_KEY", FromEnv: true},
wantValue: "sk-1************************mnop",
wantEnvVar: "env.OPENAI_API_KEY",
wantFromEnv: true,
},
{
name: "env-backed short secret is fully masked",
input: EnvVar{Val: "12345678", EnvVar: "env.SHORT", FromEnv: true},
wantValue: "********",
wantEnvVar: "env.SHORT",
wantFromEnv: true,
},
{
name: "env-backed unresolved on this server keeps empty value",
input: EnvVar{Val: "", EnvVar: "env.MISSING", FromEnv: true},
wantValue: "",
wantEnvVar: "env.MISSING",
wantFromEnv: true,
},
{
name: "plain value (not from env) is NOT redacted",
input: EnvVar{Val: "2024-10-21", EnvVar: "", FromEnv: false},
wantValue: "2024-10-21",
wantEnvVar: "",
wantFromEnv: false,
},
{
name: "empty plain value passes through",
input: EnvVar{Val: "", EnvVar: "", FromEnv: false},
wantValue: "",
wantEnvVar: "",
wantFromEnv: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.input)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
var got struct {
Value string `json:"value"`
EnvVar string `json:"env_var"`
FromEnv bool `json:"from_env"`
}
if err := json.Unmarshal(data, &got); err != nil {
t.Fatalf("Unmarshal of marshaled output failed: %v", err)
}
if got.Value != tt.wantValue {
t.Errorf("value: got %q, want %q", got.Value, tt.wantValue)
}
if got.EnvVar != tt.wantEnvVar {
t.Errorf("env_var: got %q, want %q", got.EnvVar, tt.wantEnvVar)
}
if got.FromEnv != tt.wantFromEnv {
t.Errorf("from_env: got %v, want %v", got.FromEnv, tt.wantFromEnv)
}
})
}
}
// TestEnvVar_MarshalJSON_DoesNotMutateOriginal ensures the auto-redaction in MarshalJSON
// does not mutate the receiver. The inference path calls GetValue() to build the actual
// HTTP request to the LLM provider, so the original Val must remain intact.
func TestEnvVar_MarshalJSON_DoesNotMutateOriginal(t *testing.T) {
original := EnvVar{Val: "real-secret-value", EnvVar: "env.SECRET", FromEnv: true}
if _, err := json.Marshal(original); err != nil {
t.Fatalf("Marshal failed: %v", err)
}
if original.Val != "real-secret-value" {
t.Errorf("MarshalJSON mutated Val: got %q, want %q", original.Val, "real-secret-value")
}
if original.GetValue() != "real-secret-value" {
t.Errorf("GetValue() returns mutated value: got %q", original.GetValue())
}
}
// TestEnvVar_MarshalJSON_RoundTripIsRedacted verifies that a marshaled-then-unmarshaled
// env-backed EnvVar is recognized as redacted. The merge logic in provider_keys.go relies
// on this so it can detect "the UI sent back the same redacted value, don't overwrite".
func TestEnvVar_MarshalJSON_RoundTripIsRedacted(t *testing.T) {
original := EnvVar{Val: "sk-1234567890abcdefghijklmnop", EnvVar: "env.KEY", FromEnv: true}
data, err := json.Marshal(original)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
var roundTripped EnvVar
if err := json.Unmarshal(data, &roundTripped); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
if !roundTripped.IsRedacted() {
t.Errorf("Round-tripped env-backed value should be IsRedacted, got Val=%q", roundTripped.Val)
}
if roundTripped.EnvVar != "env.KEY" {
t.Errorf("env_var reference lost in round-trip: got %q, want %q", roundTripped.EnvVar, "env.KEY")
}
}
// TestEnvVar_MarshalJSON_DoesNotAffectGetValue is a critical safety net: marshaling an
// EnvVar to JSON must NOT change what GetValue() returns. The inference path uses
// GetValue() to build outgoing LLM requests; if marshaling were to mutate the value,
// every request after a UI fetch would silently start using the redacted mask as the
// API key.
func TestEnvVar_MarshalJSON_DoesNotAffectGetValue(t *testing.T) {
os.Setenv("MY_REAL_API_KEY", "sk-real-secret-1234567890abcdef")
defer os.Unsetenv("MY_REAL_API_KEY")
ev := NewEnvVar("env.MY_REAL_API_KEY")
if ev.GetValue() != "sk-real-secret-1234567890abcdef" {
t.Fatalf("setup: GetValue() = %q, want resolved env value", ev.GetValue())
}
// Marshaling would redact in the JSON output, but must not touch the in-memory Val.
if _, err := json.Marshal(ev); err != nil {
t.Fatalf("Marshal failed: %v", err)
}
if ev.GetValue() != "sk-real-secret-1234567890abcdef" {
t.Errorf("GetValue() returns mutated value after MarshalJSON: got %q", ev.GetValue())
}
}

253
core/schemas/files.go Normal file
View File

@@ -0,0 +1,253 @@
// Package schemas defines the core schemas and types used by the Bifrost system.
package schemas
// FilePurpose represents the purpose of an uploaded file.
type FilePurpose string
const (
FilePurposeBatch FilePurpose = "batch"
FilePurposeAssistants FilePurpose = "assistants"
FilePurposeFineTune FilePurpose = "fine-tune"
FilePurposeVision FilePurpose = "vision"
FilePurposeBatchOutput FilePurpose = "batch_output"
FilePurposeUserData FilePurpose = "user_data"
FilePurposeResponses FilePurpose = "responses"
FilePurposeEvals FilePurpose = "evals"
)
// FileStatus represents the status of a file.
type FileStatus string
const (
FileStatusUploaded FileStatus = "uploaded"
FileStatusProcessed FileStatus = "processed"
FileStatusProcessing FileStatus = "processing"
FileStatusError FileStatus = "error"
FileStatusDeleted FileStatus = "deleted"
)
// FileStorageBackend represents the storage backend type.
type FileStorageBackend string
const (
FileStorageAPI FileStorageBackend = "api" // OpenAI/Azure REST API
FileStorageS3 FileStorageBackend = "s3" // AWS S3
FileStorageGCS FileStorageBackend = "gcs" // Google Cloud Storage
FileStorageMemory FileStorageBackend = "memory" // In-memory (for Anthropic virtual files)
)
// FileObject represents a file object returned by the API.
type FileObject struct {
ID string `json:"id"`
Object string `json:"object,omitempty"` // "file"
Bytes int64 `json:"bytes"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at,omitempty"`
Filename string `json:"filename"`
Purpose FilePurpose `json:"purpose"`
Status FileStatus `json:"status,omitempty"`
StatusDetails *string `json:"status_details,omitempty"`
ExpiresAt *int64 `json:"expires_at,omitempty"`
}
// BifrostFileUploadRequest represents a request to upload a file.
type BifrostFileUploadRequest struct {
Provider ModelProvider `json:"provider"`
Model *string `json:"model"`
// File content
File []byte `json:"-"` // Raw file content (not serialized)
Filename string `json:"filename"` // Original filename
Purpose FilePurpose `json:"purpose"` // Purpose of the file (e.g., "batch")
ContentType *string `json:"content_type,omitempty"` // MIME type of the file
// Storage configuration (for S3/GCS backends)
StorageConfig *FileStorageConfig `json:"storage_config,omitempty"`
// Expiration configuration (OpenAI only)
ExpiresAfter *FileExpiresAfter `json:"expires_after,omitempty"`
// Extra parameters for provider-specific features
ExtraParams map[string]interface{} `json:"-"`
}
// S3StorageConfig represents AWS S3 storage configuration.
type S3StorageConfig struct {
Bucket string `json:"bucket,omitempty"`
Region string `json:"region,omitempty"`
Prefix string `json:"prefix,omitempty"`
}
// GCSStorageConfig represents Google Cloud Storage configuration.
type GCSStorageConfig struct {
Bucket string `json:"bucket,omitempty"`
Project string `json:"project,omitempty"`
Prefix string `json:"prefix,omitempty"`
}
// FileExpiresAfter represents an expiration configuration for uploaded files.
type FileExpiresAfter struct {
Anchor string `json:"anchor"` // e.g., "created_at"
Seconds int `json:"seconds"` // 3600-2592000 (1 hour to 30 days)
}
// FileStorageConfig represents storage configuration for cloud storage backends.
type FileStorageConfig struct {
S3 *S3StorageConfig `json:"s3,omitempty"`
GCS *GCSStorageConfig `json:"gcs,omitempty"`
}
// BifrostFileUploadResponse represents the response from uploading a file.
type BifrostFileUploadResponse struct {
ID string `json:"id"`
Object string `json:"object,omitempty"` // "file"
Bytes int64 `json:"bytes"`
CreatedAt int64 `json:"created_at"`
Filename string `json:"filename"`
Purpose FilePurpose `json:"purpose"`
Status FileStatus `json:"status,omitempty"`
StatusDetails *string `json:"status_details,omitempty"`
ExpiresAt *int64 `json:"expires_at,omitempty"`
// Storage backend info
StorageBackend FileStorageBackend `json:"storage_backend,omitempty"`
StorageURI string `json:"storage_uri,omitempty"` // S3/GCS URI if applicable
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}
// BifrostFileListRequest represents a request to list files.
type BifrostFileListRequest struct {
Provider ModelProvider `json:"provider"`
Model *string `json:"model"`
RawRequestBody []byte `json:"-"` // Raw request body (not serialized)
// Filters
Purpose FilePurpose `json:"purpose,omitempty"` // Filter by purpose
// Pagination
Limit int `json:"limit,omitempty"` // Max results to return
After *string `json:"after,omitempty"` // Cursor for pagination
Order *string `json:"order,omitempty"` // Sort order (asc/desc)
// Storage configuration (for S3/GCS backends)
StorageConfig *FileStorageConfig `json:"storage_config,omitempty"`
// Extra parameters for provider-specific features
ExtraParams map[string]interface{} `json:"-"`
}
// GetRawRequestBody returns the raw request body.
func (request *BifrostFileListRequest) GetRawRequestBody() []byte {
return request.RawRequestBody
}
// BifrostFileListResponse represents the response from listing files.
type BifrostFileListResponse struct {
Object string `json:"object,omitempty"` // "list"
Data []FileObject `json:"data"`
HasMore bool `json:"has_more,omitempty"`
After *string `json:"after,omitempty"` // Continuation token for pagination
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}
// BifrostFileRetrieveRequest represents a request to retrieve file metadata.
type BifrostFileRetrieveRequest struct {
Provider ModelProvider `json:"provider"`
Model *string `json:"model"`
RawRequestBody []byte `json:"-"` // Raw request body (not serialized)
FileID string `json:"file_id"` // ID of the file to retrieve
// Storage configuration (for S3/GCS backends)
StorageConfig *FileStorageConfig `json:"storage_config,omitempty"`
// Extra parameters for provider-specific features
ExtraParams map[string]interface{} `json:"-"`
}
// GetRawRequestBody returns the raw request body.
func (request *BifrostFileRetrieveRequest) GetRawRequestBody() []byte {
return request.RawRequestBody
}
// BifrostFileRetrieveResponse represents the response from retrieving file metadata.
type BifrostFileRetrieveResponse struct {
ID string `json:"id"`
Object string `json:"object,omitempty"` // "file"
Bytes int64 `json:"bytes"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at,omitempty"`
Filename string `json:"filename"`
Purpose FilePurpose `json:"purpose"`
Status FileStatus `json:"status,omitempty"`
StatusDetails *string `json:"status_details,omitempty"`
ExpiresAt *int64 `json:"expires_at,omitempty"`
// Storage backend info
StorageBackend FileStorageBackend `json:"storage_backend,omitempty"`
StorageURI string `json:"storage_uri,omitempty"`
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}
// BifrostFileDeleteRequest represents a request to delete a file.
type BifrostFileDeleteRequest struct {
Provider ModelProvider `json:"provider"`
Model *string `json:"model"`
FileID string `json:"file_id"` // ID of the file to delete
RawRequestBody []byte `json:"-"` // Raw request body (not serialized)
// Storage configuration (for S3/GCS backends)
StorageConfig *FileStorageConfig `json:"storage_config,omitempty"`
// Extra parameters for provider-specific features
ExtraParams map[string]interface{} `json:"-"`
}
// GetRawRequestBody returns the raw request body.
func (request *BifrostFileDeleteRequest) GetRawRequestBody() []byte {
return request.RawRequestBody
}
// BifrostFileDeleteResponse represents the response from deleting a file.
type BifrostFileDeleteResponse struct {
ID string `json:"id"`
Object string `json:"object,omitempty"` // "file"
Deleted bool `json:"deleted"`
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}
// BifrostFileContentRequest represents a request to download file content.
type BifrostFileContentRequest struct {
Provider ModelProvider `json:"provider"`
Model *string `json:"model"`
FileID string `json:"file_id"` // ID of the file to download
RawRequestBody []byte `json:"-"` // Raw request body (not serialized)
// Storage configuration (for S3/GCS backends)
StorageConfig *FileStorageConfig `json:"storage_config,omitempty"`
// Extra parameters for provider-specific features
ExtraParams map[string]interface{} `json:"-"`
}
// GetRawRequestBody returns the raw request body.
func (request *BifrostFileContentRequest) GetRawRequestBody() []byte {
return request.RawRequestBody
}
// BifrostFileContentResponse represents the response from downloading file content.
type BifrostFileContentResponse struct {
FileID string `json:"file_id"`
Content []byte `json:"-"` // Raw file content (not serialized)
ContentType string `json:"content_type,omitempty"` // MIME type
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}

338
core/schemas/images.go Normal file
View File

@@ -0,0 +1,338 @@
package schemas
type ImageEventType string
const (
ImageGenerationEventTypePartial ImageEventType = "image_generation.partial_image"
ImageGenerationEventTypeCompleted ImageEventType = "image_generation.completed"
ImageGenerationEventTypeError ImageEventType = "error"
ImageEditEventTypePartial ImageEventType = "image_edit.partial_image"
ImageEditEventTypeCompleted ImageEventType = "image_edit.completed"
ImageEditEventTypeError ImageEventType = "error"
)
// BifrostImageGenerationRequest represents an image generation request in bifrost format
type BifrostImageGenerationRequest struct {
Provider ModelProvider `json:"provider"`
Model string `json:"model"`
Input *ImageGenerationInput `json:"input"`
Params *ImageGenerationParameters `json:"params,omitempty"`
Fallbacks []Fallback `json:"fallbacks,omitempty"`
RawRequestBody []byte `json:"-"`
}
// GetRawRequestBody implements utils.RequestBodyGetter.
func (b *BifrostImageGenerationRequest) GetRawRequestBody() []byte {
return b.RawRequestBody
}
type ImageGenerationInput struct {
Prompt string `json:"prompt"`
}
type ImageGenerationParameters struct {
N *int `json:"n,omitempty"` // Number of images (1-10)
Background *string `json:"background,omitempty"` // "transparent", "opaque", "auto"
Moderation *string `json:"moderation,omitempty"` // "low", "auto"
PartialImages *int `json:"partial_images,omitempty"` // 0-3
Size *string `json:"size,omitempty"` // "256x256", "512x512", "1024x1024", "1792x1024", "1024x1792", "1536x1024", "1024x1536", "auto"
Quality *string `json:"quality,omitempty"` // "auto", "high", "medium", "low", "hd", "standard"
OutputCompression *int `json:"output_compression,omitempty"` // compression level (0-100%)
OutputFormat *string `json:"output_format,omitempty"` // "png", "webp", "jpeg"
Style *string `json:"style,omitempty"` // "natural", "vivid"
ResponseFormat *string `json:"response_format,omitempty"` // "url", "b64_json"
Seed *int `json:"seed,omitempty"` // seed for image generation
NegativePrompt *string `json:"negative_prompt,omitempty"` // negative prompt for image generation
NumInferenceSteps *int `json:"num_inference_steps,omitempty"` // number of inference steps
User *string `json:"user,omitempty"`
InputImages []string `json:"input_images,omitempty"` // input images for image generation, base64 encoded or URL
AspectRatio *string `json:"aspect_ratio,omitempty"` // aspect ratio of the image
ExtraParams map[string]interface{} `json:"-"`
}
// BifrostImageGenerationResponse represents the image generation response in bifrost format
type BifrostImageGenerationResponse struct {
ID string `json:"id,omitempty"`
Created int64 `json:"created,omitempty"`
Model string `json:"model,omitempty"`
Data []ImageData `json:"data"`
*ImageGenerationResponseParameters
Usage *ImageUsage `json:"usage,omitempty"`
ExtraFields BifrostResponseExtraFields `json:"extra_fields,omitempty"`
}
// BackfillParams populates response fields from the original request that are needed
// for cost calculation but may not be returned by the provider.
// - NumInputImages on ImageUsage (count of input images from the request)
// - Size on ImageGenerationResponseParameters (from request params if not in response)
// - Quality (low, medium, high, auto) only
func (r *BifrostImageGenerationResponse) BackfillParams(req *BifrostRequest) {
if r == nil || req == nil {
return
}
numInputImages, size, quality := getNumInputImagesSizeAndQualityFromRequest(req)
// Backfill Model from whichever inner request carries it. Some provider APIs
// (notably OpenAI /v1/images/*) omit model in the response body.
if r.Model == "" {
switch {
case req.ImageGenerationRequest != nil:
r.Model = req.ImageGenerationRequest.Model
case req.ImageEditRequest != nil:
r.Model = req.ImageEditRequest.Model
case req.ImageVariationRequest != nil:
r.Model = req.ImageVariationRequest.Model
}
}
// Backfill NumInputImages
if numInputImages > 0 {
if r.Usage == nil {
r.Usage = &ImageUsage{}
}
r.Usage.NumInputImages = numInputImages
}
// Backfill Size if not already present from provider response
if size != "" && (r.ImageGenerationResponseParameters == nil || r.ImageGenerationResponseParameters.Size == "") {
if r.ImageGenerationResponseParameters == nil {
r.ImageGenerationResponseParameters = &ImageGenerationResponseParameters{}
}
r.ImageGenerationResponseParameters.Size = size
}
// Backfill Quality if not already present from provider response
if quality != "" && (r.ImageGenerationResponseParameters == nil || r.ImageGenerationResponseParameters.Quality == "") {
if r.ImageGenerationResponseParameters == nil {
r.ImageGenerationResponseParameters = &ImageGenerationResponseParameters{}
}
r.ImageGenerationResponseParameters.Quality = quality
}
}
// getModelFromRequest extracts the model from any image-related request.
func getModelFromRequest(req *BifrostRequest) string {
if req == nil {
return ""
}
switch {
case req.ImageGenerationRequest != nil:
return req.ImageGenerationRequest.Model
case req.ImageEditRequest != nil:
return req.ImageEditRequest.Model
case req.ImageVariationRequest != nil:
return req.ImageVariationRequest.Model
}
return ""
}
// getNumInputImagesSizeAndQualityFromRequest extracts request params for cost calculation.
// Quality is only returned when it is one of low, medium, high, auto.
func getNumInputImagesSizeAndQualityFromRequest(req *BifrostRequest) (numInputImages int, size string, quality string) {
if req == nil {
return 0, "", ""
}
switch {
case req.ImageGenerationRequest != nil:
if req.ImageGenerationRequest.Params != nil {
p := req.ImageGenerationRequest.Params
numInputImages = len(p.InputImages)
if p.Size != nil {
size = *p.Size
}
if p.Quality != nil {
quality = normalizeImageQuality(*p.Quality)
}
}
case req.ImageEditRequest != nil:
if req.ImageEditRequest.Input != nil {
numInputImages = len(req.ImageEditRequest.Input.Images)
}
if req.ImageEditRequest.Params != nil {
p := req.ImageEditRequest.Params
if p.Size != nil {
size = *p.Size
}
if p.Quality != nil {
quality = normalizeImageQuality(*p.Quality)
}
}
case req.ImageVariationRequest != nil:
if req.ImageVariationRequest.Input != nil {
numInputImages = 1
}
if req.ImageVariationRequest.Params != nil && req.ImageVariationRequest.Params.Size != nil {
size = *req.ImageVariationRequest.Params.Size
}
}
return numInputImages, size, quality
}
// normalizeImageQuality returns the quality string only if it is supported by gpt-image-1.5 (low, medium, high, auto).
// All other values (hd, standard, etc.) are discarded and return empty.
func normalizeImageQuality(q string) string {
switch q {
case "low", "medium", "high", "auto":
return q
default:
return ""
}
}
type ImageGenerationResponseParameters struct {
Background string `json:"background,omitempty"`
OutputFormat string `json:"output_format,omitempty"`
Quality string `json:"quality,omitempty"`
Size string `json:"size,omitempty"`
FinishReasons []*string `json:"finish_reasons,omitempty"`
Seeds []int `json:"seeds,omitempty"`
}
type ImageData struct {
URL string `json:"url,omitempty"`
B64JSON string `json:"b64_json,omitempty"`
RevisedPrompt string `json:"revised_prompt,omitempty"`
Index int `json:"index"`
}
type ImageUsage struct {
InputTokens int `json:"input_tokens,omitempty"` // Always text tokens unless InputTokensDetails is not nil
InputTokensDetails *ImageTokenDetails `json:"input_tokens_details,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
OutputTokens int `json:"output_tokens,omitempty"` // Always image tokens unless OutputTokensDetails is not nil
OutputTokensDetails *ImageTokenDetails `json:"output_tokens_details,omitempty"`
NumInputImages int `json:"num_input_images,omitempty"` // Number of input images from the request (populated by Bifrost)
}
type ImageTokenDetails struct {
NImages int `json:"-"` // Number of images generated (used internally for bifrost)
ImageTokens int `json:"image_tokens,omitempty"`
TextTokens int `json:"text_tokens,omitempty"`
}
// Streaming Response
type BifrostImageGenerationStreamResponse struct {
ID string `json:"id,omitempty"`
Type ImageEventType `json:"type,omitempty"`
Index int `json:"-"` // Which image (0-N)
ChunkIndex int `json:"-"` // Chunk order within image
PartialImageIndex *int `json:"partial_image_index,omitempty"`
SequenceNumber int `json:"sequence_number,omitempty"`
B64JSON string `json:"b64_json,omitempty"`
URL string `json:"url,omitempty"`
CreatedAt int64 `json:"created_at,omitempty"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
Background string `json:"background,omitempty"`
OutputFormat string `json:"output_format,omitempty"`
RevisedPrompt string `json:"revised_prompt,omitempty"`
Usage *ImageUsage `json:"usage,omitempty"`
Error *BifrostError `json:"error,omitempty"`
RawRequest string `json:"-"`
RawResponse string `json:"-"`
ExtraFields BifrostResponseExtraFields `json:"extra_fields,omitempty"`
}
// BackfillParams populates response fields from the original request that are needed
// for cost calculation but may not be returned by the provider.
// - NumInputImages on ImageUsage (count of input images from the request)
// - Size on ImageGenerationResponseParameters (from request params if not in response)
// - Quality (low, medium, high, auto) only
func (r *BifrostImageGenerationStreamResponse) BackfillParams(req *BifrostRequest) {
numInputImages, size, quality := getNumInputImagesSizeAndQualityFromRequest(req)
// Backfill NumInputImages
if numInputImages > 0 {
if r.Usage == nil {
r.Usage = &ImageUsage{}
}
r.Usage.NumInputImages = numInputImages
}
// Backfill Size if not already present from provider response
if size != "" && r.Size == "" {
r.Size = size
}
// Backfill Quality if not already present (only low, medium, high, auto)
if quality != "" && r.Quality == "" {
r.Quality = quality
}
}
// BifrostImageEditRequest represents an image edit request in bifrost format
type BifrostImageEditRequest struct {
Provider ModelProvider `json:"provider"`
Model string `json:"model"`
Input *ImageEditInput `json:"input"`
Params *ImageEditParameters `json:"params,omitempty"`
Fallbacks []Fallback `json:"fallbacks,omitempty"`
RawRequestBody []byte `json:"-"`
}
// GetRawRequestBody implements [utils.RequestBodyGetter].
func (b *BifrostImageEditRequest) GetRawRequestBody() []byte {
return b.RawRequestBody
}
type ImageEditInput struct {
Images []ImageInput `json:"images"`
Prompt string `json:"prompt"`
}
type ImageInput struct {
Image []byte `json:"image"`
}
type ImageEditParameters struct {
Type *string `json:"type,omitempty"` // "inpainting", "outpainting", "background_removal", "remove_background", "erase_object", "recolor", "search_replace", "control_sketch", "control_structure", "style_guide", "style_transfer", "upscale_fast", "upscale_creative", "upscale_conservative"
Background *string `json:"background,omitempty"` // "transparent", "opaque", "auto"
InputFidelity *string `json:"input_fidelity,omitempty"` // "low", "high"
Mask []byte `json:"mask,omitempty"`
N *int `json:"n,omitempty"` // number of images to generate (1-10)
OutputCompression *int `json:"output_compression,omitempty"` // compression level (0-100%)
OutputFormat *string `json:"output_format,omitempty"` // "png", "webp", "jpeg"
PartialImages *int `json:"partial_images,omitempty"` // 0-3
Quality *string `json:"quality,omitempty"` // "auto", "high", "medium", "low", "standard"
ResponseFormat *string `json:"response_format,omitempty"` // "url", "b64_json"
Size *string `json:"size,omitempty"` // "256x256", "512x512", "1024x1024", "1536x1024", "1024x1536", "auto"
User *string `json:"user,omitempty"`
NegativePrompt *string `json:"negative_prompt,omitempty"` // negative prompt for image editing
Seed *int `json:"seed,omitempty"` // seed for image editing
NumInferenceSteps *int `json:"num_inference_steps,omitempty"` // number of inference steps
ExtraParams map[string]interface{} `json:"-"`
}
// BifrostImageVariationRequest represents an image variation request in bifrost format
type BifrostImageVariationRequest struct {
Provider ModelProvider `json:"provider"`
Model string `json:"model"`
Input *ImageVariationInput `json:"input"`
Params *ImageVariationParameters `json:"params,omitempty"`
Fallbacks []Fallback `json:"fallbacks,omitempty"`
RawRequestBody []byte `json:"-"`
}
// GetRawRequestBody implements [utils.RequestBodyGetter].
func (b *BifrostImageVariationRequest) GetRawRequestBody() []byte {
return b.RawRequestBody
}
type ImageVariationInput struct {
Image ImageInput `json:"image"`
}
type ImageVariationParameters struct {
N *int `json:"n,omitempty"` // Number of images (1-10)
ResponseFormat *string `json:"response_format,omitempty"` // "url", "b64_json"
Size *string `json:"size,omitempty"` // "256x256", "512x512", "1024x1024", "1792x1024", "1024x1792", "1536x1024", "1024x1536", "auto"
User *string `json:"user,omitempty"`
ExtraParams map[string]interface{} `json:"-"`
}
// BifrostImageVariationResponse represents the image variation response in bifrost format
// It uses the same structure as image generation response
type BifrostImageVariationResponse = BifrostImageGenerationResponse

135
core/schemas/json_native.go Normal file
View File

@@ -0,0 +1,135 @@
//go:build !tinygo && !wasm
package schemas
import (
"bytes"
"encoding/json"
"reflect"
"github.com/bytedance/sonic"
)
// Marshal encodes v to JSON bytes using the high-performance sonic library.
func Marshal(v interface{}) ([]byte, error) {
return sonic.Marshal(v)
}
// MarshalString encodes v to a JSON string using sonic.
func MarshalString(v interface{}) (string, error) {
return sonic.MarshalString(v)
}
// Unmarshal decodes JSON data into v using sonic.
func Unmarshal(data []byte, v interface{}) error {
return sonic.Unmarshal(data, v)
}
// Compact removes insignificant whitespace from JSON-encoded src
// and appends the result to dst.
func Compact(dst *bytes.Buffer, src []byte) error {
return json.Compact(dst, src)
}
// MarshalSorted encodes v to JSON with map keys sorted alphabetically.
// Use this when deterministic output is needed (e.g., hashing, caching keys).
// Uses sonic.ConfigStd which has SortMapKeys enabled.
func MarshalSorted(v interface{}) ([]byte, error) {
return sonic.ConfigStd.Marshal(v)
}
// MarshalSortedIndent encodes v to indented JSON with map keys sorted alphabetically.
func MarshalSortedIndent(v interface{}, prefix, indent string) ([]byte, error) {
return sonic.ConfigStd.MarshalIndent(v, prefix, indent)
}
// ConvertViaJSON converts src to type T via JSON round-trip using sorted marshaling.
// Use as fallback when direct type assertion fails (e.g., map[string]interface{} from JSON).
func ConvertViaJSON[T any](src interface{}) (T, error) {
var zero T
data, err := MarshalSorted(src)
if err != nil {
return zero, err
}
var result T
if err := Unmarshal(data, &result); err != nil {
return zero, err
}
return result, nil
}
// MarshalDeeplySorted encodes v to JSON with all map keys sorted alphabetically,
// including nested maps inside OrderedMap and other custom types with MarshalJSON.
// This ensures fully deterministic output for hashing/caching purposes.
//
// Unlike MarshalSorted which relies on sonic's SortMapKeys (which doesn't affect
// types with custom MarshalJSON like OrderedMap), this function first normalizes
// the entire structure to plain maps, then marshals with sorted keys.
func MarshalDeeplySorted(v interface{}) ([]byte, error) {
normalized := normalizeForSortedMarshal(v)
return sonic.ConfigStd.Marshal(normalized)
}
// normalizeForSortedMarshal recursively converts OrderedMaps and structs to plain maps
// so that sonic.ConfigStd.Marshal will sort all keys deterministically.
func normalizeForSortedMarshal(v interface{}) interface{} {
if v == nil {
return nil
}
switch val := v.(type) {
case *OrderedMap:
if val == nil {
return nil
}
result := make(map[string]interface{}, val.Len())
val.Range(func(k string, v interface{}) bool {
result[k] = normalizeForSortedMarshal(v)
return true
})
return result
case OrderedMap:
result := make(map[string]interface{}, val.Len())
val.Range(func(k string, v interface{}) bool {
result[k] = normalizeForSortedMarshal(v)
return true
})
return result
case map[string]interface{}:
result := make(map[string]interface{}, len(val))
for k, v := range val {
result[k] = normalizeForSortedMarshal(v)
}
return result
case []interface{}:
result := make([]interface{}, len(val))
for i, elem := range val {
result[i] = normalizeForSortedMarshal(elem)
}
return result
default:
// Intentional round-trip: converts structs with custom MarshalJSON into plain
// maps so sonic.ConfigStd can sort all keys. Cannot use sjson since input is a Go struct.
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Ptr {
if rv.IsNil() {
return nil
}
rv = rv.Elem()
}
if rv.Kind() == reflect.Struct {
// Marshal struct to JSON, then unmarshal to map for normalization
data, err := sonic.Marshal(v)
if err != nil {
return v
}
var m map[string]interface{}
if err := sonic.Unmarshal(data, &m); err != nil {
return v
}
// Recursively normalize the resulting map
return normalizeForSortedMarshal(m)
}
return v
}
}

View File

@@ -0,0 +1,92 @@
package schemas
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test helper types for ConvertViaJSON tests
type convertTestTarget struct {
Type string `json:"type"`
Name string `json:"name,omitempty"`
}
type convertTestSource struct {
Name string `json:"name"`
Value int `json:"value"`
}
type convertTestDest struct {
Name string `json:"name"`
Value int `json:"value"`
}
func TestConvertViaJSON(t *testing.T) {
tests := []struct {
name string
input interface{}
validate func(t *testing.T)
}{
{
name: "map_to_struct",
input: map[string]interface{}{"type": "json_object", "name": "test_format"},
validate: func(t *testing.T) {
result, err := ConvertViaJSON[convertTestTarget](map[string]interface{}{"type": "json_object", "name": "test_format"})
require.NoError(t, err)
assert.Equal(t, "json_object", result.Type)
assert.Equal(t, "test_format", result.Name)
},
},
{
name: "struct_to_struct",
validate: func(t *testing.T) {
src := convertTestSource{Name: "hello", Value: 42}
result, err := ConvertViaJSON[convertTestDest](src)
require.NoError(t, err)
assert.Equal(t, "hello", result.Name)
assert.Equal(t, 42, result.Value)
},
},
{
name: "slice_conversion",
validate: func(t *testing.T) {
src := []interface{}{
map[string]interface{}{"type": "image", "name": "a"},
map[string]interface{}{"type": "video", "name": "b"},
}
result, err := ConvertViaJSON[[]convertTestTarget](src)
require.NoError(t, err)
require.Len(t, result, 2)
assert.Equal(t, "image", result[0].Type)
assert.Equal(t, "a", result[0].Name)
assert.Equal(t, "video", result[1].Type)
assert.Equal(t, "b", result[1].Name)
},
},
{
name: "invalid_input_returns_error",
validate: func(t *testing.T) {
result, err := ConvertViaJSON[convertTestTarget](42)
require.Error(t, err)
assert.Equal(t, convertTestTarget{}, result)
},
},
{
name: "nil_input",
validate: func(t *testing.T) {
result, err := ConvertViaJSON[convertTestTarget](nil)
// nil marshals to "null" which unmarshals to zero struct — no error
require.NoError(t, err)
assert.Equal(t, convertTestTarget{}, result)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.validate(t)
})
}
}

92
core/schemas/json_wasm.go Normal file
View File

@@ -0,0 +1,92 @@
//go:build tinygo || wasm
package schemas
import (
"bytes"
"encoding/json"
)
// Marshal encodes v to JSON bytes using the standard library.
func Marshal(v interface{}) ([]byte, error) {
return json.Marshal(v)
}
// MarshalString encodes v to a JSON string using the standard library.
func MarshalString(v interface{}) (string, error) {
data, err := json.Marshal(v)
if err != nil {
return "", err
}
return string(data), nil
}
// Unmarshal decodes JSON data into v using the standard library.
func Unmarshal(data []byte, v interface{}) error {
return json.Unmarshal(data, v)
}
// Compact removes insignificant whitespace from JSON-encoded src
// and appends the result to dst.
func Compact(dst *bytes.Buffer, src []byte) error {
return json.Compact(dst, src)
}
// MarshalSorted encodes v to JSON with map keys sorted alphabetically.
// Use this when deterministic output is needed (e.g., hashing, caching keys).
// Recursively normalizes OrderedMap values to plain maps so json.Marshal sorts keys.
func MarshalSorted(v interface{}) ([]byte, error) {
normalized := normalizeForSortedMarshal(v)
return json.Marshal(normalized)
}
// MarshalDeeplySorted encodes v to JSON with all map keys sorted alphabetically,
// including nested maps inside OrderedMap and other custom types with MarshalJSON.
// This ensures fully deterministic output for hashing/caching purposes.
// In WASM builds, this is equivalent to MarshalSorted since both normalize recursively.
func MarshalDeeplySorted(v interface{}) ([]byte, error) {
normalized := normalizeForSortedMarshal(v)
return json.Marshal(normalized)
}
// normalizeForSortedMarshal recursively converts OrderedMaps and structs to plain maps
// so that json.Marshal will sort their keys (Go 1.12+ sorts map keys).
func normalizeForSortedMarshal(v interface{}) interface{} {
if v == nil {
return nil
}
switch val := v.(type) {
case *OrderedMap:
if val == nil {
return nil
}
result := make(map[string]interface{}, val.Len())
val.Range(func(k string, v interface{}) bool {
result[k] = normalizeForSortedMarshal(v)
return true
})
return result
case OrderedMap:
result := make(map[string]interface{}, val.Len())
val.Range(func(k string, v interface{}) bool {
result[k] = normalizeForSortedMarshal(v)
return true
})
return result
case map[string]interface{}:
result := make(map[string]interface{}, len(val))
for k, v := range val {
result[k] = normalizeForSortedMarshal(v)
}
return result
case []interface{}:
result := make([]interface{}, len(val))
for i, elem := range val {
result[i] = normalizeForSortedMarshal(elem)
}
return result
default:
return v
}
}

View File

@@ -0,0 +1,207 @@
package schemas
import (
"bytes"
"encoding/json"
)
// JSONKeyOrder is a lightweight helper that preserves JSON key ordering through
// struct serialization round-trips. Embed it in any struct with `json:"-"` tag.
//
// LLMs are autoregressive sequence models that are sensitive to JSON key ordering
// in tool schemas. This helper ensures that when Bifrost deserializes and
// re-serializes JSON, the original key order from the client is preserved.
//
// Usage:
//
// type MyStruct struct {
// keyOrder JSONKeyOrder `json:"-"`
// Field1 string `json:"field1"`
// Field2 string `json:"field2"`
// }
//
// func (s *MyStruct) UnmarshalJSON(data []byte) error {
// type Alias MyStruct
// if err := Unmarshal(data, (*Alias)(s)); err != nil { return err }
// s.keyOrder.Capture(data)
// return nil
// }
//
// func (s MyStruct) MarshalJSON() ([]byte, error) {
// type Alias MyStruct
// data, err := Marshal(Alias(s))
// if err != nil { return nil, err }
// return s.keyOrder.Apply(data)
// }
type JSONKeyOrder struct {
keys []string
}
// Capture extracts and stores the top-level key order from raw JSON data.
// Call this at the end of UnmarshalJSON.
func (o *JSONKeyOrder) Capture(data []byte) {
o.keys = ExtractTopLevelKeyOrder(data)
}
// Apply reorders the keys in serialized JSON to match the captured order.
// If no order was captured (programmatic construction), returns data unchanged.
// Call this at the end of MarshalJSON.
func (o *JSONKeyOrder) Apply(data []byte) ([]byte, error) {
if len(o.keys) == 0 {
return data, nil
}
return ReorderJSONKeys(data, o.keys)
}
// ExtractTopLevelKeyOrder parses a JSON object and returns its top-level keys in
// document order. Useful for capturing key order before struct deserialization
// loses it, so that re-serialization can preserve the original order.
func ExtractTopLevelKeyOrder(data []byte) []string {
trimmed := bytes.TrimSpace(data)
if len(trimmed) == 0 || trimmed[0] != '{' {
return nil
}
dec := json.NewDecoder(bytes.NewReader(trimmed))
// Read opening '{'
if _, err := dec.Token(); err != nil {
return nil
}
var keys []string
for dec.More() {
tok, err := dec.Token()
if err != nil {
break
}
key, ok := tok.(string)
if !ok {
break
}
keys = append(keys, key)
// Skip the value (handles nested objects/arrays)
if err := skipJSONValue(dec); err != nil {
break
}
}
return keys
}
// skipJSONValue reads and discards a single JSON value from a decoder.
func skipJSONValue(dec *json.Decoder) error {
tok, err := dec.Token()
if err != nil {
return err
}
delim, ok := tok.(json.Delim)
if !ok {
return nil // primitive value, already consumed
}
switch delim {
case '{':
for dec.More() {
// skip key
if _, err := dec.Token(); err != nil {
return err
}
if err := skipJSONValue(dec); err != nil {
return err
}
}
_, err = dec.Token() // closing '}'
return err
case '[':
for dec.More() {
if err := skipJSONValue(dec); err != nil {
return err
}
}
_, err = dec.Token() // closing ']'
return err
}
return nil
}
// ReorderJSONKeys takes serialized JSON and a desired key order, and returns the
// same JSON with top-level keys reordered. Keys present in `order` are emitted
// first in that order; any remaining keys follow in their original order.
// This is a general-purpose utility for preserving client-specified key order
// through struct serialization/deserialization round-trips.
func ReorderJSONKeys(data []byte, order []string) ([]byte, error) {
// Parse into key → raw value pairs, preserving original values as-is
trimmed := bytes.TrimSpace(data)
if len(trimmed) < 2 || trimmed[0] != '{' {
return data, nil
}
// Use encoding/json decoder to get raw key-value pairs while preserving order
dec := json.NewDecoder(bytes.NewReader(trimmed))
dec.UseNumber()
if _, err := dec.Token(); err != nil { // '{'
return data, nil
}
type kvPair struct {
key string
val json.RawMessage
}
var pairs []kvPair
pairMap := make(map[string]json.RawMessage)
for dec.More() {
tok, err := dec.Token()
if err != nil {
return data, nil
}
key, ok := tok.(string)
if !ok {
return data, nil
}
var val json.RawMessage
if err := dec.Decode(&val); err != nil {
return data, nil
}
pairs = append(pairs, kvPair{key, val})
pairMap[key] = val
}
// Rebuild JSON: first keys from `order`, then remaining keys in original order
var buf bytes.Buffer
buf.WriteByte('{')
first := true
emitted := make(map[string]bool, len(order))
for _, key := range order {
val, exists := pairMap[key]
if !exists {
continue
}
if !first {
buf.WriteByte(',')
}
first = false
keyBytes, _ := MarshalSorted(key)
buf.Write(keyBytes)
buf.WriteByte(':')
buf.Write(val)
emitted[key] = true
}
// Remaining keys in their original document order
for _, kv := range pairs {
if emitted[kv.key] {
continue
}
if !first {
buf.WriteByte(',')
}
first = false
keyBytes, _ := MarshalSorted(kv.key)
buf.Write(keyBytes)
buf.WriteByte(':')
buf.Write(kv.val)
}
buf.WriteByte('}')
return buf.Bytes(), nil
}

17
core/schemas/kvstore.go Normal file
View File

@@ -0,0 +1,17 @@
package schemas
import "time"
// KVStore is a minimal interface for a key-value store used by Bifrost internals.
// The concrete implementation (e.g. framework/kvstore.Store) is injected by the
// caller and must satisfy this interface. Passing nil disables KV-backed features.
type KVStore interface {
Get(key string) (any, error)
SetWithTTL(key string, value any, ttl time.Duration) error
SetNXWithTTL(key string, value any, ttl time.Duration) (bool, error)
Delete(key string) (bool, error)
}
const (
DefaultSessionStickyTTL = time.Hour
)

79
core/schemas/logger.go Normal file
View File

@@ -0,0 +1,79 @@
// Package schemas defines the core schemas and types used by the Bifrost system.
package schemas
// LogLevel represents the severity level of a log message.
// Internally it maps to zerolog.Level for interoperability.
type LogLevel string
// LogLevel constants for different severity levels.
const (
LogLevelDebug LogLevel = "debug"
LogLevelInfo LogLevel = "info"
LogLevelWarn LogLevel = "warn"
LogLevelError LogLevel = "error"
)
// LoggerOutputType represents the output type of a logger.
type LoggerOutputType string
// LoggerOutputType constants for different output types.
const (
LoggerOutputTypeJSON LoggerOutputType = "json"
LoggerOutputTypePretty LoggerOutputType = "pretty"
)
// Logger defines the interface for logging operations in the Bifrost system.
// Implementations of this interface should provide methods for logging messages
// at different severity levels.
type Logger interface {
// Debug logs a debug-level message.
// This is used for detailed debugging information that is typically only needed
// during development or troubleshooting.
Debug(msg string, args ...any)
// Info logs an info-level message.
// This is used for general informational messages about normal operation.
Info(msg string, args ...any)
// Warn logs a warning-level message.
// This is used for potentially harmful situations that don't prevent normal operation.
Warn(msg string, args ...any)
// Error logs an error-level message.
// This is used for serious problems that need attention and may prevent normal operation.
Error(msg string, args ...any)
// Fatal logs a fatal-level message.
// This is used for critical situations that require immediate attention and will terminate the program.
Fatal(msg string, args ...any)
// SetLevel sets the log level for the logger.
SetLevel(level LogLevel)
// SetOutputType sets the output type for the logger.
SetOutputType(outputType LoggerOutputType)
// LogHTTPRequest returns a LogEventBuilder for structured HTTP access logging.
// The level parameter controls the log severity, msg is sent when Send() is called.
// Use the fluent builder to attach typed fields before calling Send().
LogHTTPRequest(level LogLevel, msg string) LogEventBuilder
}
// LogEventBuilder provides a fluent interface for building structured log entries.
type LogEventBuilder interface {
Str(key, val string) LogEventBuilder
Int(key string, val int) LogEventBuilder
Int64(key string, val int64) LogEventBuilder
Send()
}
// noopLogEventBuilder is a no-op builder for loggers that don't need structured logging.
type noopLogEventBuilder struct{}
func (noopLogEventBuilder) Str(string, string) LogEventBuilder { return noopLogEventBuilder{} }
func (noopLogEventBuilder) Int(string, int) LogEventBuilder { return noopLogEventBuilder{} }
func (noopLogEventBuilder) Int64(string, int64) LogEventBuilder { return noopLogEventBuilder{} }
func (noopLogEventBuilder) Send() {}
// NoopLogEvent is a shared singleton no-op LogEventBuilder.
var NoopLogEvent LogEventBuilder = noopLogEventBuilder{}

243
core/schemas/mcp.go Normal file
View File

@@ -0,0 +1,243 @@
//go:build !tinygo && !wasm
// Package schemas defines the core schemas and types used by the Bifrost system.
package schemas
import (
"context"
"errors"
"strings"
"time"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/server"
)
// OAuth-related errors
var (
ErrOAuth2ConfigNotFound = errors.New("oauth2 config not found")
ErrOAuth2ProviderNotAvailable = errors.New("oauth2 provider not available")
ErrOAuth2TokenExpired = errors.New("oauth2 token expired")
ErrOAuth2TokenInvalid = errors.New("oauth2 token invalid")
ErrOAuth2RefreshFailed = errors.New("oauth2 token refresh failed")
ErrOAuth2NotPerUserSession = errors.New("state does not match a per-user oauth session")
ErrOAuth2TokenNotFound = errors.New("per-user oauth token not found for this identity and mcp server")
ErrPerUserOAuthPendingFlowExpired = errors.New("per-user oauth pending flow has expired")
)
// MCPUserOAuthRequiredError is returned when a per-user OAuth MCP server requires
// the user to authenticate before tool execution can proceed.
type MCPUserOAuthRequiredError struct {
MCPClientID string `json:"mcp_client_id"`
MCPClientName string `json:"mcp_client_name"`
AuthorizeURL string `json:"authorize_url"`
SessionID string `json:"session_id"`
Message string `json:"message"`
}
func (e *MCPUserOAuthRequiredError) Error() string {
return e.Message
}
// MCPConfig represents the configuration for MCP integration in Bifrost.
// It enables tool auto-discovery and execution from local and external MCP servers.
type MCPConfig struct {
ClientConfigs []*MCPClientConfig `json:"client_configs,omitempty"` // Per-client execution configurations
ToolManagerConfig *MCPToolManagerConfig `json:"tool_manager_config,omitempty"` // MCP tool manager configuration
ToolSyncInterval time.Duration `json:"tool_sync_interval,omitempty"` // Global default interval for syncing tools from MCP servers (0 = use default 10 min)
// Function to fetch a new request ID for each tool call result message in agent mode,
// this is used to ensure that the tool call result messages are unique and can be tracked in plugins or by the user.
// This id is attached to ctx.Value(schemas.BifrostContextKeyRequestID) in the agent mode.
// If not provider, same request ID is used for all tool call result messages without any overrides.
FetchNewRequestIDFunc func(ctx *BifrostContext) string `json:"-"`
// PluginPipelineProvider returns a plugin pipeline for running MCP plugin hooks.
// Used when executeCode tool calls nested MCP tools to ensure plugins run for them.
// The plugin pipeline should be released back to the pool using ReleasePluginPipeline.
PluginPipelineProvider func() interface{} `json:"-"`
// ReleasePluginPipeline releases a plugin pipeline back to the pool.
// This should be called after the plugin pipeline is no longer needed.
ReleasePluginPipeline func(pipeline interface{}) `json:"-"`
}
type MCPToolManagerConfig struct {
ToolExecutionTimeout time.Duration `json:"tool_execution_timeout"`
MaxAgentDepth int `json:"max_agent_depth"`
CodeModeBindingLevel CodeModeBindingLevel `json:"code_mode_binding_level,omitempty"` // How tools are exposed in VFS: "server" or "tool"
DisableAutoToolInject bool `json:"disable_auto_tool_inject,omitempty"` // When true, MCP tools are not injected into requests by default
}
const (
DefaultMaxAgentDepth = 10
DefaultToolExecutionTimeout = 30 * time.Second
)
// CodeModeBindingLevel defines how tools are exposed in the VFS for code execution
type CodeModeBindingLevel string
const (
CodeModeBindingLevelServer CodeModeBindingLevel = "server"
CodeModeBindingLevelTool CodeModeBindingLevel = "tool"
)
// MCPAuthType defines the authentication type for MCP connections
type MCPAuthType string
const (
MCPAuthTypeNone MCPAuthType = "none" // No authentication
MCPAuthTypeHeaders MCPAuthType = "headers" // Header-based authentication (API keys, etc.)
MCPAuthTypeOauth MCPAuthType = "oauth" // OAuth 2.0 authentication (server-level, admin authenticates once)
MCPAuthTypePerUserOauth MCPAuthType = "per_user_oauth" // Per-user OAuth 2.0 authentication (each user authenticates individually)
)
// MCPClientConfig defines tool filtering for an MCP client.
type MCPClientConfig struct {
ID string `json:"client_id"` // Client ID
Name string `json:"name"` // Client name
IsCodeModeClient bool `json:"is_code_mode_client"` // Whether the client is a code mode client
ConnectionType MCPConnectionType `json:"connection_type"` // How to connect (HTTP, STDIO, SSE, or InProcess)
ConnectionString *EnvVar `json:"connection_string,omitempty"` // HTTP or SSE URL (required for HTTP or SSE connections)
StdioConfig *MCPStdioConfig `json:"stdio_config,omitempty"` // STDIO configuration (required for STDIO connections)
AuthType MCPAuthType `json:"auth_type"` // Authentication type (none, headers, or oauth)
OauthConfigID *string `json:"oauth_config_id,omitempty"` // OAuth config ID (references oauth_configs table)
State string `json:"state,omitempty"` // Connection state (connected, disconnected, error)
Headers map[string]EnvVar `json:"headers,omitempty"` // Headers to send with the request (for headers auth type)
AllowedExtraHeaders WhiteList `json:"allowed_extra_headers,omitempty"` // Allowlist of request-level headers that callers may forward to this MCP server at execution time
InProcessServer *server.MCPServer `json:"-"` // MCP server instance for in-process connections (Go package only)
ToolsToExecute WhiteList `json:"tools_to_execute,omitempty"` // Include-only list.
// ToolsToExecute semantics:
// - ["*"] => all tools are included
// - [] => no tools are included (deny-by-default)
// - nil/omitted => treated as [] (no tools)
// - ["tool1", "tool2"] => include only the specified tools
ToolsToAutoExecute WhiteList `json:"tools_to_auto_execute,omitempty"` // Auto-execute list.
// ToolsToAutoExecute semantics:
// - ["*"] => all tools are auto-executed
// - [] => no tools are auto-executed (deny-by-default)
// - nil/omitted => treated as [] (no tools)
// - ["tool1", "tool2"] => auto-execute only the specified tools
// Note: If a tool is in ToolsToAutoExecute but not in ToolsToExecute, it will be skipped.
IsPingAvailable *bool `json:"is_ping_available,omitempty"` // Whether the MCP server supports ping for health checks (nil/true = ping; false = listTools). Defaults to true.
ToolSyncInterval time.Duration `json:"tool_sync_interval,omitempty"` // Per-client override for tool sync interval (0 = use global, negative = disabled)
ToolPricing map[string]float64 `json:"tool_pricing,omitempty"` // Tool pricing for each tool (cost per execution)
ConfigHash string `json:"-"` // Config hash for reconciliation (not serialized)
AllowOnAllVirtualKeys bool `json:"allow_on_all_virtual_keys"` // Whether to allow the MCP client to run on all virtual keys
// Discovered tools for per-user OAuth clients (persisted so they survive restart)
DiscoveredTools map[string]ChatTool `json:"-"` // Discovered tool schemas keyed by prefixed name
DiscoveredToolNameMapping map[string]string `json:"-"` // Mapping from sanitized tool names to original MCP names
}
// NewMCPClientConfigFromMap creates a new MCP client config from a map[string]any.
func NewMCPClientConfigFromMap(configMap map[string]any) *MCPClientConfig {
var config MCPClientConfig
data, err := MarshalSorted(configMap)
if err != nil {
return nil
}
if err := Unmarshal(data, &config); err != nil {
return nil
}
return &config
}
// HttpHeaders returns the HTTP headers for the MCP client config.
func (c *MCPClientConfig) HttpHeaders(ctx context.Context, oauth2Provider OAuth2Provider) (map[string]string, error) {
headers := make(map[string]string)
switch c.AuthType {
case MCPAuthTypeOauth:
if c.OauthConfigID == nil {
return nil, ErrOAuth2ConfigNotFound
}
if oauth2Provider == nil {
return nil, ErrOAuth2ProviderNotAvailable
}
accessToken, err := oauth2Provider.GetAccessToken(ctx, *c.OauthConfigID)
if err != nil {
return nil, err
}
// Validate token format - trim whitespace and check for invalid characters
accessToken = strings.TrimSpace(accessToken)
if accessToken == "" {
return nil, errors.New("access token is empty")
}
if strings.ContainsAny(accessToken, "\n\r\t") {
return nil, errors.New("access token contains invalid characters")
}
headers["Authorization"] = "Bearer " + accessToken
case MCPAuthTypeHeaders:
for key, value := range c.Headers {
headers[key] = value.GetValue()
}
case MCPAuthTypePerUserOauth:
// Per-user OAuth: headers are injected per-call in executeToolInternal, not at connection level
return headers, nil
case MCPAuthTypeNone:
// No headers to add
default:
// Default to headers behavior for backward compatibility
for key, value := range c.Headers {
headers[key] = value.GetValue()
}
}
return headers, nil
}
// MCPConnectionType defines the communication protocol for MCP connections
type MCPConnectionType string
const (
MCPConnectionTypeHTTP MCPConnectionType = "http" // HTTP-based connection
MCPConnectionTypeSTDIO MCPConnectionType = "stdio" // STDIO-based connection
MCPConnectionTypeSSE MCPConnectionType = "sse" // Server-Sent Events connection
MCPConnectionTypeInProcess MCPConnectionType = "inprocess" // In-process (in-memory) connection
)
// MCPStdioConfig defines how to launch a STDIO-based MCP server.
type MCPStdioConfig struct {
Command string `json:"command"` // Executable command to run
Args []string `json:"args"` // Command line arguments
Envs []string `json:"envs"` // Environment variables required
}
type MCPConnectionState string
const (
MCPConnectionStateConnected MCPConnectionState = "connected" // Client is connected and ready to use
MCPConnectionStateDisconnected MCPConnectionState = "disconnected" // Client is not connected
MCPConnectionStateError MCPConnectionState = "error" // Client is in an error state, and cannot be used
MCPConnectionStatePendingTools MCPConnectionState = "pending_tools" // Connected but tools not yet populated
)
// MCPClientState represents a connected MCP client with its configuration and tools.
// It is used internally by the MCP manager to track the state of a connected MCP client.
type MCPClientState struct {
Name string // Unique name for this client
Conn *client.Client // Active MCP client connection
ExecutionConfig *MCPClientConfig // Tool filtering settings
ToolMap map[string]ChatTool // Available tools mapped by name
ToolNameMapping map[string]string // Maps sanitized_name -> original_mcp_name (e.g., "notion_search" -> "notion-search")
ConnectionInfo *MCPClientConnectionInfo `json:"connection_info"` // Connection metadata for management
CancelFunc context.CancelFunc `json:"-"` // Cancel function for SSE connections (not serialized)
State MCPConnectionState // Connection state (connected, disconnected, error)
}
// MCPClientConnectionInfo stores metadata about how a client is connected.
type MCPClientConnectionInfo struct {
Type MCPConnectionType `json:"type"` // Connection type (HTTP, STDIO, SSE, or InProcess)
ConnectionURL *string `json:"connection_url,omitempty"` // HTTP/SSE endpoint URL (for HTTP/SSE connections)
StdioCommandString *string `json:"stdio_command_string,omitempty"` // Command string for display (for STDIO connections)
}
// MCPClient represents a connected MCP client with its configuration and tools,
// and connection information, after it has been initialized.
// It is returned by GetMCPClients() method in bifrost.
type MCPClient struct {
Config *MCPClientConfig `json:"config"` // Tool filtering settings
Tools []ChatToolFunction `json:"tools"` // Available tools
State MCPConnectionState `json:"state"` // Connection state
}

7
core/schemas/mcp_wasm.go Normal file
View File

@@ -0,0 +1,7 @@
//go:build tinygo || wasm
package schemas
// MCPConfig is a stub for WASM builds.
// MCP functionality is not available in WASM plugins.
type MCPConfig struct{}

269
core/schemas/models.go Normal file
View File

@@ -0,0 +1,269 @@
package schemas
import (
"encoding/base64"
"encoding/json"
"fmt"
)
// DefaultPageSize is the default page size for listing models
const DefaultPageSize = 1000
// MaxPaginationRequests is the maximum number of pagination requests to make
const MaxPaginationRequests = 20
// Structure to collect results from goroutines
type ListModelsByKeyResult struct {
Response *BifrostListModelsResponse
Err *BifrostError
KeyID string
}
// KeyStatus represents the status of model listing for a specific key
type KeyStatus struct {
KeyID string `json:"key_id"` // Empty for keyless providers
Status KeyStatusType `json:"status"` // "success", "failed"
Provider ModelProvider `json:"provider"` // Always populated
Error *BifrostError `json:"error,omitempty"`
}
// MarshalJSON implements custom JSON marshaling for KeyStatus to prevent
// circular reference: KeyStatus.Error → BifrostError.ExtraFields.KeyStatuses → KeyStatus.
func (k KeyStatus) MarshalJSON() ([]byte, error) {
type Alias KeyStatus
alias := Alias(k)
if alias.Error != nil {
errCopy := *alias.Error
errCopy.ExtraFields.KeyStatuses = nil
alias.Error = &errCopy
}
return MarshalSorted(alias)
}
type BifrostListModelsRequest struct {
Provider ModelProvider `json:"provider"`
PageSize int `json:"page_size"`
// PageToken: Token received from previous request to retrieve next page
PageToken string `json:"page_token"`
// Unfiltered: If true, the response will include all models for the provider, regardless of the allowed models (internal bifrost use only, not sent to the provider)
Unfiltered bool `json:"-"`
// ExtraParams: Additional provider-specific query parameters
// This allows for flexibility to pass any custom parameters that specific providers might support
ExtraParams map[string]interface{} `json:"-"`
}
type BifrostListModelsResponse struct {
Data []Model `json:"data"`
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
NextPageToken string `json:"next_page_token,omitempty"` // Token to retrieve next page
// Key-level status tracking for multi-key providers
KeyStatuses []KeyStatus `json:"key_statuses,omitempty"`
// Anthropic specific fields
FirstID *string `json:"-"`
LastID *string `json:"-"`
HasMore *bool `json:"-"`
}
// ApplyPagination applies offset-based pagination to a BifrostListModelsResponse.
// Uses opaque tokens with LastID validation to ensure cursor integrity.
// Returns the paginated response with properly set NextPageToken.
func (response *BifrostListModelsResponse) ApplyPagination(pageSize int, pageToken string) *BifrostListModelsResponse {
if response == nil {
return nil
}
totalItems := len(response.Data)
if pageSize <= 0 {
return response
}
cursor := decodePaginationCursor(pageToken)
offset := cursor.Offset
// Validate cursor integrity if LastID is present
if cursor.LastID != "" && !validatePaginationCursor(cursor, response.Data) {
// Invalid cursor: reset to beginning
offset = 0
}
if offset >= totalItems {
// Return empty page, no next token
return &BifrostListModelsResponse{
Data: []Model{},
ExtraFields: response.ExtraFields,
NextPageToken: "",
KeyStatuses: response.KeyStatuses,
}
}
endIndex := offset + pageSize
if endIndex > totalItems {
endIndex = totalItems
}
paginatedData := response.Data[offset:endIndex]
paginatedResponse := &BifrostListModelsResponse{
Data: paginatedData,
ExtraFields: response.ExtraFields,
KeyStatuses: response.KeyStatuses,
}
if endIndex < totalItems {
// Get the last item ID for cursor validation
var lastID string
if len(paginatedData) > 0 {
lastID = paginatedData[len(paginatedData)-1].ID
}
nextToken, err := encodePaginationCursor(endIndex, lastID)
if err == nil {
paginatedResponse.NextPageToken = nextToken
}
} else {
paginatedResponse.NextPageToken = ""
}
return paginatedResponse
}
type Model struct {
ID string `json:"id"`
CanonicalSlug *string `json:"canonical_slug,omitempty"`
Name *string `json:"name,omitempty"`
Alias *string `json:"alias,omitempty"` // Provider API identifier this model alias maps to (e.g. Azure deployment name, Bedrock ARN)
Created *int64 `json:"created,omitempty"`
ContextLength *int `json:"context_length,omitempty"`
MaxInputTokens *int `json:"max_input_tokens,omitempty"`
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
Architecture *Architecture `json:"architecture,omitempty"`
Pricing *Pricing `json:"pricing,omitempty"`
TopProvider *TopProvider `json:"top_provider,omitempty"`
PerRequestLimits *PerRequestLimits `json:"per_request_limits,omitempty"`
SupportedParameters []string `json:"supported_parameters,omitempty"`
DefaultParameters *DefaultParameters `json:"default_parameters,omitempty"`
HuggingFaceID *string `json:"hugging_face_id,omitempty"`
Description *string `json:"description,omitempty"`
OwnedBy *string `json:"owned_by,omitempty"`
SupportedMethods []string `json:"supported_methods,omitempty"`
// ProviderExtra carries opaque provider-specific data (e.g. Anthropic capabilities)
// through the Bifrost pipeline for integration reverse-conversion. Never serialized.
ProviderExtra json.RawMessage `json:"-"`
}
type Architecture struct {
Modality *string `json:"modality,omitempty"`
Tokenizer *string `json:"tokenizer,omitempty"`
InstructType *string `json:"instruct_type,omitempty"`
InputModalities []string `json:"input_modalities,omitempty"`
OutputModalities []string `json:"output_modalities,omitempty"`
}
type Pricing struct {
Prompt *string `json:"prompt,omitempty"`
Completion *string `json:"completion,omitempty"`
Request *string `json:"request,omitempty"`
Image *string `json:"image,omitempty"`
WebSearch *string `json:"web_search,omitempty"`
InternalReasoning *string `json:"internal_reasoning,omitempty"`
InputCacheRead *string `json:"input_cache_read,omitempty"`
InputCacheWrite *string `json:"input_cache_write,omitempty"`
}
type TopProvider struct {
IsModerated *bool `json:"is_moderated,omitempty"`
ContextLength *int `json:"context_length,omitempty"`
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
}
type PerRequestLimits struct {
PromptTokens *int `json:"prompt_tokens,omitempty"`
CompletionTokens *int `json:"completion_tokens,omitempty"`
}
type DefaultParameters struct {
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
}
// paginationCursor represents the internal cursor structure for pagination.
type paginationCursor struct {
Offset int `json:"o"`
LastID string `json:"l,omitempty"`
}
// encodePaginationCursor creates an opaque base64-encoded page token from cursor data.
// Returns empty string if offset is 0 or negative.
func encodePaginationCursor(offset int, lastID string) (string, error) {
if offset <= 0 {
return "", nil
}
cursor := paginationCursor{
Offset: offset,
LastID: lastID,
}
jsonData, err := MarshalSorted(cursor)
if err != nil {
return "", fmt.Errorf("failed to marshal pagination cursor: %w", err)
}
// Use URL-safe base64 encoding without padding for opaque token
encoded := base64.RawURLEncoding.EncodeToString(jsonData)
return encoded, nil
}
// decodePaginationCursor extracts cursor data from an opaque base64-encoded page token.
// Returns cursor with 0 offset for empty or invalid tokens.
func decodePaginationCursor(token string) paginationCursor {
if token == "" {
return paginationCursor{}
}
// Decode base64
decoded, err := base64.RawURLEncoding.DecodeString(token)
if err != nil {
return paginationCursor{}
}
var cursor paginationCursor
if err := Unmarshal(decoded, &cursor); err != nil {
return paginationCursor{}
}
if cursor.Offset < 0 {
return paginationCursor{}
}
return cursor
}
// validatePaginationCursor validates that the cursor matches the expected position in the data.
// Returns true if the cursor is valid, false otherwise.
func validatePaginationCursor(cursor paginationCursor, data []Model) bool {
if cursor.LastID == "" {
return true
}
if cursor.Offset <= 0 || cursor.Offset > len(data) {
return false
}
prevIndex := cursor.Offset - 1
if prevIndex >= 0 && prevIndex < len(data) {
return data[prevIndex].ID == cursor.LastID
}
return true
}

117
core/schemas/models_test.go Normal file
View File

@@ -0,0 +1,117 @@
package schemas
import (
"bytes"
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Raw types that mirror the real KeyStatus → BifrostError → ExtraFields → KeyStatus
// chain but without any custom MarshalJSON. Used to reproduce the cycle error
// that would occur without the fix.
type rawKeyStatusExtraFields struct {
KeyStatuses []rawKeyStatus `json:"key_statuses,omitempty"`
}
type rawKeyStatusBifrostError struct {
IsBifrostError bool `json:"is_bifrost_error"`
Error *ErrorField `json:"error"`
ExtraFields rawKeyStatusExtraFields `json:"extra_fields"`
}
type rawKeyStatus struct {
KeyID string `json:"key_id"`
Status KeyStatusType `json:"status"`
Provider ModelProvider `json:"provider"`
Error *rawKeyStatusBifrostError `json:"error,omitempty"`
}
// TestKeyStatusMarshalJSON_ReproduceCycle proves that without the custom MarshalJSON,
// the circular reference between KeyStatus and BifrostError causes a marshaling failure.
func TestKeyStatusMarshalJSON_ReproduceCycle(t *testing.T) {
bifrostErr := &rawKeyStatusBifrostError{
IsBifrostError: true,
Error: &ErrorField{Message: "test error"},
}
keyStatus := rawKeyStatus{
KeyID: "key-1",
Status: KeyStatusListModelsFailed,
Provider: "test-provider",
Error: bifrostErr,
}
// Create the same cycle that HandleKeylessListModelsRequest creates
bifrostErr.ExtraFields.KeyStatuses = []rawKeyStatus{keyStatus}
// Without any custom MarshalJSON, this must fail with a cycle error
_, err := json.Marshal(keyStatus)
require.Error(t, err, "expected cycle error without the MarshalJSON fix")
assert.Contains(t, err.Error(), "cycle", "error should mention a cycle")
}
func TestKeyStatusMarshalJSON_NoCycle(t *testing.T) {
bifrostErr := &BifrostError{
IsBifrostError: true,
Error: &ErrorField{Message: "test error"},
}
keyStatus := KeyStatus{
KeyID: "key-1",
Status: KeyStatusListModelsFailed,
Provider: "test-provider",
Error: bifrostErr,
}
// Create the same cycle that HandleKeylessListModelsRequest creates
bifrostErr.ExtraFields.KeyStatuses = []KeyStatus{keyStatus}
data, err := Marshal(keyStatus)
require.NoError(t, err, "Marshal should not fail on circular KeyStatus")
// Verify the output doesn't contain nested key_statuses
assert.False(t, bytes.Contains(data, []byte(`"key_statuses"`)),
"expected key_statuses to be omitted from nested error")
}
func TestKeyStatusMarshalJSON_NilError(t *testing.T) {
keyStatus := KeyStatus{
KeyID: "key-2",
Status: "success",
Provider: "test-provider",
}
data, err := Marshal(keyStatus)
require.NoError(t, err, "Marshal should not fail on KeyStatus with nil error")
assert.Contains(t, string(data), `"key_id":"key-2"`)
assert.NotContains(t, string(data), `"error"`)
}
func TestKeyStatusMarshalJSON_PreservesErrorFields(t *testing.T) {
statusCode := 401
bifrostErr := &BifrostError{
IsBifrostError: true,
StatusCode: &statusCode,
Error: &ErrorField{Message: "unauthorized"},
ExtraFields: BifrostErrorExtraFields{
Provider: "openai",
OriginalModelRequested: "gpt-4",
},
}
keyStatus := KeyStatus{
KeyID: "key-3",
Status: KeyStatusListModelsFailed,
Provider: "openai",
Error: bifrostErr,
}
// Create cycle
bifrostErr.ExtraFields.KeyStatuses = []KeyStatus{keyStatus}
data, err := Marshal(keyStatus)
require.NoError(t, err)
// Error fields other than key_statuses should be preserved
dataStr := string(data)
assert.Contains(t, dataStr, `"unauthorized"`)
assert.Contains(t, dataStr, `"original_model_requested":"gpt-4"`)
assert.Contains(t, dataStr, `"status_code":401`)
}

2378
core/schemas/mux.go Normal file

File diff suppressed because it is too large Load Diff

702
core/schemas/mux_test.go Normal file
View File

@@ -0,0 +1,702 @@
package schemas
import "testing"
func TestToChatMessages_PreservesDeveloperRole(t *testing.T) {
messages := []ResponsesMessage{
{
Role: Ptr(ResponsesInputMessageRoleDeveloper),
Content: &ResponsesMessageContent{
ContentBlocks: []ResponsesMessageContentBlock{
{
Type: ResponsesInputMessageContentBlockTypeText,
Text: Ptr("You are helpful"),
},
},
},
},
}
chatMessages := ToChatMessages(messages)
if len(chatMessages) != 1 {
t.Fatalf("expected 1 message, got %d", len(chatMessages))
}
if chatMessages[0].Role != ChatMessageRoleDeveloper {
t.Fatalf("expected role %q, got %q", ChatMessageRoleDeveloper, chatMessages[0].Role)
}
}
func TestToChatRequest_NormalizesDeveloperRoleToSystemForFallback(t *testing.T) {
req := &BifrostResponsesRequest{
Input: []ResponsesMessage{
{
Role: Ptr(ResponsesInputMessageRoleDeveloper),
Content: &ResponsesMessageContent{
ContentBlocks: []ResponsesMessageContentBlock{
{
Type: ResponsesInputMessageContentBlockTypeText,
Text: Ptr("You are helpful"),
},
},
},
},
},
Params: &ResponsesParameters{},
}
chatReq := req.ToChatRequest()
if chatReq == nil {
t.Fatal("expected non-nil chat request")
}
if len(chatReq.Input) != 1 {
t.Fatalf("expected 1 chat message, got %d", len(chatReq.Input))
}
if chatReq.Input[0].Role != ChatMessageRoleSystem {
t.Fatalf("expected role %q in fallback conversion, got %q", ChatMessageRoleSystem, chatReq.Input[0].Role)
}
}
func TestToChatMessages_LeavesExistingSupportedRolesUnchanged(t *testing.T) {
messages := []ResponsesMessage{
{Role: Ptr(ResponsesInputMessageRoleSystem)},
{Role: Ptr(ResponsesInputMessageRoleUser)},
{Role: Ptr(ResponsesInputMessageRoleAssistant)},
}
chatMessages := ToChatMessages(messages)
if len(chatMessages) != len(messages) {
t.Fatalf("expected %d messages, got %d", len(messages), len(chatMessages))
}
if chatMessages[0].Role != ChatMessageRoleSystem {
t.Fatalf("expected system role, got %q", chatMessages[0].Role)
}
if chatMessages[1].Role != ChatMessageRoleUser {
t.Fatalf("expected user role, got %q", chatMessages[1].Role)
}
if chatMessages[2].Role != ChatMessageRoleAssistant {
t.Fatalf("expected assistant role, got %q", chatMessages[2].Role)
}
}
func TestToChatRequest_FiltersUnsupportedResponsesToolsForFallback(t *testing.T) {
validName := "valid_tool"
invalidName := " "
req := &BifrostResponsesRequest{
Params: &ResponsesParameters{
Tools: []ResponsesTool{
{
Type: ResponsesToolTypeFunction,
Name: &validName,
ResponsesToolFunction: &ResponsesToolFunction{
Parameters: &ToolFunctionParameters{
Type: "object",
Properties: &OrderedMap{},
},
},
},
{
Type: ResponsesToolTypeFunction,
Name: &invalidName,
},
{
Type: ResponsesToolTypeMCP,
Name: Ptr("mcp_tool"),
},
{
Type: ResponsesToolTypeWebSearch,
Name: Ptr("web_search"),
},
},
},
}
chatReq := req.ToChatRequest()
if chatReq == nil || chatReq.Params == nil {
t.Fatal("expected non-nil chat request params")
}
if len(chatReq.Params.Tools) != 1 {
t.Fatalf("expected 1 valid fallback tool, got %d", len(chatReq.Params.Tools))
}
if chatReq.Params.Tools[0].Type != ChatToolTypeFunction {
t.Fatalf("expected tool type %q, got %q", ChatToolTypeFunction, chatReq.Params.Tools[0].Type)
}
if chatReq.Params.Tools[0].Function == nil || chatReq.Params.Tools[0].Function.Name != validName {
t.Fatalf("expected function tool %q to be preserved", validName)
}
}
func TestToChatRequest_DropsInvalidToolChoiceForFallback(t *testing.T) {
validName := "valid_tool"
invalidChoiceName := "missing_tool"
req := &BifrostResponsesRequest{
Params: &ResponsesParameters{
Tools: []ResponsesTool{
{
Type: ResponsesToolTypeFunction,
Name: &validName,
},
},
ToolChoice: &ResponsesToolChoice{
ResponsesToolChoiceStruct: &ResponsesToolChoiceStruct{
Type: ResponsesToolChoiceTypeFunction,
Name: &invalidChoiceName,
},
},
},
}
chatReq := req.ToChatRequest()
if chatReq == nil || chatReq.Params == nil {
t.Fatal("expected non-nil chat request params")
}
if chatReq.Params.ToolChoice != nil {
t.Fatal("expected incompatible tool choice to be removed for fallback")
}
}
func TestToChatRequest_AllNonFunctionToolsDropsToolsAndToolChoice(t *testing.T) {
auto := string(ChatToolChoiceTypeAuto)
req := &BifrostResponsesRequest{
Params: &ResponsesParameters{
Tools: []ResponsesTool{
{Type: ResponsesToolTypeMCP, Name: Ptr("mcp")},
{Type: ResponsesToolTypeWebSearch, Name: Ptr("search")},
},
ToolChoice: &ResponsesToolChoice{
ResponsesToolChoiceStr: &auto,
},
},
}
chatReq := req.ToChatRequest()
if chatReq == nil || chatReq.Params == nil {
t.Fatal("expected non-nil chat request params")
}
if chatReq.Params.Tools != nil {
t.Fatalf("expected nil tools when all tools are unsupported, got %d", len(chatReq.Params.Tools))
}
if chatReq.Params.ToolChoice != nil {
t.Fatal("expected tool choice to be dropped when no valid tools remain")
}
}
func TestToChatRequest_DropsAllowedToolsAndCustomToolChoiceForFallback(t *testing.T) {
validName := "valid_tool"
tests := []ResponsesToolChoiceType{
ResponsesToolChoiceTypeAllowedTools,
ResponsesToolChoiceTypeCustom,
}
for _, choiceType := range tests {
t.Run(string(choiceType), func(t *testing.T) {
req := &BifrostResponsesRequest{
Params: &ResponsesParameters{
Tools: []ResponsesTool{
{
Type: ResponsesToolTypeFunction,
Name: &validName,
},
},
ToolChoice: &ResponsesToolChoice{
ResponsesToolChoiceStruct: &ResponsesToolChoiceStruct{
Type: choiceType,
},
},
},
}
chatReq := req.ToChatRequest()
if chatReq == nil || chatReq.Params == nil {
t.Fatal("expected non-nil chat request params")
}
if chatReq.Params.ToolChoice != nil {
t.Fatalf("expected %q tool choice to be dropped for fallback", choiceType)
}
})
}
}
func TestToChatRequest_PreservesStringToolChoiceAutoAndNone(t *testing.T) {
validName := "valid_tool"
tests := []string{
string(ChatToolChoiceTypeAuto),
string(ChatToolChoiceTypeNone),
}
for _, choice := range tests {
t.Run(choice, func(t *testing.T) {
req := &BifrostResponsesRequest{
Params: &ResponsesParameters{
Tools: []ResponsesTool{
{
Type: ResponsesToolTypeFunction,
Name: &validName,
},
},
ToolChoice: &ResponsesToolChoice{
ResponsesToolChoiceStr: &choice,
},
},
}
chatReq := req.ToChatRequest()
if chatReq == nil || chatReq.Params == nil {
t.Fatal("expected non-nil chat request params")
}
if chatReq.Params.ToolChoice == nil || chatReq.Params.ToolChoice.ChatToolChoiceStr == nil {
t.Fatal("expected string tool choice to be preserved")
}
if *chatReq.Params.ToolChoice.ChatToolChoiceStr != choice {
t.Fatalf("expected tool choice %q, got %q", choice, *chatReq.Params.ToolChoice.ChatToolChoiceStr)
}
})
}
}
func TestToBifrostResponsesStreamResponse_PopulatesFinalDoneTextAndCompletedOutput(t *testing.T) {
state := AcquireChatToResponsesStreamState()
defer ReleaseChatToResponsesStreamState(state)
makeChunk := func(role *string, content *string, finishReason *string) *BifrostChatResponse {
return &BifrostChatResponse{
ID: "chatcmpl-test",
Model: "test-model",
Choices: []BifrostResponseChoice{
{
FinishReason: finishReason,
ChatStreamResponseChoice: &ChatStreamResponseChoice{
Delta: &ChatStreamResponseChoiceDelta{
Role: role,
Content: content,
},
},
},
},
}
}
role := string(ChatMessageRoleAssistant)
part1 := "Hello"
part2 := " world"
stop := string(BifrostFinishReasonStop)
var all []*BifrostResponsesStreamResponse
all = append(all, makeChunk(&role, nil, nil).ToBifrostResponsesStreamResponse(state)...)
all = append(all, makeChunk(nil, &part1, nil).ToBifrostResponsesStreamResponse(state)...)
all = append(all, makeChunk(nil, &part2, nil).ToBifrostResponsesStreamResponse(state)...)
all = append(all, makeChunk(nil, nil, &stop).ToBifrostResponsesStreamResponse(state)...)
var outputTextDone *BifrostResponsesStreamResponse
var completed *BifrostResponsesStreamResponse
for _, evt := range all {
if evt == nil {
continue
}
if evt.Type == ResponsesStreamResponseTypeOutputTextDone {
outputTextDone = evt
}
if evt.Type == ResponsesStreamResponseTypeCompleted {
completed = evt
}
}
if outputTextDone == nil || outputTextDone.Text == nil {
t.Fatal("expected response.output_text.done with text")
}
if *outputTextDone.Text != "Hello world" {
t.Fatalf("expected output_text.done text %q, got %q", "Hello world", *outputTextDone.Text)
}
if completed == nil || completed.Response == nil || len(completed.Response.Output) != 1 {
t.Fatal("expected response.completed with one output message")
}
msg := completed.Response.Output[0]
if msg.Content == nil || len(msg.Content.ContentBlocks) == 0 || msg.Content.ContentBlocks[0].Text == nil {
t.Fatal("expected completed output message to include text content block")
}
if *msg.Content.ContentBlocks[0].Text != "Hello world" {
t.Fatalf("expected completed output text %q, got %q", "Hello world", *msg.Content.ContentBlocks[0].Text)
}
}
func TestToBifrostResponsesResponse_MapsLengthToIncomplete(t *testing.T) {
length := string(BifrostFinishReasonLength)
resp := (&BifrostChatResponse{
Choices: []BifrostResponseChoice{
{FinishReason: &length},
},
}).ToBifrostResponsesResponse()
if resp == nil || resp.Status == nil {
t.Fatal("expected status to be set")
}
if *resp.Status != "incomplete" {
t.Fatalf("expected status %q, got %q", "incomplete", *resp.Status)
}
if resp.IncompleteDetails == nil {
t.Fatal("expected incomplete_details to be set")
}
if resp.IncompleteDetails.Reason != "max_output_tokens" {
t.Fatalf("expected incomplete_details.reason %q, got %q", "max_output_tokens", resp.IncompleteDetails.Reason)
}
}
func TestToBifrostResponsesResponse_MapsToolCallsToCompleted(t *testing.T) {
toolCalls := string(BifrostFinishReasonToolCalls)
resp := (&BifrostChatResponse{
Choices: []BifrostResponseChoice{
{FinishReason: &toolCalls},
},
}).ToBifrostResponsesResponse()
if resp == nil || resp.Status == nil {
t.Fatal("expected status to be set")
}
if *resp.Status != "completed" {
t.Fatalf("expected status %q, got %q", "completed", *resp.Status)
}
if resp.IncompleteDetails != nil {
t.Fatal("expected incomplete_details to be nil")
}
}
func TestToBifrostResponsesResponse_PrioritizesLengthAcrossChoices(t *testing.T) {
stop := string(BifrostFinishReasonStop)
length := string(BifrostFinishReasonLength)
resp := (&BifrostChatResponse{
Choices: []BifrostResponseChoice{
{FinishReason: &stop},
{FinishReason: &length},
},
}).ToBifrostResponsesResponse()
if resp == nil || resp.Status == nil {
t.Fatal("expected status to be set")
}
if *resp.Status != "incomplete" {
t.Fatalf("expected status %q, got %q", "incomplete", *resp.Status)
}
if resp.IncompleteDetails == nil || resp.IncompleteDetails.Reason != "max_output_tokens" {
t.Fatal("expected max_output_tokens incomplete_details")
}
}
func TestToBifrostResponsesResponse_UnknownFinishReasonLeavesStatusUnset(t *testing.T) {
unknown := "content_filter"
resp := (&BifrostChatResponse{
Choices: []BifrostResponseChoice{
{FinishReason: &unknown},
},
}).ToBifrostResponsesResponse()
if resp == nil {
t.Fatal("expected non-nil response")
}
if resp.Status != nil {
t.Fatalf("expected status to be nil, got %q", *resp.Status)
}
if resp.IncompleteDetails != nil {
t.Fatal("expected incomplete_details to be nil")
}
}
func TestToBifrostResponsesStreamResponse_IncludesFunctionCallsInCompletedOutput(t *testing.T) {
state := AcquireChatToResponsesStreamState()
defer ReleaseChatToResponsesStreamState(state)
role := string(ChatMessageRoleAssistant)
part1 := "Let me help"
toolCallsFinish := string(BifrostFinishReasonToolCalls)
funcName := "get_weather"
toolCallID := "call_abc123"
var all []*BifrostResponsesStreamResponse
// Role chunk
all = append(all, (&BifrostChatResponse{
ID: "chatcmpl-test",
Model: "test-model",
Choices: []BifrostResponseChoice{
{
ChatStreamResponseChoice: &ChatStreamResponseChoice{
Delta: &ChatStreamResponseChoiceDelta{
Role: &role,
},
},
},
},
}).ToBifrostResponsesStreamResponse(state)...)
// Text content chunk
all = append(all, (&BifrostChatResponse{
ID: "chatcmpl-test",
Model: "test-model",
Choices: []BifrostResponseChoice{
{
ChatStreamResponseChoice: &ChatStreamResponseChoice{
Delta: &ChatStreamResponseChoiceDelta{
Content: &part1,
},
},
},
},
}).ToBifrostResponsesStreamResponse(state)...)
// Tool call chunk with function name
all = append(all, (&BifrostChatResponse{
ID: "chatcmpl-test",
Model: "test-model",
Choices: []BifrostResponseChoice{
{
ChatStreamResponseChoice: &ChatStreamResponseChoice{
Delta: &ChatStreamResponseChoiceDelta{
ToolCalls: []ChatAssistantMessageToolCall{
{
Index: 0,
ID: &toolCallID,
Function: ChatAssistantMessageToolCallFunction{Name: &funcName, Arguments: `{"city":`},
},
},
},
},
},
},
}).ToBifrostResponsesStreamResponse(state)...)
// Tool call argument continuation
all = append(all, (&BifrostChatResponse{
ID: "chatcmpl-test",
Model: "test-model",
Choices: []BifrostResponseChoice{
{
ChatStreamResponseChoice: &ChatStreamResponseChoice{
Delta: &ChatStreamResponseChoiceDelta{
ToolCalls: []ChatAssistantMessageToolCall{
{
Index: 0,
Function: ChatAssistantMessageToolCallFunction{Arguments: `"Paris"}`},
},
},
},
},
},
},
}).ToBifrostResponsesStreamResponse(state)...)
// Finish with tool_calls
all = append(all, (&BifrostChatResponse{
ID: "chatcmpl-test",
Model: "test-model",
Choices: []BifrostResponseChoice{
{
FinishReason: &toolCallsFinish,
ChatStreamResponseChoice: &ChatStreamResponseChoice{
Delta: &ChatStreamResponseChoiceDelta{},
},
},
},
}).ToBifrostResponsesStreamResponse(state)...)
var completed *BifrostResponsesStreamResponse
for _, evt := range all {
if evt != nil && evt.Type == ResponsesStreamResponseTypeCompleted {
completed = evt
}
}
if completed == nil || completed.Response == nil {
t.Fatal("expected response.completed event")
}
output := completed.Response.Output
if len(output) < 2 {
t.Fatalf("expected at least 2 output items (text + function_call), got %d", len(output))
}
var hasText, hasFunctionCall bool
for _, item := range output {
if item.Type != nil && *item.Type == ResponsesMessageTypeMessage {
hasText = true
if item.Content == nil || len(item.Content.ContentBlocks) == 0 || item.Content.ContentBlocks[0].Text == nil {
t.Fatal("text message missing content")
}
if *item.Content.ContentBlocks[0].Text != "Let me help" {
t.Fatalf("expected text %q, got %q", "Let me help", *item.Content.ContentBlocks[0].Text)
}
}
if item.Type != nil && *item.Type == ResponsesMessageTypeFunctionCall {
hasFunctionCall = true
if item.ResponsesToolMessage == nil {
t.Fatal("function_call item missing ResponsesToolMessage")
}
if item.Name == nil || *item.Name != "get_weather" {
t.Fatalf("expected function name %q, got %v", "get_weather", item.Name)
}
if item.Arguments == nil || *item.Arguments != `{"city":"Paris"}` {
t.Fatalf("expected arguments %q, got %v", `{"city":"Paris"}`, item.Arguments)
}
if item.CallID == nil || *item.CallID != toolCallID {
t.Fatalf("expected call_id %q, got %v", toolCallID, item.CallID)
}
}
}
if !hasText {
t.Fatal("expected text message in completed output")
}
if !hasFunctionCall {
t.Fatal("expected function_call item in completed output")
}
}
func TestToBifrostResponsesStreamResponse_ToolCallsOnlyInCompletedOutput(t *testing.T) {
state := AcquireChatToResponsesStreamState()
defer ReleaseChatToResponsesStreamState(state)
role := string(ChatMessageRoleAssistant)
toolCallsFinish := string(BifrostFinishReasonToolCalls)
funcName := "get_weather"
toolCallID := "call_xyz789"
var all []*BifrostResponsesStreamResponse
// Role chunk
all = append(all, (&BifrostChatResponse{
ID: "chatcmpl-test",
Model: "test-model",
Choices: []BifrostResponseChoice{
{
ChatStreamResponseChoice: &ChatStreamResponseChoice{
Delta: &ChatStreamResponseChoiceDelta{
Role: &role,
},
},
},
},
}).ToBifrostResponsesStreamResponse(state)...)
// Tool call chunk (no text content at all)
all = append(all, (&BifrostChatResponse{
ID: "chatcmpl-test",
Model: "test-model",
Choices: []BifrostResponseChoice{
{
ChatStreamResponseChoice: &ChatStreamResponseChoice{
Delta: &ChatStreamResponseChoiceDelta{
ToolCalls: []ChatAssistantMessageToolCall{
{
Index: 0,
ID: &toolCallID,
Function: ChatAssistantMessageToolCallFunction{Name: &funcName, Arguments: `{"q":"test"}`},
},
},
},
},
},
},
}).ToBifrostResponsesStreamResponse(state)...)
// Finish with tool_calls
all = append(all, (&BifrostChatResponse{
ID: "chatcmpl-test",
Model: "test-model",
Choices: []BifrostResponseChoice{
{
FinishReason: &toolCallsFinish,
ChatStreamResponseChoice: &ChatStreamResponseChoice{
Delta: &ChatStreamResponseChoiceDelta{},
},
},
},
}).ToBifrostResponsesStreamResponse(state)...)
var completed *BifrostResponsesStreamResponse
for _, evt := range all {
if evt != nil && evt.Type == ResponsesStreamResponseTypeCompleted {
completed = evt
}
}
if completed == nil || completed.Response == nil {
t.Fatal("expected response.completed event")
}
output := completed.Response.Output
if len(output) != 1 {
t.Fatalf("expected 1 output item (function_call only), got %d", len(output))
}
item := output[0]
if item.Type == nil || *item.Type != ResponsesMessageTypeFunctionCall {
t.Fatal("expected function_call type")
}
if item.ResponsesToolMessage == nil {
t.Fatal("function_call item missing ResponsesToolMessage")
}
if item.Name == nil || *item.Name != "get_weather" {
t.Fatalf("expected function name %q, got %v", "get_weather", item.Name)
}
if item.Arguments == nil || *item.Arguments != `{"q":"test"}` {
t.Fatalf("expected arguments %q, got %v", `{"q":"test"}`, item.Arguments)
}
}
func TestToBifrostResponsesStreamResponse_MapsLengthToIncompleteEvent(t *testing.T) {
state := AcquireChatToResponsesStreamState()
defer ReleaseChatToResponsesStreamState(state)
makeChunk := func(role *string, content *string, finishReason *string) *BifrostChatResponse {
return &BifrostChatResponse{
ID: "chatcmpl-test",
Model: "test-model",
Choices: []BifrostResponseChoice{
{
FinishReason: finishReason,
ChatStreamResponseChoice: &ChatStreamResponseChoice{
Delta: &ChatStreamResponseChoiceDelta{
Role: role,
Content: content,
},
},
},
},
}
}
role := string(ChatMessageRoleAssistant)
part := "Hello"
length := string(BifrostFinishReasonLength)
var all []*BifrostResponsesStreamResponse
all = append(all, makeChunk(&role, nil, nil).ToBifrostResponsesStreamResponse(state)...)
all = append(all, makeChunk(nil, &part, nil).ToBifrostResponsesStreamResponse(state)...)
all = append(all, makeChunk(nil, nil, &length).ToBifrostResponsesStreamResponse(state)...)
var completed *BifrostResponsesStreamResponse
var incomplete *BifrostResponsesStreamResponse
for _, evt := range all {
if evt == nil {
continue
}
if evt.Type == ResponsesStreamResponseTypeCompleted {
completed = evt
}
if evt.Type == ResponsesStreamResponseTypeIncomplete {
incomplete = evt
}
}
if completed != nil {
t.Fatal("did not expect response.completed for finish_reason=length")
}
if incomplete == nil || incomplete.Response == nil {
t.Fatal("expected response.incomplete with response payload")
}
if incomplete.Response.Status == nil || *incomplete.Response.Status != "incomplete" {
t.Fatal("expected terminal response status to be incomplete")
}
if incomplete.Response.IncompleteDetails == nil || incomplete.Response.IncompleteDetails.Reason != "max_output_tokens" {
t.Fatal("expected incomplete_details.reason to be max_output_tokens")
}
}

99
core/schemas/oauth.go Normal file
View File

@@ -0,0 +1,99 @@
package schemas
import (
"context"
"time"
)
// OauthProvider interface defines OAuth operations
type OAuth2Provider interface {
// GetAccessToken retrieves the access token for a given oauth_config_id (server-level OAuth)
GetAccessToken(ctx context.Context, oauthConfigID string) (string, error)
// RefreshAccessToken refreshes the access token for a given oauth_config_id
RefreshAccessToken(ctx context.Context, oauthConfigID string) error
// ValidateToken checks if the token is still valid
ValidateToken(ctx context.Context, oauthConfigID string) (bool, error)
// RevokeToken revokes the OAuth token
RevokeToken(ctx context.Context, oauthConfigID string) error
// Per-user OAuth methods
// GetUserAccessToken retrieves the access token for a per-user OAuth session.
// If the token is expired, it automatically attempts a refresh.
GetUserAccessToken(ctx context.Context, sessionToken string) (string, error)
// GetUserAccessTokenByIdentity retrieves the upstream access token for a user
// identified by virtualKeyID, userID, or sessionToken (fallback), for a specific
// MCP client. Tokens looked up by identity persist across sessions.
GetUserAccessTokenByIdentity(ctx context.Context, virtualKeyID, userID, sessionToken, mcpClientID string) (string, error)
// InitiateUserOAuthFlow creates a per-user OAuth session and returns the authorization URL.
// Returns (flow initiation details, session ID for polling, error).
InitiateUserOAuthFlow(ctx context.Context, oauthConfigID string, mcpClientID string, redirectURI string) (*OAuth2FlowInitiation, string, error)
// CompleteUserOAuthFlow handles the OAuth callback for a per-user flow.
// Returns the session token that the user should send on subsequent requests.
CompleteUserOAuthFlow(ctx context.Context, state string, code string) (string, error)
// RefreshUserAccessToken refreshes a per-user OAuth access token.
RefreshUserAccessToken(ctx context.Context, sessionToken string) error
// RevokeUserToken revokes a per-user OAuth token and marks the session as revoked.
RevokeUserToken(ctx context.Context, sessionToken string) error
}
// OauthConfig represents OAuth client configuration
type OAuth2Config struct {
ID string `json:"id"`
ClientID string `json:"client_id,omitempty"` // Optional: Will be obtained via dynamic registration (RFC 7591) if not provided
ClientSecret string `json:"client_secret,omitempty"` // Optional: For public clients using PKCE, or obtained via dynamic registration
AuthorizeURL string `json:"authorize_url,omitempty"` // Optional: Will be discovered from ServerURL if not provided
TokenURL string `json:"token_url,omitempty"` // Optional: Will be discovered from ServerURL if not provided
RegistrationURL *string `json:"registration_url,omitempty"` // Optional: For dynamic client registration (RFC 7591), can be discovered
RedirectURI string `json:"redirect_uri"` // Required
Scopes []string `json:"scopes,omitempty"` // Optional: Can be discovered
ServerURL string `json:"server_url"` // MCP server URL for OAuth discovery (required if URLs not provided)
UseDiscovery bool `json:"use_discovery,omitempty"` // Deprecated: Discovery now happens automatically when URLs are missing
}
// OauthToken represents OAuth access and refresh tokens
type OAuth2Token struct {
ID string `json:"id"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresAt time.Time `json:"expires_at"`
Scopes []string `json:"scopes"`
LastRefreshedAt *time.Time `json:"last_refreshed_at,omitempty"`
}
// OauthFlowInitiation represents the response when initiating an OAuth flow
type OAuth2FlowInitiation struct {
OauthConfigID string `json:"oauth_config_id"`
AuthorizeURL string `json:"authorize_url"`
State string `json:"state"`
ExpiresAt time.Time `json:"expires_at"`
}
// OAuth2TokenExchangeRequest represents the OAuth token exchange request
type OAuth2TokenExchangeRequest struct {
GrantType string `json:"grant_type"`
Code string `json:"code,omitempty"`
RedirectURI string `json:"redirect_uri,omitempty"`
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
CodeVerifier string `json:"code_verifier,omitempty"` // PKCE verifier for authorization_code grant
}
// OAuth2TokenExchangeResponse represents the OAuth token exchange response
type OAuth2TokenExchangeResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope,omitempty"`
}

89
core/schemas/ocr.go Normal file
View File

@@ -0,0 +1,89 @@
package schemas
// OCRDocumentType specifies the type of document input for an OCR request.
type OCRDocumentType string
const (
// OCRDocumentTypeDocumentURL represents a document URL input (e.g., PDF URL or base64 data URL).
OCRDocumentTypeDocumentURL OCRDocumentType = "document_url"
// OCRDocumentTypeImageURL represents an image URL input.
OCRDocumentTypeImageURL OCRDocumentType = "image_url"
)
// OCRDocument represents the document input for an OCR request.
type OCRDocument struct {
Type OCRDocumentType `json:"type"`
DocumentURL *string `json:"document_url,omitempty"`
ImageURL *string `json:"image_url,omitempty"`
}
// OCRParameters contains optional parameters for an OCR request.
type OCRParameters struct {
IncludeImageBase64 *bool `json:"include_image_base64,omitempty"`
Pages []int `json:"pages,omitempty"`
ImageLimit *int `json:"image_limit,omitempty"`
ImageMinSize *int `json:"image_min_size,omitempty"`
TableFormat *string `json:"table_format,omitempty"`
ExtractHeader *bool `json:"extract_header,omitempty"`
ExtractFooter *bool `json:"extract_footer,omitempty"`
BBoxAnnotationFormat *string `json:"bbox_annotation_format,omitempty"`
DocumentAnnotationFormat *string `json:"document_annotation_format,omitempty"`
DocumentAnnotationPrompt *string `json:"document_annotation_prompt,omitempty"`
ExtraParams map[string]interface{} `json:"-"`
}
// BifrostOCRRequest represents a request to perform OCR on a document.
type BifrostOCRRequest struct {
Provider ModelProvider `json:"provider"`
Model string `json:"model"`
ID *string `json:"id,omitempty"`
Document OCRDocument `json:"document"`
Params *OCRParameters `json:"params,omitempty"`
Fallbacks []Fallback `json:"fallbacks,omitempty"`
RawRequestBody []byte `json:"-"`
}
// GetRawRequestBody returns the raw request body for the OCR request.
func (r *BifrostOCRRequest) GetRawRequestBody() []byte {
return r.RawRequestBody
}
// OCRPageImage represents an extracted image from an OCR page.
type OCRPageImage struct {
ID string `json:"id"`
TopLeftX float64 `json:"top_left_x"`
TopLeftY float64 `json:"top_left_y"`
BottomRightX float64 `json:"bottom_right_x"`
BottomRightY float64 `json:"bottom_right_y"`
ImageBase64 *string `json:"image_base64,omitempty"`
}
// OCRPageDimensions represents the dimensions of an OCR page.
type OCRPageDimensions struct {
DPI int `json:"dpi"`
Height int `json:"height"`
Width int `json:"width"`
}
// OCRPage represents a single processed page from an OCR response.
type OCRPage struct {
Index int `json:"index"`
Markdown string `json:"markdown"`
Images []OCRPageImage `json:"images,omitempty"`
Dimensions *OCRPageDimensions `json:"dimensions,omitempty"`
}
// OCRUsageInfo represents usage information from an OCR response.
type OCRUsageInfo struct {
PagesProcessed int `json:"pages_processed"`
DocSizeBytes int `json:"doc_size_bytes"`
}
// BifrostOCRResponse represents the response from an OCR request.
type BifrostOCRResponse struct {
Model string `json:"model"`
Pages []OCRPage `json:"pages"`
UsageInfo *OCRUsageInfo `json:"usage_info,omitempty"`
DocumentAnnotation *string `json:"document_annotation,omitempty"`
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}

634
core/schemas/orderedmap.go Normal file
View File

@@ -0,0 +1,634 @@
package schemas
import (
"bytes"
"encoding/json"
"fmt"
"sort"
)
// OrderedMap is a map that preserves insertion order of keys.
// It stores key-value pairs and maintains the order in which keys were first inserted.
// It is NOT safe for concurrent use.
type OrderedMap struct {
keys []string
values map[string]interface{}
}
// Pair is a key-value pair for constructing OrderedMaps with order preserved.
type Pair struct {
Key string
Value interface{}
}
// KV is a shorthand constructor for Pair.
func KV(key string, value interface{}) Pair {
return Pair{Key: key, Value: value}
}
// NewOrderedMap creates a new empty OrderedMap.
func NewOrderedMap() *OrderedMap {
return &OrderedMap{
values: make(map[string]interface{}),
}
}
// NewOrderedMapWithCapacity creates a new empty OrderedMap with preallocated capacity.
func NewOrderedMapWithCapacity(cap int) *OrderedMap {
return &OrderedMap{
keys: make([]string, 0, cap),
values: make(map[string]interface{}, cap),
}
}
// NewOrderedMapFromPairs creates an OrderedMap from key-value pairs, preserving the given order.
func NewOrderedMapFromPairs(pairs ...Pair) *OrderedMap {
om := &OrderedMap{
keys: make([]string, 0, len(pairs)),
values: make(map[string]interface{}, len(pairs)),
}
for _, p := range pairs {
om.Set(p.Key, p.Value)
}
return om
}
// OrderedMapFromMap creates an OrderedMap from a plain map.
// Key order is NOT guaranteed since Go maps have undefined iteration order.
// Use this only when insertion order doesn't matter (e.g., for hashing).
func OrderedMapFromMap(m map[string]interface{}) *OrderedMap {
if m == nil {
return nil
}
om := &OrderedMap{
keys: make([]string, 0, len(m)),
values: make(map[string]interface{}, len(m)),
}
for k, v := range m {
om.keys = append(om.keys, k)
om.values[k] = v
}
return om
}
// Get returns the value associated with the key and whether the key exists.
func (om *OrderedMap) Get(key string) (interface{}, bool) {
if om == nil {
return nil, false
}
v, ok := om.values[key]
return v, ok
}
// Set sets the value for a key. If the key is new, it is appended to the end.
// If the key already exists, its value is updated in place without changing order.
func (om *OrderedMap) Set(key string, value interface{}) {
if om.values == nil {
om.values = make(map[string]interface{})
}
if _, exists := om.values[key]; !exists {
om.keys = append(om.keys, key)
}
om.values[key] = value
}
// Delete removes a key and its value. The key is also removed from the ordered keys list.
func (om *OrderedMap) Delete(key string) {
if om == nil {
return
}
if _, exists := om.values[key]; !exists {
return
}
delete(om.values, key)
for i, k := range om.keys {
if k == key {
om.keys = append(om.keys[:i], om.keys[i+1:]...)
break
}
}
}
// Len returns the number of key-value pairs.
func (om *OrderedMap) Len() int {
if om == nil {
return 0
}
return len(om.keys)
}
// Keys returns the keys in insertion order. The returned slice is a copy.
func (om *OrderedMap) Keys() []string {
if om == nil {
return nil
}
out := make([]string, len(om.keys))
copy(out, om.keys)
return out
}
// Range iterates over key-value pairs in insertion order.
// If fn returns false, iteration stops.
func (om *OrderedMap) Range(fn func(key string, value interface{}) bool) {
if om == nil {
return
}
for _, k := range om.keys {
if !fn(k, om.values[k]) {
break
}
}
}
// Clone creates a shallow copy of the OrderedMap (keys and top-level values are copied,
// but nested values share references).
func (om *OrderedMap) Clone() *OrderedMap {
if om == nil {
return nil
}
clone := &OrderedMap{
keys: make([]string, len(om.keys)),
values: make(map[string]interface{}, len(om.values)),
}
copy(clone.keys, om.keys)
for k, v := range om.values {
clone.values[k] = v
}
return clone
}
// ToMap returns a plain map[string]interface{} with the same key-value pairs.
// The returned map does not preserve insertion order.
func (om *OrderedMap) ToMap() map[string]interface{} {
if om == nil {
return nil
}
m := make(map[string]interface{}, len(om.values))
for k, v := range om.values {
m[k] = v
}
return m
}
// MarshalJSON serializes the OrderedMap to JSON, preserving insertion order of keys.
// Uses a value receiver so that both OrderedMap and *OrderedMap invoke this method
// (critical for []OrderedMap slices like AnyOf/OneOf/AllOf in ToolFunctionParameters).
func (om OrderedMap) MarshalJSON() ([]byte, error) {
if om.values == nil {
return []byte("null"), nil
}
var buf bytes.Buffer
buf.WriteByte('{')
for i, k := range om.keys {
if i > 0 {
buf.WriteByte(',')
}
// key
keyBytes, err := MarshalSorted(k)
if err != nil {
return nil, err
}
buf.Write(keyBytes)
buf.WriteByte(':')
// value — nested *OrderedMap values will use their own MarshalJSON
valBytes, err := MarshalSorted(om.values[k])
if err != nil {
return nil, err
}
buf.Write(valBytes)
}
buf.WriteByte('}')
return buf.Bytes(), nil
}
// MarshalSorted serializes the OrderedMap to JSON with keys sorted alphabetically.
// Use this when deterministic output is needed regardless of insertion order (e.g., hashing).
func (om *OrderedMap) MarshalSorted() ([]byte, error) {
if om == nil {
return []byte("null"), nil
}
keys := make([]string, len(om.keys))
copy(keys, om.keys)
sort.Strings(keys)
var buf bytes.Buffer
buf.WriteByte('{')
for i, k := range keys {
if i > 0 {
buf.WriteByte(',')
}
keyBytes, err := MarshalSorted(k)
if err != nil {
return nil, err
}
buf.Write(keyBytes)
buf.WriteByte(':')
valBytes, err := MarshalSorted(om.values[k])
if err != nil {
return nil, err
}
buf.Write(valBytes)
}
buf.WriteByte('}')
return buf.Bytes(), nil
}
// UnmarshalJSON deserializes JSON into the OrderedMap, preserving the key order
// from the JSON document. Nested objects are also deserialized as *OrderedMap.
// Note: uses encoding/json.Decoder (not sonic) because token-by-token decoding
// is required to preserve key order from the JSON document.
func (om *OrderedMap) UnmarshalJSON(data []byte) error {
// Handle null
trimmed := bytes.TrimSpace(data)
if bytes.Equal(trimmed, []byte("null")) {
om.keys = nil
om.values = nil
return nil
}
dec := json.NewDecoder(bytes.NewReader(data))
// Read opening brace
t, err := dec.Token()
if err != nil {
return fmt.Errorf("orderedmap: expected '{': %w", err)
}
delim, ok := t.(json.Delim)
if !ok || delim != '{' {
return fmt.Errorf("orderedmap: expected '{', got %v", t)
}
om.keys = om.keys[:0]
if om.values == nil {
om.values = make(map[string]interface{})
} else {
for k := range om.values {
delete(om.values, k)
}
}
for dec.More() {
// Read key
keyToken, err := dec.Token()
if err != nil {
return fmt.Errorf("orderedmap: reading key: %w", err)
}
key, ok := keyToken.(string)
if !ok {
return fmt.Errorf("orderedmap: expected string key, got %T", keyToken)
}
// Read value, preserving nested object order
value, err := decodeOrderedValue(dec)
if err != nil {
return fmt.Errorf("orderedmap: reading value for key %q: %w", key, err)
}
om.Set(key, value)
}
// Read closing brace
if _, err := dec.Token(); err != nil {
return fmt.Errorf("orderedmap: expected '}': %w", err)
}
return nil
}
// jsonSchemaPriority maps JSON Schema keywords to their preferred
// serialization position. Keys present in this map are emitted first
// (in the given order), followed by all remaining keys alphabetically.
// This matches the optimal ordering for LLM tool schemas: the model
// sees type and description before properties, constraints, etc.
var jsonSchemaPriority = map[string]int{
"type": 0,
"description": 1,
"properties": 2,
"required": 3,
}
// SortKeys sorts the keys of this OrderedMap using JSON Schema priority
// ordering (type, description, properties, required first), with remaining
// keys sorted alphabetically. Nested *OrderedMap values are also sorted
// recursively.
func (om *OrderedMap) SortKeys() {
if om == nil || len(om.keys) == 0 {
return
}
sort.Slice(om.keys, func(i, j int) bool {
pi, okI := jsonSchemaPriority[om.keys[i]]
pj, okJ := jsonSchemaPriority[om.keys[j]]
switch {
case okI && okJ:
return pi < pj
case okI:
return true
case okJ:
return false
default:
return om.keys[i] < om.keys[j]
}
})
for k, v := range om.values {
switch nested := v.(type) {
case *OrderedMap:
nested.SortKeys()
case map[string]interface{}:
converted := OrderedMapFromMap(nested)
converted.SortKeys()
om.values[k] = converted
case []interface{}:
sortOrderedMapsInSlice(nested)
}
}
}
func sortOrderedMapsInSlice(s []interface{}) {
for i, item := range s {
switch v := item.(type) {
case *OrderedMap:
v.SortKeys()
case map[string]interface{}:
converted := OrderedMapFromMap(v)
converted.SortKeys()
s[i] = converted
case []interface{}:
sortOrderedMapsInSlice(v)
}
}
}
// SortedCopy returns a new OrderedMap with keys sorted using JSON Schema
// priority ordering. Nested *OrderedMap values are recursively copied and
// sorted. Primitive values (strings, numbers, bools) are shared, not cloned.
// This is much cheaper than a full JSON marshal/unmarshal Clone because it
// only allocates new key slices and value maps.
func (om *OrderedMap) SortedCopy() *OrderedMap {
if om == nil {
return nil
}
if len(om.keys) == 0 {
return &OrderedMap{values: make(map[string]interface{})}
}
newKeys := make([]string, len(om.keys))
copy(newKeys, om.keys)
sort.Slice(newKeys, func(i, j int) bool {
pi, okI := jsonSchemaPriority[newKeys[i]]
pj, okJ := jsonSchemaPriority[newKeys[j]]
switch {
case okI && okJ:
return pi < pj
case okI:
return true
case okJ:
return false
default:
return newKeys[i] < newKeys[j]
}
})
newValues := make(map[string]interface{}, len(om.values))
for k, v := range om.values {
switch nested := v.(type) {
case *OrderedMap:
newValues[k] = nested.SortedCopy()
case map[string]interface{}:
newValues[k] = OrderedMapFromMap(nested).SortedCopy()
case []interface{}:
newValues[k] = sortedCopySlice(nested)
default:
newValues[k] = v
}
}
return &OrderedMap{keys: newKeys, values: newValues}
}
func sortedCopySlice(s []interface{}) []interface{} {
out := make([]interface{}, len(s))
for i, item := range s {
switch v := item.(type) {
case *OrderedMap:
out[i] = v.SortedCopy()
case map[string]interface{}:
out[i] = OrderedMapFromMap(v).SortedCopy()
case []interface{}:
out[i] = sortedCopySlice(v)
default:
out[i] = item
}
}
return out
}
// SortedCopyPreservingProperties is like SortedCopy but preserves the key
// order of user-defined property names inside "properties" maps. Structural
// JSON Schema keys (type, description, properties, required) are still sorted
// by priority, and all other keys alphabetically. When the key "properties"
// is encountered, its value (an OrderedMap of user-defined field names) has
// its top-level key order preserved while each nested schema value is
// recursively processed with the same property-aware logic.
//
// This ensures deterministic serialization for prompt caching (structural keys
// are always in the same order) while preserving the client's intended field
// generation order for LLM structured output.
func (om *OrderedMap) SortedCopyPreservingProperties() *OrderedMap {
if om == nil {
return nil
}
if len(om.keys) == 0 {
return &OrderedMap{values: make(map[string]interface{})}
}
newKeys := make([]string, len(om.keys))
copy(newKeys, om.keys)
sort.Slice(newKeys, func(i, j int) bool {
pi, okI := jsonSchemaPriority[newKeys[i]]
pj, okJ := jsonSchemaPriority[newKeys[j]]
switch {
case okI && okJ:
return pi < pj
case okI:
return true
case okJ:
return false
default:
return newKeys[i] < newKeys[j]
}
})
newValues := make(map[string]interface{}, len(om.values))
for k, v := range om.values {
if k == "properties" {
// User-defined property names: preserve key order, sort nested schemas
newValues[k] = preserveKeysOrderedCopyWithAwareness(v)
} else {
switch nested := v.(type) {
case *OrderedMap:
newValues[k] = nested.SortedCopyPreservingProperties()
case map[string]interface{}:
newValues[k] = OrderedMapFromMap(nested).SortedCopyPreservingProperties()
case []interface{}:
newValues[k] = sortedCopySlicePreservingProperties(nested)
default:
newValues[k] = v
}
}
}
return &OrderedMap{keys: newKeys, values: newValues}
}
// preserveKeysOrderedCopyWithAwareness copies an OrderedMap preserving its
// top-level key order (these are user-defined property names) while
// recursively applying SortedCopyPreservingProperties to each value (each
// value is a schema that may itself contain "properties").
// If the input is not an *OrderedMap, it falls back to SortedCopyPreservingProperties.
func preserveKeysOrderedCopyWithAwareness(v interface{}) interface{} {
switch om := v.(type) {
case *OrderedMap:
return om.preserveKeysWithPropertyAwareness()
case map[string]interface{}:
// Plain maps have non-deterministic iteration order in Go;
// convert and sort since we can't preserve an order that doesn't exist.
return OrderedMapFromMap(om).SortedCopyPreservingProperties()
default:
return v
}
}
// preserveKeysWithPropertyAwareness preserves the top-level key order of this
// OrderedMap while recursively applying SortedCopyPreservingProperties to each
// nested value.
func (om *OrderedMap) preserveKeysWithPropertyAwareness() *OrderedMap {
if om == nil {
return nil
}
if len(om.keys) == 0 {
return &OrderedMap{values: make(map[string]interface{})}
}
// Preserve original key order (no sorting)
newKeys := make([]string, len(om.keys))
copy(newKeys, om.keys)
newValues := make(map[string]interface{}, len(om.values))
for k, v := range om.values {
switch nested := v.(type) {
case *OrderedMap:
newValues[k] = nested.SortedCopyPreservingProperties()
case map[string]interface{}:
newValues[k] = OrderedMapFromMap(nested).SortedCopyPreservingProperties()
case []interface{}:
newValues[k] = sortedCopySlicePreservingProperties(nested)
default:
newValues[k] = v
}
}
return &OrderedMap{keys: newKeys, values: newValues}
}
func sortedCopySlicePreservingProperties(s []interface{}) []interface{} {
out := make([]interface{}, len(s))
for i, item := range s {
switch v := item.(type) {
case *OrderedMap:
out[i] = v.SortedCopyPreservingProperties()
case map[string]interface{}:
out[i] = OrderedMapFromMap(v).SortedCopyPreservingProperties()
case []interface{}:
out[i] = sortedCopySlicePreservingProperties(v)
default:
out[i] = item
}
}
return out
}
// decodeOrderedValue reads a single JSON value from the decoder.
// Objects are decoded as *OrderedMap (preserving key order).
// Arrays are decoded as []interface{} with each element recursively decoded.
// Primitives are returned as their Go equivalents.
func decodeOrderedValue(dec *json.Decoder) (interface{}, error) {
t, err := dec.Token()
if err != nil {
return nil, err
}
switch v := t.(type) {
case json.Delim:
if v == '{' {
// Recursively parse nested object as *OrderedMap
nested := NewOrderedMap()
for dec.More() {
keyToken, err := dec.Token()
if err != nil {
return nil, err
}
key, ok := keyToken.(string)
if !ok {
return nil, fmt.Errorf("expected string key, got %T", keyToken)
}
val, err := decodeOrderedValue(dec)
if err != nil {
return nil, err
}
nested.Set(key, val)
}
// Consume closing '}'
if _, err := dec.Token(); err != nil {
return nil, err
}
return nested, nil
}
if v == '[' {
// Parse array elements recursively
var arr []interface{}
for dec.More() {
val, err := decodeOrderedValue(dec)
if err != nil {
return nil, err
}
arr = append(arr, val)
}
// Consume closing ']'
if _, err := dec.Token(); err != nil {
return nil, err
}
if arr == nil {
arr = []interface{}{}
}
return arr, nil
}
return nil, fmt.Errorf("unexpected delimiter: %v", v)
case string:
return v, nil
case float64:
return v, nil
case bool:
return v, nil
case nil:
return nil, nil
case json.Number:
f, err := v.Float64()
if err != nil {
return v.String(), nil
}
return f, nil
default:
return v, nil
}
}

View File

@@ -0,0 +1,619 @@
package schemas
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewOrderedMap(t *testing.T) {
om := NewOrderedMap()
assert.NotNil(t, om)
assert.Equal(t, 0, om.Len())
assert.Empty(t, om.Keys())
}
func TestNewOrderedMapFromPairs(t *testing.T) {
om := NewOrderedMapFromPairs(
KV("b", 2),
KV("a", 1),
KV("c", 3),
)
assert.Equal(t, 3, om.Len())
assert.Equal(t, []string{"b", "a", "c"}, om.Keys())
v, ok := om.Get("a")
assert.True(t, ok)
assert.Equal(t, 1, v)
}
func TestOrderedMap_SetPreservesInsertionOrder(t *testing.T) {
om := NewOrderedMap()
om.Set("z", 1)
om.Set("a", 2)
om.Set("m", 3)
assert.Equal(t, []string{"z", "a", "m"}, om.Keys())
}
func TestOrderedMap_SetUpdateInPlace(t *testing.T) {
om := NewOrderedMap()
om.Set("a", 1)
om.Set("b", 2)
om.Set("a", 10) // update, not re-append
assert.Equal(t, []string{"a", "b"}, om.Keys())
v, ok := om.Get("a")
assert.True(t, ok)
assert.Equal(t, 10, v)
}
func TestOrderedMap_Delete(t *testing.T) {
om := NewOrderedMapFromPairs(
KV("a", 1),
KV("b", 2),
KV("c", 3),
)
om.Delete("b")
assert.Equal(t, 2, om.Len())
assert.Equal(t, []string{"a", "c"}, om.Keys())
_, ok := om.Get("b")
assert.False(t, ok)
}
func TestOrderedMap_DeleteNonExistent(t *testing.T) {
om := NewOrderedMapFromPairs(KV("a", 1))
om.Delete("b") // should not panic
assert.Equal(t, 1, om.Len())
}
func TestOrderedMap_Range(t *testing.T) {
om := NewOrderedMapFromPairs(
KV("x", 1),
KV("y", 2),
KV("z", 3),
)
var keys []string
var vals []interface{}
om.Range(func(key string, value interface{}) bool {
keys = append(keys, key)
vals = append(vals, value)
return true
})
assert.Equal(t, []string{"x", "y", "z"}, keys)
assert.Equal(t, []interface{}{1, 2, 3}, vals)
}
func TestOrderedMap_RangeEarlyStop(t *testing.T) {
om := NewOrderedMapFromPairs(
KV("a", 1),
KV("b", 2),
KV("c", 3),
)
var keys []string
om.Range(func(key string, _ interface{}) bool {
keys = append(keys, key)
return key != "b" // stop after "b"
})
assert.Equal(t, []string{"a", "b"}, keys)
}
func TestOrderedMap_Clone(t *testing.T) {
om := NewOrderedMapFromPairs(
KV("a", 1),
KV("b", 2),
)
clone := om.Clone()
assert.Equal(t, om.Keys(), clone.Keys())
// Modifying clone doesn't affect original
clone.Set("c", 3)
assert.Equal(t, 2, om.Len())
assert.Equal(t, 3, clone.Len())
}
func TestOrderedMap_ToMap(t *testing.T) {
om := NewOrderedMapFromPairs(
KV("a", 1),
KV("b", "hello"),
)
m := om.ToMap()
assert.Equal(t, map[string]interface{}{"a": 1, "b": "hello"}, m)
}
func TestOrderedMap_NilSafety(t *testing.T) {
var om *OrderedMap
assert.Equal(t, 0, om.Len())
assert.Nil(t, om.Keys())
assert.Nil(t, om.Clone())
assert.Nil(t, om.ToMap())
v, ok := om.Get("key")
assert.Nil(t, v)
assert.False(t, ok)
// Range on nil should not panic
om.Range(func(key string, value interface{}) bool {
t.Fatal("should not be called")
return true
})
// Delete on nil should not panic
om.Delete("key")
}
func TestOrderedMap_MarshalJSON_PreservesOrder(t *testing.T) {
om := NewOrderedMapFromPairs(
KV("z_last", 1),
KV("a_first", 2),
KV("m_middle", 3),
)
data, err := json.Marshal(om)
require.NoError(t, err)
assert.Equal(t, `{"z_last":1,"a_first":2,"m_middle":3}`, string(data))
}
func TestOrderedMap_MarshalJSON_Empty(t *testing.T) {
om := NewOrderedMap()
data, err := json.Marshal(om)
require.NoError(t, err)
assert.Equal(t, `{}`, string(data))
}
func TestOrderedMap_MarshalJSON_NilValues(t *testing.T) {
om := OrderedMap{} // zero value, values is nil
data, err := json.Marshal(om)
require.NoError(t, err)
assert.Equal(t, `null`, string(data))
}
func TestOrderedMap_UnmarshalJSON_PreservesOrder(t *testing.T) {
input := `{"z_last":1,"a_first":"two","m_middle":true}`
var om OrderedMap
err := json.Unmarshal([]byte(input), &om)
require.NoError(t, err)
assert.Equal(t, []string{"z_last", "a_first", "m_middle"}, om.Keys())
v, ok := om.Get("z_last")
assert.True(t, ok)
assert.Equal(t, float64(1), v) // JSON numbers are float64
v, ok = om.Get("a_first")
assert.True(t, ok)
assert.Equal(t, "two", v)
v, ok = om.Get("m_middle")
assert.True(t, ok)
assert.Equal(t, true, v)
}
func TestOrderedMap_UnmarshalJSON_NestedObjects(t *testing.T) {
input := `{"outer_b":{"inner_z":1,"inner_a":2},"outer_a":"simple"}`
var om OrderedMap
err := json.Unmarshal([]byte(input), &om)
require.NoError(t, err)
assert.Equal(t, []string{"outer_b", "outer_a"}, om.Keys())
nested, ok := om.Get("outer_b")
assert.True(t, ok)
nestedOM, ok := nested.(*OrderedMap)
require.True(t, ok, "nested object should be *OrderedMap, got %T", nested)
assert.Equal(t, []string{"inner_z", "inner_a"}, nestedOM.Keys())
}
func TestOrderedMap_UnmarshalJSON_Null(t *testing.T) {
var om OrderedMap
err := json.Unmarshal([]byte("null"), &om)
require.NoError(t, err)
assert.Equal(t, 0, om.Len())
}
func TestOrderedMap_JSONRoundTrip(t *testing.T) {
original := NewOrderedMapFromPairs(
KV("answer", map[string]interface{}{
"type": "string",
"description": "The answer to the question",
}),
KV("chain_of_thought", map[string]interface{}{
"type": "string",
"description": "Reasoning chain",
}),
KV("citations", map[string]interface{}{
"type": "array",
"description": "Sources",
}),
)
data, err := json.Marshal(original)
require.NoError(t, err)
var restored OrderedMap
err = json.Unmarshal(data, &restored)
require.NoError(t, err)
assert.Equal(t, original.Keys(), restored.Keys())
}
func TestOrderedMap_MarshalSorted(t *testing.T) {
om := NewOrderedMapFromPairs(
KV("z", 1),
KV("a", 2),
KV("m", 3),
)
data, err := om.MarshalSorted()
require.NoError(t, err)
assert.Equal(t, `{"a":2,"m":3,"z":1}`, string(data))
}
func TestOrderedMap_MarshalSorted_Nil(t *testing.T) {
var om *OrderedMap
data, err := om.MarshalSorted()
require.NoError(t, err)
assert.Equal(t, `null`, string(data))
}
func TestOrderedMapFromMap(t *testing.T) {
m := map[string]interface{}{"a": 1, "b": 2}
om := OrderedMapFromMap(m)
assert.Equal(t, 2, om.Len())
v, ok := om.Get("a")
assert.True(t, ok)
assert.Equal(t, 1, v)
}
func TestOrderedMapFromMap_Nil(t *testing.T) {
om := OrderedMapFromMap(nil)
assert.Nil(t, om)
}
func TestOrderedMap_NestedOrderedMapMarshal(t *testing.T) {
inner := NewOrderedMapFromPairs(
KV("z_prop", "last"),
KV("a_prop", "first"),
)
outer := NewOrderedMapFromPairs(
KV("properties", inner),
KV("type", "object"),
)
data, err := json.Marshal(outer)
require.NoError(t, err)
assert.Equal(t, `{"properties":{"z_prop":"last","a_prop":"first"},"type":"object"}`, string(data))
}
func TestOrderedMap_UnmarshalThenMarshalPreservesOrder(t *testing.T) {
// This is the core use case: JSON comes in with a specific order,
// we deserialize, then re-serialize and the order is preserved.
input := `{"answer":"string","chain_of_thought":"string","citations":"array","is_unanswered":"boolean"}`
var om OrderedMap
err := json.Unmarshal([]byte(input), &om)
require.NoError(t, err)
output, err := json.Marshal(om)
require.NoError(t, err)
assert.Equal(t, input, string(output))
}
func TestOrderedMap_SortKeys_PlainMapValues(t *testing.T) {
om := NewOrderedMapFromPairs(
KV("properties", map[string]interface{}{
"z_field": map[string]interface{}{
"type": "string",
"description": "last field",
},
"a_field": map[string]interface{}{
"type": "number",
"description": "first field",
},
}),
KV("type", "object"),
)
om.SortKeys()
assert.Equal(t, []string{"type", "properties"}, om.Keys(), "top-level keys should be schema-sorted")
props, ok := om.Get("properties")
require.True(t, ok)
propsOM, ok := props.(*OrderedMap)
require.True(t, ok, "plain map should be converted to *OrderedMap, got %T", props)
assert.Equal(t, []string{"a_field", "z_field"}, propsOM.Keys(), "nested plain map keys should be sorted")
zField, ok := propsOM.Get("z_field")
require.True(t, ok)
zFieldOM, ok := zField.(*OrderedMap)
require.True(t, ok, "deeply nested plain map should be converted to *OrderedMap, got %T", zField)
assert.Equal(t, []string{"type", "description"}, zFieldOM.Keys(), "deeply nested keys should be schema-sorted")
}
func TestOrderedMap_SortKeys_PlainMapInSlice(t *testing.T) {
om := NewOrderedMapFromPairs(
KV("anyOf", []interface{}{
map[string]interface{}{
"description": "first option",
"type": "string",
},
}),
)
om.SortKeys()
anyOf, ok := om.Get("anyOf")
require.True(t, ok)
slice, ok := anyOf.([]interface{})
require.True(t, ok)
require.Len(t, slice, 1)
elemOM, ok := slice[0].(*OrderedMap)
require.True(t, ok, "plain map in slice should be converted to *OrderedMap, got %T", slice[0])
assert.Equal(t, []string{"type", "description"}, elemOM.Keys())
}
func TestOrderedMap_SortedCopy_PlainMapValues(t *testing.T) {
om := NewOrderedMapFromPairs(
KV("properties", map[string]interface{}{
"z_field": "string",
"a_field": "number",
}),
KV("type", "object"),
)
sorted := om.SortedCopy()
assert.Equal(t, []string{"type", "properties"}, sorted.Keys())
props, ok := sorted.Get("properties")
require.True(t, ok)
propsOM, ok := props.(*OrderedMap)
require.True(t, ok, "plain map should be converted to *OrderedMap in SortedCopy, got %T", props)
assert.Equal(t, []string{"a_field", "z_field"}, propsOM.Keys())
origProps, _ := om.Get("properties")
_, isMap := origProps.(map[string]interface{})
assert.True(t, isMap, "original should be unmodified (still a plain map)")
}
// --- SortedCopyPreservingProperties tests ---
func TestOrderedMap_SortedCopyPreservingProperties_Basic(t *testing.T) {
// Schema with structural keys in wrong order and user properties in non-alpha order
om := NewOrderedMapFromPairs(
KV("required", []interface{}{"chain_of_thought"}),
KV("properties", NewOrderedMapFromPairs(
KV("chain_of_thought", NewOrderedMapFromPairs(
KV("description", "Reasoning steps"),
KV("type", "string"),
)),
KV("answer", NewOrderedMapFromPairs(
KV("description", "The answer"),
KV("type", "string"),
)),
KV("citations", NewOrderedMapFromPairs(
KV("type", "array"),
)),
)),
KV("type", "object"),
)
result := om.SortedCopyPreservingProperties()
// Caching: structural keys sorted by JSON Schema priority
assert.Equal(t, []string{"type", "properties", "required"}, result.Keys())
// CoT: user-defined property names preserved in original order
props, ok := result.Get("properties")
require.True(t, ok)
propsOM := props.(*OrderedMap)
assert.Equal(t, []string{"chain_of_thought", "answer", "citations"}, propsOM.Keys())
// Caching: structural keys within each property are sorted
cot, _ := propsOM.Get("chain_of_thought")
cotOM := cot.(*OrderedMap)
assert.Equal(t, []string{"type", "description"}, cotOM.Keys(), "structural keys within property should be sorted")
// Immutability: original unchanged
assert.Equal(t, []string{"required", "properties", "type"}, om.Keys())
origProps, _ := om.Get("properties")
origPropsOM := origProps.(*OrderedMap)
assert.Equal(t, []string{"chain_of_thought", "answer", "citations"}, origPropsOM.Keys())
}
func TestOrderedMap_SortedCopyPreservingProperties_NestedObjects(t *testing.T) {
// Schema where a property is itself an object with nested properties
om := NewOrderedMapFromPairs(
KV("type", "object"),
KV("properties", NewOrderedMapFromPairs(
KV("reasoning", NewOrderedMapFromPairs(
KV("type", "string"),
)),
KV("address", NewOrderedMapFromPairs(
KV("type", "object"),
KV("properties", NewOrderedMapFromPairs(
KV("street", NewOrderedMapFromPairs(KV("type", "string"))),
KV("city", NewOrderedMapFromPairs(KV("type", "string"))),
KV("zip", NewOrderedMapFromPairs(KV("type", "string"))),
)),
)),
KV("answer", NewOrderedMapFromPairs(
KV("type", "string"),
)),
)),
)
result := om.SortedCopyPreservingProperties()
// CoT: outer property names preserved
props, _ := result.Get("properties")
propsOM := props.(*OrderedMap)
assert.Equal(t, []string{"reasoning", "address", "answer"}, propsOM.Keys())
// CoT: inner nested property names preserved
addr, _ := propsOM.Get("address")
addrOM := addr.(*OrderedMap)
innerProps, _ := addrOM.Get("properties")
innerPropsOM := innerProps.(*OrderedMap)
assert.Equal(t, []string{"street", "city", "zip"}, innerPropsOM.Keys())
}
func TestOrderedMap_SortedCopyPreservingProperties_ThreeLevelNesting(t *testing.T) {
om := NewOrderedMapFromPairs(
KV("type", "object"),
KV("properties", NewOrderedMapFromPairs(
KV("organization", NewOrderedMapFromPairs(
KV("type", "object"),
KV("properties", NewOrderedMapFromPairs(
KV("department", NewOrderedMapFromPairs(
KV("type", "object"),
KV("properties", NewOrderedMapFromPairs(
KV("team_lead", NewOrderedMapFromPairs(KV("type", "string"))),
KV("team_name", NewOrderedMapFromPairs(KV("type", "string"))),
)),
)),
KV("budget", NewOrderedMapFromPairs(KV("type", "number"))),
)),
)),
KV("summary", NewOrderedMapFromPairs(KV("type", "string"))),
)),
)
result := om.SortedCopyPreservingProperties()
// Level 1: organization, summary
props, _ := result.Get("properties")
propsOM := props.(*OrderedMap)
assert.Equal(t, []string{"organization", "summary"}, propsOM.Keys())
// Level 2: department, budget
org, _ := propsOM.Get("organization")
orgProps, _ := org.(*OrderedMap).Get("properties")
orgPropsOM := orgProps.(*OrderedMap)
assert.Equal(t, []string{"department", "budget"}, orgPropsOM.Keys())
// Level 3: team_lead, team_name
dept, _ := orgPropsOM.Get("department")
deptProps, _ := dept.(*OrderedMap).Get("properties")
deptPropsOM := deptProps.(*OrderedMap)
assert.Equal(t, []string{"team_lead", "team_name"}, deptPropsOM.Keys())
}
func TestOrderedMap_SortedCopyPreservingProperties_WithDefs(t *testing.T) {
// $defs definition names should be sorted (for caching), but properties
// within each definition should be preserved
om := NewOrderedMapFromPairs(
KV("$defs", NewOrderedMapFromPairs(
KV("Metadata", NewOrderedMapFromPairs(
KV("type", "object"),
KV("properties", NewOrderedMapFromPairs(
KV("latency_ms", NewOrderedMapFromPairs(KV("type", "number"))),
KV("model_version", NewOrderedMapFromPairs(KV("type", "string"))),
)),
)),
KV("Citation", NewOrderedMapFromPairs(
KV("type", "object"),
KV("properties", NewOrderedMapFromPairs(
KV("url", NewOrderedMapFromPairs(KV("type", "string"))),
KV("text", NewOrderedMapFromPairs(KV("type", "string"))),
)),
)),
)),
KV("type", "object"),
KV("properties", NewOrderedMapFromPairs(
KV("answer", NewOrderedMapFromPairs(KV("type", "string"))),
)),
)
result := om.SortedCopyPreservingProperties()
// Caching: $defs definition names are sorted alphabetically
defs, _ := result.Get("$defs")
defsOM := defs.(*OrderedMap)
assert.Equal(t, []string{"Citation", "Metadata"}, defsOM.Keys())
// CoT: properties within each $def are preserved
meta, _ := defsOM.Get("Metadata")
metaProps, _ := meta.(*OrderedMap).Get("properties")
metaPropsOM := metaProps.(*OrderedMap)
assert.Equal(t, []string{"latency_ms", "model_version"}, metaPropsOM.Keys())
citation, _ := defsOM.Get("Citation")
citProps, _ := citation.(*OrderedMap).Get("properties")
citPropsOM := citProps.(*OrderedMap)
assert.Equal(t, []string{"url", "text"}, citPropsOM.Keys())
}
func TestOrderedMap_SortedCopyPreservingProperties_NilAndEmpty(t *testing.T) {
// nil returns nil
var nilOM *OrderedMap
assert.Nil(t, nilOM.SortedCopyPreservingProperties())
// Empty returns empty
empty := &OrderedMap{keys: []string{}, values: make(map[string]interface{})}
result := empty.SortedCopyPreservingProperties()
assert.NotNil(t, result)
assert.Equal(t, 0, result.Len())
// No "properties" key behaves like SortedCopy
noProps := NewOrderedMapFromPairs(
KV("required", []interface{}{"a"}),
KV("type", "object"),
KV("description", "test"),
)
result = noProps.SortedCopyPreservingProperties()
assert.Equal(t, []string{"type", "description", "required"}, result.Keys())
}
func TestOrderedMap_SortedCopyPreservingProperties_PlainMapInsideProperties(t *testing.T) {
// When properties contains a plain map (not OrderedMap), it should be
// converted and have its nested values processed
om := NewOrderedMapFromPairs(
KV("type", "object"),
KV("properties", map[string]interface{}{
"field_a": map[string]interface{}{"description": "first", "type": "string"},
}),
)
result := om.SortedCopyPreservingProperties()
// The properties value should be converted to *OrderedMap
props, ok := result.Get("properties")
require.True(t, ok)
_, isOM := props.(*OrderedMap)
assert.True(t, isOM, "plain map should be converted to *OrderedMap")
}
func TestOrderedMap_EmptyArray(t *testing.T) {
input := `{"items":[],"name":"test"}`
var om OrderedMap
err := json.Unmarshal([]byte(input), &om)
require.NoError(t, err)
v, ok := om.Get("items")
assert.True(t, ok)
assert.Equal(t, []interface{}{}, v)
output, err := json.Marshal(om)
require.NoError(t, err)
assert.Equal(t, input, string(output))
}

View File

@@ -0,0 +1,63 @@
package schemas
import (
"encoding/base64"
"encoding/json"
"fmt"
)
// SerialCursor tracks pagination state for serial key exhaustion.
// When paginating across multiple keys, we exhaust all pages from one key
// before moving to the next, ensuring only one API call per pagination request.
type SerialCursor struct {
Version int `json:"v"` // Version for compatibility
KeyIndex int `json:"i"` // Current key index in sorted keys array
Cursor string `json:"c"` // Native cursor for current key (empty = start fresh)
}
// EncodeSerialCursor encodes a SerialCursor to a base64 string for transport.
func EncodeSerialCursor(cursor *SerialCursor) string {
if cursor == nil {
return ""
}
data, err := json.Marshal(cursor)
if err != nil {
return ""
}
return base64.URLEncoding.EncodeToString(data)
}
// DecodeSerialCursor decodes a base64 string back to a SerialCursor.
// Returns (nil, nil) if the encoded string is empty; returns an error for invalid data.
func DecodeSerialCursor(encoded string) (*SerialCursor, error) {
if encoded == "" {
return nil, nil
}
data, err := base64.URLEncoding.DecodeString(encoded)
if err != nil {
return nil, fmt.Errorf("failed to decode cursor: %w", err)
}
var cursor SerialCursor
if err := json.Unmarshal(data, &cursor); err != nil {
return nil, fmt.Errorf("failed to unmarshal cursor: %w", err)
}
// Validate version
if cursor.Version != 1 {
return nil, fmt.Errorf("unsupported cursor version: %d", cursor.Version)
}
return &cursor, nil
}
// NewSerialCursor creates a new SerialCursor with version 1.
func NewSerialCursor(keyIndex int, cursor string) *SerialCursor {
return &SerialCursor{
Version: 1,
KeyIndex: keyIndex,
Cursor: cursor,
}
}

View File

@@ -0,0 +1,26 @@
package schemas
type BifrostPassthroughRequest struct {
Provider ModelProvider // provider extracted from path or body, used for key selection when non-empty
Model string // model extracted from path or body, used for key selection when non-empty
Method string
Path string // stripped path, e.g. "/v1/fine-tuning/jobs"
RawQuery string // raw query string, no "?"
Body []byte
SafeHeaders map[string]string // client headers, auth already stripped
}
type BifrostPassthroughResponse struct {
StatusCode int
Headers map[string]string
Body []byte
BodyTruncated bool
ExtraFields BifrostResponseExtraFields
}
type PassthroughLogParams struct {
Method string `json:"method"`
Path string `json:"path"` // stripped path, e.g. "/v1/fine-tuning/jobs"
RawQuery string `json:"raw_query"` // raw query string, no "?"
StatusCode int `json:"status_code"`
}

327
core/schemas/plugin.go Normal file
View File

@@ -0,0 +1,327 @@
// Package schemas defines the core schemas and types used by the Bifrost system.
package schemas
import (
"context"
"strings"
"sync"
)
// PluginStatus constants
const (
PluginStatusActive = "active"
PluginStatusError = "error"
PluginStatusDisabled = "disabled"
PluginStatusLoading = "loading"
PluginStatusUninitialized = "uninitialized"
PluginStatusUnloaded = "unloaded"
PluginStatusLoaded = "loaded"
)
// PluginStatus represents the status of a plugin.
type PluginStatus struct {
Name string `json:"name"` // Display name of the plugin
Status string `json:"status"`
Logs []string `json:"logs"`
Types []PluginType `json:"types"` // Plugin types (LLM, MCP, HTTP)
}
// PluginType represents the type of plugin.
type PluginType string
const (
PluginTypeLLM PluginType = "llm"
PluginTypeMCP PluginType = "mcp"
PluginTypeHTTP PluginType = "http"
)
// HTTPRequest is a serializable representation of an HTTP request.
// Used for plugin HTTP transport interception (supports both native .so and WASM plugins).
// This type is pooled for allocation control - use AcquireHTTPRequest and ReleaseHTTPRequest.
type HTTPRequest struct {
Method string `json:"method"`
Path string `json:"path"`
Headers map[string]string `json:"headers"`
Query map[string]string `json:"query"`
Body []byte `json:"body"`
PathParams map[string]string `json:"path_params"` // Path variables extracted from the URL pattern (e.g., {model})
}
// CaseInsensitiveHeaderLookup looks up a header key in a case-insensitive manner
func (req *HTTPRequest) CaseInsensitiveHeaderLookup(key string) string {
return caseInsensitiveLookup(req.Headers, key)
}
// CaseInsensitiveQueryLookup looks up a query key in a case-insensitive manner
func (req *HTTPRequest) CaseInsensitiveQueryLookup(key string) string {
return caseInsensitiveLookup(req.Query, key)
}
// CaseInsensitivePathParamLookup looks up a path parameter key in a case-insensitive manner
func (req *HTTPRequest) CaseInsensitivePathParamLookup(key string) string {
return caseInsensitiveLookup(req.PathParams, key)
}
// CaseInsensitiveLookup looks up a key in a case-insensitive manner for a map of strings
// Returns the value if found, otherwise an empty string
func caseInsensitiveLookup(data map[string]string, key string) string {
if data == nil || key == "" {
return ""
}
// exact match
if v, ok := data[key]; ok {
return v
}
// lower key checks
lowerKey := strings.ToLower(key)
if v, ok := data[lowerKey]; ok {
return v
}
// case-insensitive iteration
for k, v := range data {
if strings.EqualFold(k, key) {
return v
}
}
return ""
}
// HTTPResponse is a serializable representation of an HTTP response.
// Used for short-circuit responses in plugin HTTP transport interception.
// This type is pooled for allocation control - use AcquireHTTPResponse and ReleaseHTTPResponse.
type HTTPResponse struct {
StatusCode int `json:"status_code"`
Headers map[string]string `json:"headers"`
Body []byte `json:"body"`
}
// httpRequestPool is the pool for HTTPRequest objects to reduce allocations.
var httpRequestPool = sync.Pool{
New: func() any {
return &HTTPRequest{
Headers: make(map[string]string, 16),
Query: make(map[string]string, 8),
PathParams: make(map[string]string, 4),
}
},
}
// AcquireHTTPRequest gets an HTTPRequest from the pool.
// The returned HTTPRequest is ready to use with pre-allocated maps.
// Call ReleaseHTTPRequest when done to return it to the pool.
func AcquireHTTPRequest() *HTTPRequest {
return httpRequestPool.Get().(*HTTPRequest)
}
// ReleaseHTTPRequest returns an HTTPRequest to the pool.
// The HTTPRequest is reset before being returned to the pool.
// Do not use the HTTPRequest after calling this function.
func ReleaseHTTPRequest(req *HTTPRequest) {
if req == nil {
return
}
// Clear the maps
clear(req.Headers)
clear(req.Query)
clear(req.PathParams)
// Reset fields
req.Method = ""
req.Path = ""
req.Body = nil
httpRequestPool.Put(req)
}
// httpResponsePool is the pool for HTTPResponse objects to reduce allocations.
var httpResponsePool = sync.Pool{
New: func() any {
return &HTTPResponse{
Headers: make(map[string]string, 8),
}
},
}
// AcquireHTTPResponse gets an HTTPResponse from the pool.
// The returned HTTPResponse is ready to use with a pre-allocated Headers map.
// Call ReleaseHTTPResponse when done to return it to the pool.
func AcquireHTTPResponse() *HTTPResponse {
return httpResponsePool.Get().(*HTTPResponse)
}
// ReleaseHTTPResponse returns an HTTPResponse to the pool.
// The HTTPResponse is reset before being returned to the pool.
// Do not use the HTTPResponse after calling this function.
func ReleaseHTTPResponse(resp *HTTPResponse) {
if resp == nil {
return
}
// Clear the map
clear(resp.Headers)
// Reset fields
resp.StatusCode = 0
resp.Body = nil
httpResponsePool.Put(resp)
}
// Plugin defines the interface for Bifrost plugins.
// Plugins can intercept and modify requests and responses at different stages
// of the processing pipeline.
// User can provide multiple plugins in the BifrostConfig.
// PreHooks are executed in the order they are registered.
// PostHooks are executed in the reverse order of PreHooks.
//
// Execution order:
// 1. HTTPTransportPreHook (HTTP transport only, executed in registration order)
// 2. PreLLMHook (executed in registration order)
// 3. Provider call
// 4. PostLLMHook (executed in reverse order of PreHooks)
// 5. HTTPTransportPostHook (HTTP transport only, executed in reverse order)
// 5a. HTTPTransportStreamChunkHook (for streaming responses, called per-chunk in reverse order)
//
// Common use cases: rate limiting, caching, logging, monitoring, request transformation, governance.
//
// Plugin error handling:
// - No Plugin errors are returned to the caller; they are logged as warnings by the Bifrost instance.
// - PreLLMHook and PostLLMHook can both modify the request/response and the error. Plugins can recover from errors (set error to nil and provide a response), or invalidate a response (set response to nil and provide an error).
// - PostLLMHook is always called with both the current response and error, and should handle either being nil.
// - Only truly empty errors (no message, no error, no status code, no type) are treated as recoveries by the pipeline.
// - If a PreLLMHook returns a LLMPluginShortCircuit, the provider call may be skipped and only the PostLLMHook methods of plugins that had their PreLLMHook executed are called in reverse order.
// - The plugin pipeline ensures symmetry: for every PreLLMHook executed, the corresponding PostLLMHook will be called in reverse order.
//
// IMPORTANT: When returning BifrostError from PreLLMHook or PostLLMHook:
// - You can set the AllowFallbacks field to control fallback behavior
// - AllowFallbacks = &true: Allow Bifrost to try fallback providers
// - AllowFallbacks = &false: Do not try fallbacks, return error immediately
// - AllowFallbacks = nil: Treated as true by default (allow fallbacks for resilience)
//
// Plugin authors should ensure their hooks are robust to both response and error being nil, and should not assume either is always present.
type BasePlugin interface {
// GetName returns the name of the plugin.
GetName() string
// Cleanup is called on bifrost shutdown.
// It allows plugins to clean up any resources they have allocated.
// Returns any error that occurred during cleanup, which will be logged as a warning by the Bifrost instance.
Cleanup() error
}
type HTTPTransportPlugin interface {
BasePlugin
// HTTPTransportPreHook is called at the HTTP transport layer before requests enter Bifrost core.
// It receives a serializable HTTPRequest and allows plugins to modify it in-place.
// Only invoked when using HTTP transport (bifrost-http), not when using Bifrost as a Go SDK directly.
// Works with both native .so plugins and WASM plugins due to serializable types.
//
// Return values:
// - (nil, nil): Continue to next plugin/handler, request modifications are applied
// - (*HTTPResponse, nil): Short-circuit with this response, skip remaining plugins and provider call
// - (nil, error): Short-circuit with error response
//
// Return nil for both values if the plugin doesn't need HTTP transport interception.
HTTPTransportPreHook(ctx *BifrostContext, req *HTTPRequest) (*HTTPResponse, error)
// HTTPTransportPostHook is called at the HTTP transport layer after requests exit Bifrost core.
// It receives a serializable HTTPRequest and HTTPResponse and allows plugins to modify it in-place.
// Only invoked when using HTTP transport (bifrost-http), not when using Bifrost as a Go SDK directly.
// Works with both native .so plugins and WASM plugins due to serializable types.
// NOTE: This hook is NOT called for streaming responses. Use HTTPTransportStreamChunkHook instead.
// NOTE: For large streamed responses (non-streaming APIs that switch to body streaming for memory safety),
// resp.Body may be nil by design while StatusCode and Headers remain populated.
//
// Return values:
// - nil: Continue to next plugin/handler, response modifications are applied
// - error: Short-circuit with error response and skip remaining plugins
//
// Return nil if the plugin doesn't need HTTP transport interception.
HTTPTransportPostHook(ctx *BifrostContext, req *HTTPRequest, resp *HTTPResponse) error
// HTTPTransportStreamChunkHook is called for each chunk during streaming responses.
// It receives the BifrostStreamChunk BEFORE they are written to the client.
// Only invoked for streaming responses when using HTTP transport (bifrost-http).
// Works with both native .so plugins and WASM plugins due to serializable types.
//
// Plugins can modify the chunk by returning a different BifrostStreamChunk.
// Return the original chunk unchanged if no modification is needed.
//
// Return values:
// - (*BifrostStreamChunk, nil): Continue with the (potentially modified) BifrostStreamChunk
// - (nil, nil): Skip this BifrostStreamChunk entirely (don't send to client)
// - (*BifrostStreamChunk, error): Log warning and continue with the BifrostStreamChunk
// - (nil, error): Send back error to the client and stop the streaming
//
// Return (*BifrostStreamChunk, nil) unchanged if the plugin doesn't need streaming chunk interception.
HTTPTransportStreamChunkHook(ctx *BifrostContext, req *HTTPRequest, chunk *BifrostStreamChunk) (*BifrostStreamChunk, error)
}
type LLMPlugin interface {
BasePlugin
PreLLMHook(ctx *BifrostContext, req *BifrostRequest) (*BifrostRequest, *LLMPluginShortCircuit, error)
PostLLMHook(ctx *BifrostContext, resp *BifrostResponse, bifrostErr *BifrostError) (*BifrostResponse, *BifrostError, error)
}
type MCPPlugin interface {
BasePlugin
PreMCPHook(ctx *BifrostContext, req *BifrostMCPRequest) (*BifrostMCPRequest, *MCPPluginShortCircuit, error)
PostMCPHook(ctx *BifrostContext, resp *BifrostMCPResponse, bifrostErr *BifrostError) (*BifrostMCPResponse, *BifrostError, error)
}
// Plugin placement constants control where custom plugins execute relative to built-in plugins.
type PluginPlacement string
const (
PluginPlacementPostBuiltin PluginPlacement = "post_builtin"
PluginPlacementPreBuiltin PluginPlacement = "pre_builtin"
PluginPlacementBuiltin PluginPlacement = "builtin"
PluginPlacementDefault PluginPlacement = PluginPlacementPostBuiltin
)
// PluginConfig is the configuration for a plugin.
// It contains the name of the plugin, whether it is enabled, and the configuration for the plugin.
type PluginConfig struct {
Enabled bool `json:"enabled"`
Name string `json:"name"`
Path *string `json:"path,omitempty"`
Version *int16 `json:"version,omitempty"`
Config any `json:"config,omitempty"`
Placement *PluginPlacement `json:"placement,omitempty"` // "pre_builtin" or "post_builtin". Default: "post_builtin"
Order *int `json:"order,omitempty"` // Position within placement group. Lower = earlier. Default: 0
}
// ObservabilityPlugin is an interface for plugins that receive completed traces
// for forwarding to observability backends (e.g., OTEL collectors, Datadog, etc.)
//
// ObservabilityPlugins are called asynchronously after the HTTP response has been
// written to the wire, ensuring they don't add latency to the client response.
//
// Plugins implementing this interface will:
// 1. Continue to work as regular plugins via PreLLMHook/PostLLMHook
// 2. Additionally receive completed traces via the Inject method
//
// Example backends: OpenTelemetry collectors, Datadog, Jaeger, Maxim, etc.
//
// Note: Go type assertion (plugin.(ObservabilityPlugin)) is used to identify
// plugins implementing this interface - no marker method is needed.
type ObservabilityPlugin interface {
BasePlugin
// Inject receives a completed trace for forwarding to observability backends.
// This method is called asynchronously after the response has been written to the client.
// The trace contains all spans that were added during request processing.
//
// Implementations should:
// - Convert the trace to their backend's format
// - Send the trace to the backend (can be async, but see retention note below)
// - Handle errors gracefully (log and continue)
//
// The context passed is a fresh background context, not the request context.
//
// Retention: implementations MUST NOT retain the *Trace pointer after Inject
// returns. The caller releases the trace back to a sync.Pool immediately after
// Inject completes, so any background goroutine that still references it will
// race with pool reuse. If a plugin needs to forward the trace asynchronously,
// it must copy the data it needs before returning.
Inject(ctx context.Context, trace *Trace) error
}

View File

@@ -0,0 +1,36 @@
//go:build !tinygo && !wasm
package schemas
import (
"github.com/valyala/fasthttp"
)
// BifrostHTTPMiddleware is a middleware function for the Bifrost HTTP transport.
// It follows the standard pattern: receives the next handler and returns a new handler.
// Used internally for CORS, Auth, Tracing middleware. Plugins use HTTPTransportIntercept instead.
type BifrostHTTPMiddleware func(next fasthttp.RequestHandler) fasthttp.RequestHandler
// EventBroadcaster is a generic callback for broadcasting typed events to connected clients (e.g., via WebSocket).
// Any plugin or subsystem can use this to push real-time updates to the frontend.
// eventType identifies the message (e.g., "governance_update"), data is the JSON-serializable payload.
type EventBroadcaster func(eventType string, data interface{})
// LLMPluginShortCircuit represents a plugin's decision to short-circuit the normal flow.
// It can contain either a response (success short-circuit), a stream (streaming short-circuit), or an error (error short-circuit).
type LLMPluginShortCircuit struct {
Response *BifrostResponse // If set, short-circuit with this response (skips provider call)
Stream chan *BifrostStreamChunk // If set, short-circuit with this stream (skips provider call)
Error *BifrostError // If set, short-circuit with this error (can set AllowFallbacks field)
}
// MCPPluginShortCircuit represents a plugin's decision to short-circuit the normal flow.
// It can contain either a response (success short-circuit), or an error (error short-circuit).
type MCPPluginShortCircuit struct {
Response *BifrostMCPResponse // If set, short-circuit with this response (skips MCP call)
Error *BifrostError // If set, short-circuit with this error (can set AllowFallbacks field)
}
// PluginShortCircuit is the legacy name for LLMPluginShortCircuit (v1.3.x compatibility).
// Deprecated: Use LLMPluginShortCircuit instead.
type PluginShortCircuit = LLMPluginShortCircuit

220
core/schemas/plugin_test.go Normal file
View File

@@ -0,0 +1,220 @@
package schemas
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestCaseInsensitiveLookup(t *testing.T) {
tests := []struct {
name string
data map[string]string
key string
expected string
}{
{
name: "nil map returns empty string",
data: nil,
key: "Content-Type",
expected: "",
},
{
name: "empty key returns empty string",
data: map[string]string{"Content-Type": "application/json"},
key: "",
expected: "",
},
{
name: "key not found returns empty string",
data: map[string]string{"Content-Type": "application/json"},
key: "Authorization",
expected: "",
},
{
name: "exact match",
data: map[string]string{"Content-Type": "application/json"},
key: "Content-Type",
expected: "application/json",
},
{
name: "lowercase key match - map has lowercase key",
data: map[string]string{"content-type": "application/json"},
key: "Content-Type",
expected: "application/json",
},
{
name: "lowercase key match - query is lowercase",
data: map[string]string{"content-type": "application/json"},
key: "content-type",
expected: "application/json",
},
{
name: "case-insensitive iteration - map has mixed case",
data: map[string]string{"Content-Type": "application/json"},
key: "content-type",
expected: "application/json",
},
{
name: "case-insensitive iteration - uppercase query",
data: map[string]string{"Content-Type": "application/json"},
key: "CONTENT-TYPE",
expected: "application/json",
},
{
name: "multiple keys - finds correct one",
data: map[string]string{"Accept": "text/html", "Content-Type": "application/json"},
key: "content-type",
expected: "application/json",
},
// x-bf-vk header variations
{
name: "x-bf-vk exact match lowercase",
data: map[string]string{"x-bf-vk": "sk-bf-test123"},
key: "x-bf-vk",
expected: "sk-bf-test123",
},
{
name: "x-bf-vk mixed case in map",
data: map[string]string{"X-Bf-Vk": "sk-bf-test123"},
key: "x-bf-vk",
expected: "sk-bf-test123",
},
{
name: "x-bf-vk uppercase in map",
data: map[string]string{"X-BF-VK": "sk-bf-test123"},
key: "x-bf-vk",
expected: "sk-bf-test123",
},
// authorization header variations
{
name: "authorization exact match lowercase",
data: map[string]string{"authorization": "Bearer sk-bf-test123"},
key: "authorization",
expected: "Bearer sk-bf-test123",
},
{
name: "authorization capitalized in map",
data: map[string]string{"Authorization": "Bearer sk-bf-test123"},
key: "authorization",
expected: "Bearer sk-bf-test123",
},
{
name: "authorization uppercase in map",
data: map[string]string{"AUTHORIZATION": "Bearer sk-bf-test123"},
key: "authorization",
expected: "Bearer sk-bf-test123",
},
// x-api-key header variations
{
name: "x-api-key exact match lowercase",
data: map[string]string{"x-api-key": "sk-bf-apikey123"},
key: "x-api-key",
expected: "sk-bf-apikey123",
},
{
name: "x-api-key mixed case in map",
data: map[string]string{"X-Api-Key": "sk-bf-apikey123"},
key: "x-api-key",
expected: "sk-bf-apikey123",
},
{
name: "x-api-key uppercase in map",
data: map[string]string{"X-API-KEY": "sk-bf-apikey123"},
key: "x-api-key",
expected: "sk-bf-apikey123",
},
// x-goog-api-key header variations
{
name: "x-goog-api-key exact match lowercase",
data: map[string]string{"x-goog-api-key": "sk-bf-google123"},
key: "x-goog-api-key",
expected: "sk-bf-google123",
},
{
name: "x-goog-api-key mixed case in map",
data: map[string]string{"X-Goog-Api-Key": "sk-bf-google123"},
key: "x-goog-api-key",
expected: "sk-bf-google123",
},
{
name: "x-goog-api-key uppercase in map",
data: map[string]string{"X-GOOG-API-KEY": "sk-bf-google123"},
key: "x-goog-api-key",
expected: "sk-bf-google123",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := caseInsensitiveLookup(tt.data, tt.key)
assert.Equal(t, tt.expected, result)
})
}
}
func TestHTTPRequest_CaseInsensitiveHeaderLookup(t *testing.T) {
tests := []struct {
name string
headers map[string]string
key string
expected string
}{
{
name: "exact match",
headers: map[string]string{"Content-Type": "application/json"},
key: "Content-Type",
expected: "application/json",
},
{
name: "case-insensitive match",
headers: map[string]string{"Content-Type": "application/json"},
key: "content-type",
expected: "application/json",
},
{
name: "authorization header",
headers: map[string]string{"Authorization": "Bearer token123"},
key: "authorization",
expected: "Bearer token123",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := &HTTPRequest{Headers: tt.headers}
result := req.CaseInsensitiveHeaderLookup(tt.key)
assert.Equal(t, tt.expected, result)
})
}
}
func TestHTTPRequest_CaseInsensitiveQueryLookup(t *testing.T) {
tests := []struct {
name string
query map[string]string
key string
expected string
}{
{
name: "exact match",
query: map[string]string{"apiKey": "test123"},
key: "apiKey",
expected: "test123",
},
{
name: "case-insensitive match",
query: map[string]string{"ApiKey": "test123"},
key: "apikey",
expected: "test123",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := &HTTPRequest{Query: tt.query}
result := req.CaseInsensitiveQueryLookup(tt.key)
assert.Equal(t, tt.expected, result)
})
}
}

View File

@@ -0,0 +1,15 @@
//go:build tinygo || wasm
package schemas
// LLMPluginShortCircuit represents a plugin's decision to short-circuit the normal flow.
// It can contain either a response (success short-circuit), a stream (streaming short-circuit), or an error (error short-circuit).
// Streams are not supported in WASM plugins.
type LLMPluginShortCircuit struct {
Response *BifrostResponse // If set, short-circuit with this response (skips provider call)
Error *BifrostError // If set, short-circuit with this error (can set AllowFallbacks field)
}
// PluginShortCircuit is the legacy name for LLMPluginShortCircuit (v1.3.x compatibility).
// Deprecated: Use LLMPluginShortCircuit instead.
type PluginShortCircuit = LLMPluginShortCircuit

614
core/schemas/provider.go Normal file
View File

@@ -0,0 +1,614 @@
// Package schemas defines the core schemas and types used by the Bifrost system.
package schemas
import (
"context"
"encoding/json"
"maps"
"time"
)
const (
DefaultMaxRetries = 0
DefaultRetryBackoffInitial = 500 * time.Millisecond
DefaultRetryBackoffMax = 5 * time.Second
DefaultRequestTimeoutInSeconds = 30
DefaultMaxConnDurationInSeconds = 300 // 5 minutes — forces connection recycling to prevent stale connections from NAT/LB silent drops
DefaultBufferSize = 5000
DefaultConcurrency = 1000
DefaultStreamBufferSize = 256
DefaultStreamIdleTimeoutInSeconds = 60 // Idle timeout per stream chunk — if no data for this many seconds, bifrost closes the connection
DefaultMaxConnsPerHost = 5000
MaxConnsPerHostUpperBound = 10000
DefaultMaxIdleConnsPerHost = 40
)
// Pre-defined errors for provider operations
const (
ErrProviderRequestTimedOut = "request timed out (default is 30 seconds). You can increase it by setting the default_request_timeout_in_seconds in the network_config or in UI - Providers > Provider Name > Network Config."
ErrRequestCancelled = "request cancelled by caller"
ErrRequestBodyConversion = "failed to convert bifrost request to the expected provider request body"
ErrProviderRequestMarshal = "failed to marshal request body to JSON"
ErrProviderCreateRequest = "failed to create HTTP request to provider API"
ErrProviderDoRequest = "failed to execute HTTP request to provider API"
ErrProviderNetworkError = "network error occurred while connecting to provider API (DNS lookup, connection refused, etc.)"
ErrProviderResponseDecode = "failed to decode response body from provider API"
ErrProviderResponseUnmarshal = "failed to unmarshal response from provider API"
ErrProviderResponseEmpty = "empty response received from provider"
ErrProviderResponseHTML = "HTML response received from provider"
ErrProviderRawRequestUnmarshal = "failed to unmarshal raw request from provider API"
ErrProviderRawResponseUnmarshal = "failed to unmarshal raw response from provider API"
ErrProviderResponseDecompress = "failed to decompress provider's response"
)
// NetworkConfig represents the network configuration for provider connections.
// ExtraHeaders is automatically copied during provider initialization to prevent data races.
//
// RetryBackoffInitial and RetryBackoffMax are stored internally as time.Duration (nanoseconds),
// but are serialized/deserialized to/from JSON as milliseconds (integers).
// This means:
// - In JSON: values are represented as milliseconds (e.g., 1000 means 1000ms)
// - In Go: values are time.Duration (e.g., 1000ms = 1000000000 nanoseconds)
// - When unmarshaling from JSON: a value of 1000 is interpreted as 1000ms, not 1000ns
// - When marshaling to JSON: a time.Duration is converted to milliseconds
type NetworkConfig struct {
// BaseURL is supported for OpenAI, Anthropic, Cohere, Mistral, and Ollama providers (required for Ollama)
BaseURL string `json:"base_url,omitempty"` // Base URL for the provider (optional)
ExtraHeaders map[string]string `json:"extra_headers,omitempty"` // Additional headers to include in requests (optional)
DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"` // Default timeout for requests
MaxRetries int `json:"max_retries"` // Maximum number of retries
RetryBackoffInitial time.Duration `json:"retry_backoff_initial"` // Initial backoff duration (stored as nanoseconds, JSON as milliseconds)
RetryBackoffMax time.Duration `json:"retry_backoff_max"` // Maximum backoff duration (stored as nanoseconds, JSON as milliseconds)
InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"` // Disables TLS certificate verification for provider connections
CACertPEM string `json:"ca_cert_pem,omitempty"` // PEM-encoded CA certificate to trust for provider endpoint connections
StreamIdleTimeoutInSeconds int `json:"stream_idle_timeout_in_seconds,omitempty"` // Idle timeout per stream chunk (0 = use default 60s)
MaxConnsPerHost int `json:"max_conns_per_host,omitempty"` // Max TCP connections per provider host (default: 5000)
EnforceHTTP2 bool `json:"enforce_http2,omitempty"` // Force HTTP/2 on provider connections (relevant for net/http-based providers like Bedrock)
BetaHeaderOverrides map[string]bool `json:"beta_header_overrides,omitempty"` // Override default beta header support per provider (keys are prefixes like "redact-thinking-")
}
// UnmarshalJSON customizes JSON unmarshaling for NetworkConfig.
// RetryBackoffInitial and RetryBackoffMax are interpreted as milliseconds in JSON,
// but stored as time.Duration (nanoseconds) internally.
func (nc *NetworkConfig) UnmarshalJSON(data []byte) error {
// Use an alias type to avoid infinite recursion
type NetworkConfigAlias struct {
BaseURL string `json:"base_url,omitempty"`
ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"`
MaxRetries int `json:"max_retries"`
RetryBackoffInitial int64 `json:"retry_backoff_initial"` // milliseconds in JSON
RetryBackoffMax int64 `json:"retry_backoff_max"` // milliseconds in JSON
InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"`
CACertPEM string `json:"ca_cert_pem,omitempty"`
StreamIdleTimeoutInSeconds int `json:"stream_idle_timeout_in_seconds,omitempty"`
MaxConnsPerHost int `json:"max_conns_per_host,omitempty"`
EnforceHTTP2 bool `json:"enforce_http2,omitempty"`
BetaHeaderOverrides map[string]bool `json:"beta_header_overrides,omitempty"`
}
var alias NetworkConfigAlias
if err := json.Unmarshal(data, &alias); err != nil {
return err
}
// Copy all fields
nc.BaseURL = alias.BaseURL
nc.ExtraHeaders = alias.ExtraHeaders
nc.DefaultRequestTimeoutInSeconds = alias.DefaultRequestTimeoutInSeconds
nc.MaxRetries = alias.MaxRetries
nc.InsecureSkipVerify = alias.InsecureSkipVerify
nc.CACertPEM = alias.CACertPEM
nc.StreamIdleTimeoutInSeconds = alias.StreamIdleTimeoutInSeconds
nc.MaxConnsPerHost = alias.MaxConnsPerHost
nc.EnforceHTTP2 = alias.EnforceHTTP2
nc.BetaHeaderOverrides = alias.BetaHeaderOverrides
// Convert milliseconds to time.Duration (nanoseconds)
// Only convert if value is greater than 0
if alias.RetryBackoffInitial > 0 {
nc.RetryBackoffInitial = time.Duration(alias.RetryBackoffInitial) * time.Millisecond
}
if alias.RetryBackoffMax > 0 {
nc.RetryBackoffMax = time.Duration(alias.RetryBackoffMax) * time.Millisecond
}
return nil
}
// MarshalJSON customizes JSON marshaling for NetworkConfig.
// RetryBackoffInitial and RetryBackoffMax are converted from time.Duration (nanoseconds)
// to milliseconds (integers) in JSON.
func (nc NetworkConfig) MarshalJSON() ([]byte, error) {
// Use an alias type to avoid infinite recursion
type NetworkConfigAlias struct {
BaseURL string `json:"base_url,omitempty"`
ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"`
MaxRetries int `json:"max_retries"`
RetryBackoffInitial int64 `json:"retry_backoff_initial"` // milliseconds in JSON
RetryBackoffMax int64 `json:"retry_backoff_max"` // milliseconds in JSON
InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"`
CACertPEM string `json:"ca_cert_pem,omitempty"`
StreamIdleTimeoutInSeconds int `json:"stream_idle_timeout_in_seconds,omitempty"`
MaxConnsPerHost int `json:"max_conns_per_host,omitempty"`
EnforceHTTP2 bool `json:"enforce_http2,omitempty"`
BetaHeaderOverrides map[string]bool `json:"beta_header_overrides,omitempty"`
}
alias := NetworkConfigAlias{
BaseURL: nc.BaseURL,
ExtraHeaders: nc.ExtraHeaders,
DefaultRequestTimeoutInSeconds: nc.DefaultRequestTimeoutInSeconds,
MaxRetries: nc.MaxRetries,
// Convert time.Duration (nanoseconds) to milliseconds
RetryBackoffInitial: int64(nc.RetryBackoffInitial / time.Millisecond),
RetryBackoffMax: int64(nc.RetryBackoffMax / time.Millisecond),
InsecureSkipVerify: nc.InsecureSkipVerify,
CACertPEM: nc.CACertPEM,
StreamIdleTimeoutInSeconds: nc.StreamIdleTimeoutInSeconds,
MaxConnsPerHost: nc.MaxConnsPerHost,
EnforceHTTP2: nc.EnforceHTTP2,
BetaHeaderOverrides: nc.BetaHeaderOverrides,
}
return json.Marshal(alias)
}
// Redacted returns a redacted copy of the network configuration with CACertPEM masked.
func (nc *NetworkConfig) Redacted() *NetworkConfig {
if nc == nil {
return nil
}
redacted := *nc
if nc.CACertPEM != "" {
redacted.CACertPEM = "<REDACTED>"
}
return &redacted
}
// DefaultNetworkConfig is the default network configuration for provider connections.
var DefaultNetworkConfig = NetworkConfig{
DefaultRequestTimeoutInSeconds: DefaultRequestTimeoutInSeconds,
MaxRetries: DefaultMaxRetries,
RetryBackoffInitial: DefaultRetryBackoffInitial,
RetryBackoffMax: DefaultRetryBackoffMax,
StreamIdleTimeoutInSeconds: DefaultStreamIdleTimeoutInSeconds,
MaxConnsPerHost: DefaultMaxConnsPerHost,
}
// ConcurrencyAndBufferSize represents configuration for concurrent operations and buffer sizes.
type ConcurrencyAndBufferSize struct {
Concurrency int `json:"concurrency"` // Number of concurrent operations. Also used as the initial pool size for the provider reponses.
BufferSize int `json:"buffer_size"` // Size of the buffer
}
// DefaultConcurrencyAndBufferSize is the default concurrency and buffer size for provider operations.
var DefaultConcurrencyAndBufferSize = ConcurrencyAndBufferSize{
Concurrency: DefaultConcurrency,
BufferSize: DefaultBufferSize,
}
// ProxyType defines the type of proxy to use for connections.
type ProxyType string
const (
// NoProxy indicates no proxy should be used
NoProxy ProxyType = "none"
// HTTPProxy indicates an HTTP proxy should be used
HTTPProxy ProxyType = "http"
// Socks5Proxy indicates a SOCKS5 proxy should be used
Socks5Proxy ProxyType = "socks5"
// EnvProxy indicates the proxy should be read from environment variables
EnvProxy ProxyType = "environment"
)
// ProxyConfig holds the configuration for proxy settings.
type ProxyConfig struct {
Type ProxyType `json:"type"` // Type of proxy to use
URL string `json:"url"` // URL of the proxy server
Username string `json:"username"` // Username for proxy authentication
Password string `json:"password"` // Password for proxy authentication
CACertPEM string `json:"ca_cert_pem"` // PEM-encoded CA certificate to trust for TLS connections through the proxy
}
// IsRedactedValue returns true if the value is redacted.
func (pc *ProxyConfig) IsRedactedValue(value string) bool {
return value == "<REDACTED>" || value == "********"
}
// Redacted returns a redacted copy of the proxy configuration.
func (pc *ProxyConfig) Redacted() *ProxyConfig {
// Create redacted config with same structure but redacted values
redactedConfig := ProxyConfig{
Type: pc.Type,
URL: pc.URL,
Username: pc.Username,
}
if pc.Password != "" {
redactedConfig.Password = "<REDACTED>"
}
if pc.CACertPEM != "" {
redactedConfig.CACertPEM = "<REDACTED>"
}
return &redactedConfig
}
// AllowedRequests controls which operations are permitted.
// A nil *AllowedRequests means "all operations allowed."
// A non-nil value only allows fields set to true; omitted or false fields are disallowed.
type AllowedRequests struct {
ListModels bool `json:"list_models"`
TextCompletion bool `json:"text_completion"`
TextCompletionStream bool `json:"text_completion_stream"`
ChatCompletion bool `json:"chat_completion"`
ChatCompletionStream bool `json:"chat_completion_stream"`
Responses bool `json:"responses"`
ResponsesStream bool `json:"responses_stream"`
CountTokens bool `json:"count_tokens"`
Embedding bool `json:"embedding"`
Rerank bool `json:"rerank"`
OCR bool `json:"ocr"`
Speech bool `json:"speech"`
SpeechStream bool `json:"speech_stream"`
Transcription bool `json:"transcription"`
TranscriptionStream bool `json:"transcription_stream"`
ImageGeneration bool `json:"image_generation"`
ImageGenerationStream bool `json:"image_generation_stream"`
ImageEdit bool `json:"image_edit"`
ImageEditStream bool `json:"image_edit_stream"`
ImageVariation bool `json:"image_variation"`
VideoGeneration bool `json:"video_generation"`
VideoRetrieve bool `json:"video_retrieve"`
VideoDownload bool `json:"video_download"`
VideoDelete bool `json:"video_delete"`
VideoList bool `json:"video_list"`
VideoRemix bool `json:"video_remix"`
BatchCreate bool `json:"batch_create"`
BatchList bool `json:"batch_list"`
BatchRetrieve bool `json:"batch_retrieve"`
BatchCancel bool `json:"batch_cancel"`
BatchDelete bool `json:"batch_delete"`
BatchResults bool `json:"batch_results"`
FileUpload bool `json:"file_upload"`
FileList bool `json:"file_list"`
FileRetrieve bool `json:"file_retrieve"`
FileDelete bool `json:"file_delete"`
FileContent bool `json:"file_content"`
ContainerCreate bool `json:"container_create"`
ContainerList bool `json:"container_list"`
ContainerRetrieve bool `json:"container_retrieve"`
ContainerDelete bool `json:"container_delete"`
ContainerFileCreate bool `json:"container_file_create"`
ContainerFileList bool `json:"container_file_list"`
ContainerFileRetrieve bool `json:"container_file_retrieve"`
ContainerFileContent bool `json:"container_file_content"`
ContainerFileDelete bool `json:"container_file_delete"`
Passthrough bool `json:"passthrough"`
PassthroughStream bool `json:"passthrough_stream"`
WebSocketResponses bool `json:"websocket_responses"`
Realtime bool `json:"realtime"`
}
// IsOperationAllowed checks if a specific operation is allowed
func (ar *AllowedRequests) IsOperationAllowed(operation RequestType) bool {
if ar == nil {
return true // Default to allowed if no restrictions
}
switch operation {
case ListModelsRequest:
return ar.ListModels
case TextCompletionRequest:
return ar.TextCompletion
case TextCompletionStreamRequest:
return ar.TextCompletionStream
case ChatCompletionRequest:
return ar.ChatCompletion
case ChatCompletionStreamRequest:
return ar.ChatCompletionStream
case ResponsesRequest:
return ar.Responses
case ResponsesStreamRequest:
return ar.ResponsesStream
case CountTokensRequest:
return ar.CountTokens
case EmbeddingRequest:
return ar.Embedding
case RerankRequest:
return ar.Rerank
case OCRRequest:
return ar.OCR
case SpeechRequest:
return ar.Speech
case SpeechStreamRequest:
return ar.SpeechStream
case TranscriptionRequest:
return ar.Transcription
case TranscriptionStreamRequest:
return ar.TranscriptionStream
case ImageGenerationRequest:
return ar.ImageGeneration
case ImageGenerationStreamRequest:
return ar.ImageGenerationStream
case ImageEditRequest:
return ar.ImageEdit
case ImageEditStreamRequest:
return ar.ImageEditStream
case ImageVariationRequest:
return ar.ImageVariation
case VideoGenerationRequest:
return ar.VideoGeneration
case VideoRetrieveRequest:
return ar.VideoRetrieve
case VideoDownloadRequest:
return ar.VideoDownload
case VideoDeleteRequest:
return ar.VideoDelete
case VideoListRequest:
return ar.VideoList
case VideoRemixRequest:
return ar.VideoRemix
case BatchCreateRequest:
return ar.BatchCreate
case BatchListRequest:
return ar.BatchList
case BatchRetrieveRequest:
return ar.BatchRetrieve
case BatchCancelRequest:
return ar.BatchCancel
case BatchDeleteRequest:
return ar.BatchDelete
case BatchResultsRequest:
return ar.BatchResults
case FileUploadRequest:
return ar.FileUpload
case FileListRequest:
return ar.FileList
case FileRetrieveRequest:
return ar.FileRetrieve
case FileDeleteRequest:
return ar.FileDelete
case FileContentRequest:
return ar.FileContent
case ContainerCreateRequest:
return ar.ContainerCreate
case ContainerListRequest:
return ar.ContainerList
case ContainerRetrieveRequest:
return ar.ContainerRetrieve
case ContainerDeleteRequest:
return ar.ContainerDelete
case ContainerFileCreateRequest:
return ar.ContainerFileCreate
case ContainerFileListRequest:
return ar.ContainerFileList
case ContainerFileRetrieveRequest:
return ar.ContainerFileRetrieve
case ContainerFileContentRequest:
return ar.ContainerFileContent
case ContainerFileDeleteRequest:
return ar.ContainerFileDelete
case PassthroughRequest:
return ar.Passthrough
case PassthroughStreamRequest:
return ar.PassthroughStream
case WebSocketResponsesRequest:
return ar.WebSocketResponses
case RealtimeRequest:
return ar.Realtime
default:
return false // Default to not allowed for unknown operations
}
}
type CustomProviderConfig struct {
CustomProviderKey string `json:"-"` // Custom provider key, internally set by Bifrost
IsKeyLess bool `json:"is_key_less"` // Whether the custom provider requires a key (not allowed for Bedrock)
BaseProviderType ModelProvider `json:"base_provider_type"` // Base provider type
AllowedRequests *AllowedRequests `json:"allowed_requests,omitempty"` // Allowed requests for the custom provider
RequestPathOverrides map[RequestType]string `json:"request_path_overrides,omitempty"` // Mapping of request type to its custom path which will override the default path of the provider (not allowed for Bedrock)
}
// IsOperationAllowed checks if a specific operation is allowed for this custom provider
func (cpc *CustomProviderConfig) IsOperationAllowed(operation RequestType) bool {
if cpc == nil || cpc.AllowedRequests == nil {
return true // Default to allowed if no restrictions
}
return cpc.AllowedRequests.IsOperationAllowed(operation)
}
// ProviderConfig represents the complete configuration for a provider.
// An array of ProviderConfig needs to be provided in GetConfigForProvider
// in your account interface implementation.
type ProviderConfig struct {
NetworkConfig NetworkConfig `json:"network_config"` // Network configuration
ConcurrencyAndBufferSize ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"` // Concurrency settings
// Logger instance, can be provided by the user or bifrost default logger is used if not provided
Logger Logger `json:"-"`
ProxyConfig *ProxyConfig `json:"proxy_config,omitempty"` // Proxy configuration
SendBackRawRequest bool `json:"send_back_raw_request"` // Send raw request back in the bifrost response (default: false)
SendBackRawResponse bool `json:"send_back_raw_response"` // Send raw response back in the bifrost response (default: false)
StoreRawRequestResponse bool `json:"store_raw_request_response"` // Capture raw request/response for internal logging only; strip from API responses returned to clients (default: false)
CustomProviderConfig *CustomProviderConfig `json:"custom_provider_config,omitempty"`
OpenAIConfig *OpenAIConfig `json:"openai_config,omitempty"`
}
// OpenAIConfig holds OpenAI-specific provider configuration.
type OpenAIConfig struct {
DisableStore bool `json:"disable_store"` // When true, forces store=false on all outgoing OpenAI requests (default: false)
}
func (config *ProviderConfig) CheckAndSetDefaults() {
if config.ConcurrencyAndBufferSize.Concurrency == 0 {
config.ConcurrencyAndBufferSize.Concurrency = DefaultConcurrency
}
if config.ConcurrencyAndBufferSize.BufferSize == 0 {
config.ConcurrencyAndBufferSize.BufferSize = DefaultBufferSize
}
if config.NetworkConfig.DefaultRequestTimeoutInSeconds <= 0 {
config.NetworkConfig.DefaultRequestTimeoutInSeconds = DefaultRequestTimeoutInSeconds
}
if config.NetworkConfig.MaxRetries == 0 {
config.NetworkConfig.MaxRetries = DefaultMaxRetries
}
if config.NetworkConfig.RetryBackoffInitial == 0 {
config.NetworkConfig.RetryBackoffInitial = DefaultRetryBackoffInitial
}
if config.NetworkConfig.RetryBackoffMax == 0 {
config.NetworkConfig.RetryBackoffMax = DefaultRetryBackoffMax
}
if config.NetworkConfig.StreamIdleTimeoutInSeconds <= 0 {
config.NetworkConfig.StreamIdleTimeoutInSeconds = DefaultStreamIdleTimeoutInSeconds
}
if config.NetworkConfig.MaxConnsPerHost <= 0 {
config.NetworkConfig.MaxConnsPerHost = DefaultMaxConnsPerHost
} else if config.NetworkConfig.MaxConnsPerHost > MaxConnsPerHostUpperBound {
config.NetworkConfig.MaxConnsPerHost = MaxConnsPerHostUpperBound
}
// Create a defensive copy of ExtraHeaders to prevent data races
if config.NetworkConfig.ExtraHeaders != nil {
headersCopy := make(map[string]string, len(config.NetworkConfig.ExtraHeaders))
maps.Copy(headersCopy, config.NetworkConfig.ExtraHeaders)
config.NetworkConfig.ExtraHeaders = headersCopy
}
// Create a defensive copy of BetaHeaderOverrides to prevent data races
if config.NetworkConfig.BetaHeaderOverrides != nil {
overridesCopy := make(map[string]bool, len(config.NetworkConfig.BetaHeaderOverrides))
maps.Copy(overridesCopy, config.NetworkConfig.BetaHeaderOverrides)
config.NetworkConfig.BetaHeaderOverrides = overridesCopy
}
}
type PostHookRunner func(ctx *BifrostContext, result *BifrostResponse, err *BifrostError) (*BifrostResponse, *BifrostError)
// Provider defines the interface for AI model providers.
type Provider interface {
// GetProviderKey returns the provider's identifier
GetProviderKey() ModelProvider
// ListModels performs a list models request
ListModels(ctx *BifrostContext, keys []Key, request *BifrostListModelsRequest) (*BifrostListModelsResponse, *BifrostError)
// TextCompletion performs a text completion request
TextCompletion(ctx *BifrostContext, key Key, request *BifrostTextCompletionRequest) (*BifrostTextCompletionResponse, *BifrostError)
// TextCompletionStream performs a text completion stream request.
// postHookSpanFinalizer is invoked by the provider's stream goroutine on stream completion
// (or on its panic-recovery defer) to finalize aggregated post-hook spans and release the
// per-attempt plugin pipeline. Pass nil if the caller does not need finalization.
TextCompletionStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key, request *BifrostTextCompletionRequest) (chan *BifrostStreamChunk, *BifrostError)
// ChatCompletion performs a chat completion request
ChatCompletion(ctx *BifrostContext, key Key, request *BifrostChatRequest) (*BifrostChatResponse, *BifrostError)
// ChatCompletionStream performs a chat completion stream request
ChatCompletionStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key, request *BifrostChatRequest) (chan *BifrostStreamChunk, *BifrostError)
// Responses performs a completion request using the Responses API (uses chat completion request internally for non-openai providers)
Responses(ctx *BifrostContext, key Key, request *BifrostResponsesRequest) (*BifrostResponsesResponse, *BifrostError)
// ResponsesStream performs a completion request using the Responses API stream (uses chat completion stream request internally for non-openai providers)
ResponsesStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key, request *BifrostResponsesRequest) (chan *BifrostStreamChunk, *BifrostError)
// CountTokens performs a count tokens request
CountTokens(ctx *BifrostContext, key Key, request *BifrostResponsesRequest) (*BifrostCountTokensResponse, *BifrostError)
// Embedding performs an embedding request
Embedding(ctx *BifrostContext, key Key, request *BifrostEmbeddingRequest) (*BifrostEmbeddingResponse, *BifrostError)
// Rerank performs a rerank request to reorder documents by relevance to a query
Rerank(ctx *BifrostContext, key Key, request *BifrostRerankRequest) (*BifrostRerankResponse, *BifrostError)
// OCR performs an optical character recognition request on a document
OCR(ctx *BifrostContext, key Key, request *BifrostOCRRequest) (*BifrostOCRResponse, *BifrostError)
// Speech performs a text to speech request
Speech(ctx *BifrostContext, key Key, request *BifrostSpeechRequest) (*BifrostSpeechResponse, *BifrostError)
// SpeechStream performs a text to speech stream request
SpeechStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key, request *BifrostSpeechRequest) (chan *BifrostStreamChunk, *BifrostError)
// Transcription performs a transcription request
Transcription(ctx *BifrostContext, key Key, request *BifrostTranscriptionRequest) (*BifrostTranscriptionResponse, *BifrostError)
// TranscriptionStream performs a transcription stream request
TranscriptionStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key, request *BifrostTranscriptionRequest) (chan *BifrostStreamChunk, *BifrostError)
// ImageGeneration performs an image generation request
ImageGeneration(ctx *BifrostContext, key Key, request *BifrostImageGenerationRequest) (
*BifrostImageGenerationResponse, *BifrostError)
// ImageGenerationStream performs an image generation stream request
ImageGenerationStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key,
request *BifrostImageGenerationRequest) (chan *BifrostStreamChunk, *BifrostError)
// ImageEdit performs an image edit request
ImageEdit(ctx *BifrostContext, key Key, request *BifrostImageEditRequest) (*BifrostImageGenerationResponse, *BifrostError)
// ImageEditStream performs an image edit stream request
ImageEditStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key,
request *BifrostImageEditRequest) (chan *BifrostStreamChunk, *BifrostError)
// ImageVariation performs an image variation request
ImageVariation(ctx *BifrostContext, key Key, request *BifrostImageVariationRequest) (*BifrostImageGenerationResponse, *BifrostError)
// VideoGeneration performs a video generation request
VideoGeneration(ctx *BifrostContext, key Key, request *BifrostVideoGenerationRequest) (*BifrostVideoGenerationResponse, *BifrostError)
// VideoRetrieve retrieves a video from the provider
VideoRetrieve(ctx *BifrostContext, key Key, request *BifrostVideoRetrieveRequest) (*BifrostVideoGenerationResponse, *BifrostError)
// VideoDownload downloads a video from the provider
VideoDownload(ctx *BifrostContext, key Key, request *BifrostVideoDownloadRequest) (*BifrostVideoDownloadResponse, *BifrostError)
// VideoDelete deletes a video from the provider
VideoDelete(ctx *BifrostContext, key Key, request *BifrostVideoDeleteRequest) (*BifrostVideoDeleteResponse, *BifrostError)
// VideoList lists videos from the provider
VideoList(ctx *BifrostContext, key Key, request *BifrostVideoListRequest) (*BifrostVideoListResponse, *BifrostError)
// VideoRemix remixes a video from the provider
VideoRemix(ctx *BifrostContext, key Key, request *BifrostVideoRemixRequest) (*BifrostVideoGenerationResponse, *BifrostError)
// BatchCreate creates a new batch job for asynchronous processing
BatchCreate(ctx *BifrostContext, key Key, request *BifrostBatchCreateRequest) (*BifrostBatchCreateResponse, *BifrostError)
// BatchList lists batch jobs
BatchList(ctx *BifrostContext, keys []Key, request *BifrostBatchListRequest) (*BifrostBatchListResponse, *BifrostError)
// BatchRetrieve retrieves a specific batch job
BatchRetrieve(ctx *BifrostContext, keys []Key, request *BifrostBatchRetrieveRequest) (*BifrostBatchRetrieveResponse, *BifrostError)
// BatchCancel cancels a batch job
BatchCancel(ctx *BifrostContext, keys []Key, request *BifrostBatchCancelRequest) (*BifrostBatchCancelResponse, *BifrostError)
// BatchDelete deletes a batch job
BatchDelete(ctx *BifrostContext, keys []Key, request *BifrostBatchDeleteRequest) (*BifrostBatchDeleteResponse, *BifrostError)
// BatchResults retrieves results from a completed batch job
BatchResults(ctx *BifrostContext, keys []Key, request *BifrostBatchResultsRequest) (*BifrostBatchResultsResponse, *BifrostError)
// FileUpload uploads a file to the provider
FileUpload(ctx *BifrostContext, key Key, request *BifrostFileUploadRequest) (*BifrostFileUploadResponse, *BifrostError)
// FileList lists files from the provider
FileList(ctx *BifrostContext, keys []Key, request *BifrostFileListRequest) (*BifrostFileListResponse, *BifrostError)
// FileRetrieve retrieves file metadata from the provider
FileRetrieve(ctx *BifrostContext, keys []Key, request *BifrostFileRetrieveRequest) (*BifrostFileRetrieveResponse, *BifrostError)
// FileDelete deletes a file from the provider
FileDelete(ctx *BifrostContext, keys []Key, request *BifrostFileDeleteRequest) (*BifrostFileDeleteResponse, *BifrostError)
// FileContent downloads file content from the provider
FileContent(ctx *BifrostContext, keys []Key, request *BifrostFileContentRequest) (*BifrostFileContentResponse, *BifrostError)
// ContainerCreate creates a new container
ContainerCreate(ctx *BifrostContext, key Key, request *BifrostContainerCreateRequest) (*BifrostContainerCreateResponse, *BifrostError)
// ContainerList lists containers
ContainerList(ctx *BifrostContext, keys []Key, request *BifrostContainerListRequest) (*BifrostContainerListResponse, *BifrostError)
// ContainerRetrieve retrieves a specific container
ContainerRetrieve(ctx *BifrostContext, keys []Key, request *BifrostContainerRetrieveRequest) (*BifrostContainerRetrieveResponse, *BifrostError)
// ContainerDelete deletes a container
ContainerDelete(ctx *BifrostContext, keys []Key, request *BifrostContainerDeleteRequest) (*BifrostContainerDeleteResponse, *BifrostError)
// ContainerFileCreate creates a file in a container
ContainerFileCreate(ctx *BifrostContext, key Key, request *BifrostContainerFileCreateRequest) (*BifrostContainerFileCreateResponse, *BifrostError)
// ContainerFileList lists files in a container
ContainerFileList(ctx *BifrostContext, keys []Key, request *BifrostContainerFileListRequest) (*BifrostContainerFileListResponse, *BifrostError)
// ContainerFileRetrieve retrieves a file from a container
ContainerFileRetrieve(ctx *BifrostContext, keys []Key, request *BifrostContainerFileRetrieveRequest) (*BifrostContainerFileRetrieveResponse, *BifrostError)
// ContainerFileContent retrieves the content of a file from a container
ContainerFileContent(ctx *BifrostContext, keys []Key, request *BifrostContainerFileContentRequest) (*BifrostContainerFileContentResponse, *BifrostError)
// ContainerFileDelete deletes a file from a container
ContainerFileDelete(ctx *BifrostContext, keys []Key, request *BifrostContainerFileDeleteRequest) (*BifrostContainerFileDeleteResponse, *BifrostError)
// Passthrough executes a non-streaming passthrough; body is fully buffered.
Passthrough(ctx *BifrostContext, key Key, req *BifrostPassthroughRequest) (*BifrostPassthroughResponse, *BifrostError)
// PassthroughStream executes a streaming passthrough, forwarding raw response bytes as BifrostStreamChunks.
PassthroughStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key, req *BifrostPassthroughRequest) (chan *BifrostStreamChunk, *BifrostError)
}
// WebSocketCapableProvider is an optional interface that providers can implement
// to indicate support for the OpenAI Responses API WebSocket Mode.
// Checked via type assertion: provider.(WebSocketCapableProvider).
// Providers that implement this interface will have native WS upstream connections
// instead of the HTTP bridge fallback for Responses WS mode.
type WebSocketCapableProvider interface {
// SupportsWebSocketMode returns true if the provider supports the Responses API WebSocket Mode.
SupportsWebSocketMode() bool
// WebSocketResponsesURL returns the WebSocket URL for the Responses API.
WebSocketResponsesURL(key Key) string
// WebSocketHeaders returns the headers required for the upstream WebSocket connection.
WebSocketHeaders(key Key) map[string]string
}

310
core/schemas/realtime.go Normal file
View File

@@ -0,0 +1,310 @@
package schemas
import "encoding/json"
// RealtimeEventType represents the type of a Bifrost unified Realtime event.
type RealtimeEventType string
// Client-to-server event types (sent by the client through Bifrost)
const (
RTEventSessionUpdate RealtimeEventType = "session.update"
RTEventConversationItemCreate RealtimeEventType = "conversation.item.create"
RTEventConversationItemDelete RealtimeEventType = "conversation.item.delete"
RTEventResponseCreate RealtimeEventType = "response.create"
RTEventResponseCancel RealtimeEventType = "response.cancel"
RTEventInputAudioAppend RealtimeEventType = "input_audio_buffer.append"
RTEventInputAudioCommit RealtimeEventType = "input_audio_buffer.commit"
RTEventInputAudioClear RealtimeEventType = "input_audio_buffer.clear"
)
// Server-to-client event types (received from the provider, forwarded to client)
const (
RTEventSessionCreated RealtimeEventType = "session.created"
RTEventSessionUpdated RealtimeEventType = "session.updated"
RTEventConversationCreated RealtimeEventType = "conversation.created"
RTEventConversationItemAdded RealtimeEventType = "conversation.item.added"
RTEventConversationItemCreated RealtimeEventType = "conversation.item.created"
RTEventConversationItemRetrieved RealtimeEventType = "conversation.item.retrieved"
RTEventConversationItemDone RealtimeEventType = "conversation.item.done"
RTEventResponseCreated RealtimeEventType = "response.created"
RTEventResponseDone RealtimeEventType = "response.done"
RTEventResponseTextDelta RealtimeEventType = "response.text.delta"
RTEventResponseTextDone RealtimeEventType = "response.text.done"
RTEventResponseAudioDelta RealtimeEventType = "response.audio.delta"
RTEventResponseAudioDone RealtimeEventType = "response.audio.done"
RTEventResponseAudioTransDelta RealtimeEventType = "response.audio_transcript.delta"
RTEventResponseAudioTransDone RealtimeEventType = "response.audio_transcript.done"
RTEventResponseOutputItemAdded RealtimeEventType = "response.output_item.added"
RTEventResponseOutputItemDone RealtimeEventType = "response.output_item.done"
RTEventResponseContentPartAdded RealtimeEventType = "response.content_part.added"
RTEventResponseContentPartDone RealtimeEventType = "response.content_part.done"
RTEventRateLimitsUpdated RealtimeEventType = "rate_limits.updated"
RTEventInputAudioTransCompleted RealtimeEventType = "conversation.item.input_audio_transcription.completed"
RTEventInputAudioTransDelta RealtimeEventType = "conversation.item.input_audio_transcription.delta"
RTEventInputAudioTransFailed RealtimeEventType = "conversation.item.input_audio_transcription.failed"
RTEventInputAudioBufferCommitted RealtimeEventType = "input_audio_buffer.committed"
RTEventInputAudioBufferCleared RealtimeEventType = "input_audio_buffer.cleared"
RTEventInputAudioSpeechStarted RealtimeEventType = "input_audio_buffer.speech_started"
RTEventInputAudioSpeechStopped RealtimeEventType = "input_audio_buffer.speech_stopped"
RTEventError RealtimeEventType = "error"
)
// IsRealtimeConversationItemEventType reports whether the event carries a
// canonical conversation item payload after provider translation.
func IsRealtimeConversationItemEventType(eventType RealtimeEventType) bool {
switch eventType {
case RTEventConversationItemCreate,
RTEventConversationItemAdded,
RTEventConversationItemCreated,
RTEventConversationItemRetrieved,
RTEventConversationItemDone:
return true
default:
return false
}
}
// IsRealtimeUserInputEvent reports whether the event represents a finalized
// user input item in the canonical Bifrost realtime schema.
func IsRealtimeUserInputEvent(event *BifrostRealtimeEvent) bool {
return event != nil &&
event.Item != nil &&
event.Item.Role == "user" &&
IsRealtimeConversationItemEventType(event.Type)
}
// IsRealtimeToolOutputEvent reports whether the event represents a finalized
// tool output item in the canonical Bifrost realtime schema.
func IsRealtimeToolOutputEvent(event *BifrostRealtimeEvent) bool {
return event != nil &&
event.Item != nil &&
event.Item.Type == "function_call_output" &&
IsRealtimeConversationItemEventType(event.Type)
}
// IsRealtimeInputTranscriptEvent reports whether the event carries a finalized
// input-audio transcript in the canonical Bifrost realtime schema.
func IsRealtimeInputTranscriptEvent(event *BifrostRealtimeEvent) bool {
return event != nil && event.Type == RTEventInputAudioTransCompleted
}
// BifrostRealtimeEvent is the unified Bifrost envelope for all Realtime events.
// Provider converters translate between this format and the provider-native protocol.
type BifrostRealtimeEvent struct {
Type RealtimeEventType `json:"type"`
EventID string `json:"event_id,omitempty"`
Session *RealtimeSession `json:"session,omitempty"`
Item *RealtimeItem `json:"item,omitempty"`
Delta *RealtimeDelta `json:"delta,omitempty"`
Audio []byte `json:"audio,omitempty"`
Error *RealtimeError `json:"error,omitempty"`
// ExtraParams preserves provider-specific top-level event fields that are not
// promoted into the common Bifrost schema.
ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"`
// RawData preserves the original provider event for pass-through or debugging.
RawData json.RawMessage `json:"raw_data,omitempty"`
}
// RealtimeSession describes session configuration for the Realtime connection.
type RealtimeSession struct {
ID string `json:"id,omitempty"`
Model string `json:"model,omitempty"`
Modalities []string `json:"modalities,omitempty"`
Instructions string `json:"instructions,omitempty"`
Voice string `json:"voice,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
MaxOutputTokens json.RawMessage `json:"max_output_tokens,omitempty"`
TurnDetection json.RawMessage `json:"turn_detection,omitempty"`
InputAudioFormat string `json:"input_audio_format,omitempty"`
OutputAudioType string `json:"output_audio_type,omitempty"`
Tools json.RawMessage `json:"tools,omitempty"`
ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"`
}
// RealtimeItem represents a conversation item in the Realtime protocol.
type RealtimeItem struct {
ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"`
Role string `json:"role,omitempty"`
Status string `json:"status,omitempty"`
Content json.RawMessage `json:"content,omitempty"`
Name string `json:"name,omitempty"`
CallID string `json:"call_id,omitempty"`
Arguments string `json:"arguments,omitempty"`
Output string `json:"output,omitempty"`
ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"`
}
// RealtimeDelta carries incremental content for streaming events.
type RealtimeDelta struct {
Text string `json:"text,omitempty"`
Audio string `json:"audio,omitempty"`
Transcript string `json:"transcript,omitempty"`
ItemID string `json:"item_id,omitempty"`
OutputIdx *int `json:"output_index,omitempty"`
ContentIdx *int `json:"content_index,omitempty"`
ResponseID string `json:"response_id,omitempty"`
}
// RealtimeError describes an error from the Realtime API.
type RealtimeError struct {
Type string `json:"type,omitempty"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
Param string `json:"param,omitempty"`
ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"`
}
// RealtimeSessionEndpointType identifies the public ephemeral-token endpoint
// shape the client called so providers can preserve versioned behavior.
type RealtimeSessionEndpointType string
const (
RealtimeSessionEndpointClientSecrets RealtimeSessionEndpointType = "client_secrets"
RealtimeSessionEndpointSessions RealtimeSessionEndpointType = "sessions"
)
// RealtimeSessionRoute describes a provider-registered public route for
// ephemeral-token creation.
type RealtimeSessionRoute struct {
Path string
EndpointType RealtimeSessionEndpointType
DefaultProvider ModelProvider
}
// RealtimeProvider is an optional interface that providers can implement to
// indicate support for bidirectional Realtime API (audio/text streaming).
// Checked via type assertion: provider.(RealtimeProvider).
type RealtimeProvider interface {
SupportsRealtimeAPI() bool
RealtimeWebSocketURL(key Key, model string) string
RealtimeHeaders(key Key) map[string]string
// SupportsRealtimeWebRTC reports whether the provider supports WebRTC SDP exchange.
SupportsRealtimeWebRTC() bool
// ExchangeRealtimeWebRTCSDP performs the provider-specific SDP signaling exchange.
// The provider owns the HTTP specifics (URL, headers, body format).
// session may be nil if the signaling format doesn't include session config.
ExchangeRealtimeWebRTCSDP(ctx *BifrostContext, key Key, model string, sdp string, session json.RawMessage) (string, *BifrostError)
ToBifrostRealtimeEvent(providerEvent json.RawMessage) (*BifrostRealtimeEvent, error)
ToProviderRealtimeEvent(bifrostEvent *BifrostRealtimeEvent) (json.RawMessage, error)
// ShouldStartRealtimeTurn reports whether the canonical client-side event
// should start pre-hooks. Providers without an explicit turn-start signal
// return false and rely on finalize-time fallback hooks.
ShouldStartRealtimeTurn(event *BifrostRealtimeEvent) bool
// RealtimeTurnFinalEvent returns the canonical provider event that completes
// a turn and should trigger post-hooks.
RealtimeTurnFinalEvent() RealtimeEventType
RealtimeWebRTCDataChannelLabel() string
RealtimeWebSocketSubprotocol() string
ShouldForwardRealtimeEvent(event *BifrostRealtimeEvent) bool
ShouldAccumulateRealtimeOutput(eventType RealtimeEventType) bool
}
// RealtimeLegacyWebRTCProvider is an optional interface for providers that
// support the beta WebRTC handshake (e.g., OpenAI's /v1/realtime).
// Only checked for legacy integration routes via type assertion.
// Takes SDP offer + optional session JSON, same as ExchangeRealtimeWebRTCSDP
// but targets the provider's legacy/beta endpoint.
type RealtimeLegacyWebRTCProvider interface {
ExchangeLegacyRealtimeWebRTCSDP(ctx *BifrostContext, key Key, sdp string, session json.RawMessage, model string) (string, *BifrostError)
}
// RealtimeUsageExtractor lets providers parse terminal-turn usage/output from
// their native wire payloads without coupling handlers to a specific protocol.
type RealtimeUsageExtractor interface {
ExtractRealtimeTurnUsage(terminalEventRaw []byte) *BifrostLLMUsage
ExtractRealtimeTurnOutput(terminalEventRaw []byte) *ChatMessage
}
// RealtimeSessionProvider is an optional interface for providers that can mint
// short-lived client secrets for browser/client-side Realtime connections.
// Checked via type assertion: provider.(RealtimeSessionProvider).
type RealtimeSessionProvider interface {
CreateRealtimeClientSecret(ctx *BifrostContext, key Key, endpointType RealtimeSessionEndpointType, rawRequest json.RawMessage) (*BifrostPassthroughResponse, *BifrostError)
}
// ParseRealtimeEvent decodes a client/provider realtime event while preserving
// unknown top-level fields in ExtraParams for provider-specific round-tripping.
func ParseRealtimeEvent(raw []byte) (*BifrostRealtimeEvent, error) {
type realtimeEventAlias struct {
Type RealtimeEventType `json:"type"`
EventID string `json:"event_id,omitempty"`
Session *RealtimeSession `json:"session,omitempty"`
Item *RealtimeItem `json:"item,omitempty"`
Delta *RealtimeDelta `json:"delta,omitempty"`
Audio []byte `json:"audio,omitempty"`
Error *RealtimeError `json:"error,omitempty"`
}
var alias realtimeEventAlias
if err := Unmarshal(raw, &alias); err != nil {
return nil, err
}
event := &BifrostRealtimeEvent{
Type: alias.Type,
EventID: alias.EventID,
Session: alias.Session,
Item: alias.Item,
Delta: alias.Delta,
Audio: alias.Audio,
Error: alias.Error,
}
var root map[string]json.RawMessage
if err := Unmarshal(raw, &root); err != nil {
return nil, err
}
savedSession := root["session"]
savedItem := root["item"]
savedError := root["error"]
for _, key := range []string{"type", "event_id", "session", "item", "delta", "audio", "error", "raw_data"} {
delete(root, key)
}
if len(root) > 0 {
event.ExtraParams = root
}
if event.Session != nil {
var sessionRoot map[string]json.RawMessage
if len(savedSession) > 0 && Unmarshal(savedSession, &sessionRoot) == nil {
for _, key := range []string{
"id", "model", "modalities", "instructions", "voice", "temperature",
"max_output_tokens", "turn_detection", "input_audio_format", "output_audio_type", "tools",
} {
delete(sessionRoot, key)
}
if len(sessionRoot) > 0 {
event.Session.ExtraParams = sessionRoot
}
}
}
if event.Item != nil {
var itemRoot map[string]json.RawMessage
if len(savedItem) > 0 && Unmarshal(savedItem, &itemRoot) == nil {
for _, key := range []string{
"id", "type", "role", "status", "content", "name", "call_id", "arguments", "output",
} {
delete(itemRoot, key)
}
if len(itemRoot) > 0 {
event.Item.ExtraParams = itemRoot
}
}
}
if event.Error != nil {
var errorRoot map[string]json.RawMessage
if len(savedError) > 0 && Unmarshal(savedError, &errorRoot) == nil {
for _, key := range []string{"type", "code", "message", "param"} {
delete(errorRoot, key)
}
if len(errorRoot) > 0 {
event.Error.ExtraParams = errorRoot
}
}
}
return event, nil
}

View File

@@ -0,0 +1,66 @@
package schemas
import (
"bytes"
"encoding/json"
"strings"
)
// ParseRealtimeClientSecretBody parses a realtime client-secret request body
// into a mutable raw JSON map while preserving unknown fields.
func ParseRealtimeClientSecretBody(raw json.RawMessage) (map[string]json.RawMessage, *BifrostError) {
var root map[string]json.RawMessage
if err := Unmarshal(raw, &root); err != nil {
return nil, NewRealtimeClientSecretBodyError(400, "invalid_request_error", "invalid JSON body", err)
}
return root, nil
}
// ExtractRealtimeClientSecretModel extracts the model from either session.model
// or the legacy top-level model field.
func ExtractRealtimeClientSecretModel(root map[string]json.RawMessage) (string, *BifrostError) {
if sessionJSON, ok := root["session"]; ok && len(sessionJSON) > 0 && !bytes.Equal(sessionJSON, []byte("null")) {
var session map[string]json.RawMessage
if err := Unmarshal(sessionJSON, &session); err != nil {
return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "session must be an object", err)
}
if modelJSON, ok := session["model"]; ok {
var sessionModel string
if err := Unmarshal(modelJSON, &sessionModel); err != nil {
return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "session.model must be a string", err)
}
if strings.TrimSpace(sessionModel) != "" {
return strings.TrimSpace(sessionModel), nil
}
}
}
if modelJSON, ok := root["model"]; ok {
var model string
if err := Unmarshal(modelJSON, &model); err != nil {
return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "model must be a string", err)
}
if strings.TrimSpace(model) != "" {
return strings.TrimSpace(model), nil
}
}
return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "session.model or model is required", nil)
}
// NewRealtimeClientSecretBodyError builds a standard invalid-request style error
// for HTTP realtime client-secret request parsing/validation.
func NewRealtimeClientSecretBodyError(status int, errorType, message string, err error) *BifrostError {
return &BifrostError{
IsBifrostError: false,
StatusCode: Ptr(status),
Error: &ErrorField{
Type: Ptr(errorType),
Message: message,
Error: err,
},
ExtraFields: BifrostErrorExtraFields{
RequestType: RealtimeRequest,
},
}
}

View File

@@ -0,0 +1,40 @@
package schemas
import (
"encoding/json"
"testing"
)
func TestExtractRealtimeClientSecretModel(t *testing.T) {
t.Parallel()
root, err := ParseRealtimeClientSecretBody(json.RawMessage(`{"session":{"model":"openai/gpt-4o-realtime-preview"}}`))
if err != nil {
t.Fatalf("ParseRealtimeClientSecretBody() error = %v", err)
}
model, err := ExtractRealtimeClientSecretModel(root)
if err != nil {
t.Fatalf("ExtractRealtimeClientSecretModel() error = %v", err)
}
if model != "openai/gpt-4o-realtime-preview" {
t.Fatalf("model = %q, want %q", model, "openai/gpt-4o-realtime-preview")
}
}
func TestExtractRealtimeClientSecretModelFallbackTopLevel(t *testing.T) {
t.Parallel()
root, err := ParseRealtimeClientSecretBody(json.RawMessage(`{"model":"gpt-4o-realtime-preview"}`))
if err != nil {
t.Fatalf("ParseRealtimeClientSecretBody() error = %v", err)
}
model, err := ExtractRealtimeClientSecretModel(root)
if err != nil {
t.Fatalf("ExtractRealtimeClientSecretModel() error = %v", err)
}
if model != "gpt-4o-realtime-preview" {
t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview")
}
}

View File

@@ -0,0 +1,68 @@
package schemas
import "testing"
func TestIsRealtimeConversationItemEventType(t *testing.T) {
t.Parallel()
tests := []struct {
name string
eventType RealtimeEventType
want bool
}{
{name: "create", eventType: RTEventConversationItemCreate, want: true},
{name: "added", eventType: RTEventConversationItemAdded, want: true},
{name: "created", eventType: RTEventConversationItemCreated, want: true},
{name: "retrieved", eventType: RTEventConversationItemRetrieved, want: true},
{name: "done", eventType: RTEventConversationItemDone, want: true},
{name: "response done", eventType: RTEventResponseDone, want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := IsRealtimeConversationItemEventType(tt.eventType); got != tt.want {
t.Fatalf("IsRealtimeConversationItemEventType(%q) = %v, want %v", tt.eventType, got, tt.want)
}
})
}
}
func TestRealtimeCanonicalEventClassifiers(t *testing.T) {
t.Parallel()
userEvent := &BifrostRealtimeEvent{
Type: RTEventConversationItemAdded,
Item: &RealtimeItem{
Role: "user",
Type: "message",
},
}
if !IsRealtimeUserInputEvent(userEvent) {
t.Fatal("expected conversation.item.added user event to be classified as realtime user input")
}
if IsRealtimeToolOutputEvent(userEvent) {
t.Fatal("did not expect conversation.item.added user event to be classified as realtime tool output")
}
toolEvent := &BifrostRealtimeEvent{
Type: RTEventConversationItemRetrieved,
Item: &RealtimeItem{
Type: "function_call_output",
},
}
if !IsRealtimeToolOutputEvent(toolEvent) {
t.Fatal("expected function_call_output item to be classified as realtime tool output")
}
if IsRealtimeUserInputEvent(toolEvent) {
t.Fatal("did not expect function_call_output item to be classified as realtime user input")
}
transcriptEvent := &BifrostRealtimeEvent{Type: RTEventInputAudioTransCompleted}
if !IsRealtimeInputTranscriptEvent(transcriptEvent) {
t.Fatal("expected input audio transcription completion to be classified as transcript event")
}
if IsRealtimeInputTranscriptEvent(&BifrostRealtimeEvent{Type: RTEventInputAudioTransDelta}) {
t.Fatal("did not expect input audio transcription delta to be classified as transcript event")
}
}

49
core/schemas/rerank.go Normal file
View File

@@ -0,0 +1,49 @@
package schemas
// RerankDocument represents a document to be reranked.
type RerankDocument struct {
Text string `json:"text"`
ID *string `json:"id,omitempty"`
Meta map[string]interface{} `json:"meta,omitempty"`
}
// RerankParameters contains optional parameters for a rerank request.
type RerankParameters struct {
TopN *int `json:"top_n,omitempty"`
MaxTokensPerDoc *int `json:"max_tokens_per_doc,omitempty"`
Priority *int `json:"priority,omitempty"`
ReturnDocuments *bool `json:"return_documents,omitempty"`
ExtraParams map[string]interface{} `json:"-"`
}
// BifrostRerankRequest represents a request to rerank documents by relevance to a query.
type BifrostRerankRequest struct {
Provider ModelProvider `json:"provider"`
Model string `json:"model"`
Query string `json:"query"`
Documents []RerankDocument `json:"documents"`
Params *RerankParameters `json:"params,omitempty"`
Fallbacks []Fallback `json:"fallbacks,omitempty"`
RawRequestBody []byte `json:"-"`
}
// GetRawRequestBody returns the raw request body for the rerank request.
func (r *BifrostRerankRequest) GetRawRequestBody() []byte {
return r.RawRequestBody
}
// RerankResult represents a single reranked document with its relevance score.
type RerankResult struct {
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
Document *RerankDocument `json:"document,omitempty"`
}
// BifrostRerankResponse represents the response from a rerank request.
type BifrostRerankResponse struct {
ID string `json:"id,omitempty"`
Results []RerankResult `json:"results"`
Model string `json:"model"`
Usage *BifrostLLMUsage `json:"usage,omitempty"`
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}

2345
core/schemas/responses.go Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

169
core/schemas/speech.go Normal file
View File

@@ -0,0 +1,169 @@
package schemas
import (
"fmt"
"unicode/utf8"
)
type BifrostSpeechRequest struct {
Provider ModelProvider `json:"provider"`
Model string `json:"model"`
Input *SpeechInput `json:"input,omitempty"`
Params *SpeechParameters `json:"params,omitempty"`
Fallbacks []Fallback `json:"fallbacks,omitempty"`
RawRequestBody []byte `json:"-"` // set bifrost-use-raw-request-body to true in ctx to use the raw request body. Bifrost will directly send this to the downstream provider.
}
func (r *BifrostSpeechRequest) GetRawRequestBody() []byte {
return r.RawRequestBody
}
type BifrostSpeechResponse struct {
Audio []byte `json:"audio"`
Usage *SpeechUsage `json:"usage"`
Alignment *SpeechAlignment `json:"alignment,omitempty"` // Character-level timing information
NormalizedAlignment *SpeechAlignment `json:"normalized_alignment,omitempty"` // Character-level timing information for normalized text
AudioBase64 *string `json:"audio_base64,omitempty"` // Base64-encoded audio (when timestamps are requested)
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}
func (r *BifrostSpeechResponse) BackfillParams(request *BifrostSpeechRequest) {
if r == nil || request == nil || request.Input == nil {
return
}
if r.Usage == nil {
r.Usage = &SpeechUsage{}
}
r.Usage.InputChars = utf8.RuneCountInString(request.Input.Input)
}
// SpeechAlignment represents character-level timing information for audio-text synchronization
type SpeechAlignment struct {
CharStartTimesMs []float64 `json:"char_start_times_ms"` // Start time in milliseconds for each character
CharEndTimesMs []float64 `json:"char_end_times_ms"` // End time in milliseconds for each character
Characters []string `json:"characters"` // Characters corresponding to timing info
}
// SpeechInput represents the input for a speech request.
type SpeechInput struct {
Input string `json:"input"`
}
type SpeechParameters struct {
VoiceConfig *SpeechVoiceInput `json:"voice"`
Instructions string `json:"instructions,omitempty"`
ResponseFormat string `json:"response_format,omitempty"` // Default is "mp3"
Speed *float64 `json:"speed,omitempty"`
LanguageCode *string `json:"language_code,omitempty"`
PronunciationDictionaryLocators []SpeechPronunciationDictionaryLocator `json:"pronunciation_dictionary_locators,omitempty"`
EnableLogging *bool `json:"enable_logging,omitempty"`
OptimizeStreamingLatency *bool `json:"optimize_streaming_latency,omitempty"`
WithTimestamps *bool `json:"with_timestamps,omitempty"` // Returns character-level timing information
// Dynamic parameters that can be provider-specific, they are directly
// added to the request as is.
ExtraParams map[string]interface{} `json:"-"`
}
type SpeechPronunciationDictionaryLocator struct {
PronunciationDictionaryID string `json:"pronunciation_dictionary_id"`
VersionID *string `json:"version_id,omitempty"`
}
type SpeechVoiceInput struct {
Voice *string
MultiVoiceConfig []VoiceConfig
}
type VoiceConfig struct {
Speaker string `json:"speaker"`
Voice string `json:"voice"`
}
// MarshalJSON implements custom JSON marshalling for SpeechVoiceInput.
// It marshals either Voice or MultiVoiceConfig directly without wrapping.
func (vi *SpeechVoiceInput) MarshalJSON() ([]byte, error) {
// Validation: ensure only one field is set at a time
if vi.Voice != nil && len(vi.MultiVoiceConfig) > 0 {
return nil, fmt.Errorf("both Voice and MultiVoiceConfig are set; only one should be non-nil")
}
if vi.Voice != nil {
return MarshalSorted(*vi.Voice)
}
if len(vi.MultiVoiceConfig) > 0 {
return MarshalSorted(vi.MultiVoiceConfig)
}
// If both are nil, return null
return MarshalSorted(nil)
}
// UnmarshalJSON implements custom JSON unmarshalling for SpeechVoiceInput.
// It determines whether "voice" is a string or a VoiceConfig object/array and assigns to the appropriate field.
// It also handles direct string/array content without a wrapper object.
func (vi *SpeechVoiceInput) UnmarshalJSON(data []byte) error {
// Reset receiver state before attempting any decode to avoid stale data
vi.Voice = nil
vi.MultiVoiceConfig = nil
// First, try to unmarshal as a direct string
var stringContent string
if err := Unmarshal(data, &stringContent); err == nil {
vi.Voice = &stringContent
return nil
}
// Try to unmarshal as an array of VoiceConfig objects
var voiceConfigs []VoiceConfig
if err := Unmarshal(data, &voiceConfigs); err == nil {
// Validate each VoiceConfig and build a new slice deterministically
validConfigs := make([]VoiceConfig, 0, len(voiceConfigs))
for _, config := range voiceConfigs {
if config.Voice == "" {
return fmt.Errorf("voice config has empty voice field")
}
validConfigs = append(validConfigs, config)
}
vi.MultiVoiceConfig = validConfigs
return nil
}
return fmt.Errorf("voice field is neither a string, nor an array of VoiceConfig objects")
}
type SpeechStreamResponseType string
const (
SpeechStreamResponseTypeDelta SpeechStreamResponseType = "speech.audio.delta"
SpeechStreamResponseTypeDone SpeechStreamResponseType = "speech.audio.done"
)
type BifrostSpeechStreamResponse struct {
Type SpeechStreamResponseType `json:"type"`
Audio []byte `json:"audio"`
Usage *SpeechUsage `json:"usage"`
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}
func (r *BifrostSpeechStreamResponse) BackfillParams(request *BifrostSpeechRequest) {
if r == nil || request == nil || request.Input == nil {
return
}
if r.Usage == nil {
r.Usage = &SpeechUsage{}
}
r.Usage.InputChars = utf8.RuneCountInString(request.Input.Input)
}
type SpeechUsageInputTokenDetails struct {
TextTokens int `json:"text_tokens,omitempty"`
AudioTokens int `json:"audio_tokens,omitempty"`
}
type SpeechUsage struct {
InputTokens int `json:"input_tokens"`
InputChars int `json:"input_chars,omitempty"`
InputTokenDetails *SpeechUsageInputTokenDetails `json:"input_token_details,omitempty"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}

View File

@@ -0,0 +1,146 @@
package schemas
import (
"fmt"
)
// BifrostTextCompletionRequest is the request struct for text completion requests
type BifrostTextCompletionRequest struct {
Provider ModelProvider `json:"provider"`
Model string `json:"model"`
Input *TextCompletionInput `json:"input,omitempty"`
Params *TextCompletionParameters `json:"params,omitempty"`
Fallbacks []Fallback `json:"fallbacks,omitempty"`
RawRequestBody []byte `json:"-"` // set bifrost-use-raw-request-body to true in ctx to use the raw request body. Bifrost will directly send this to the downstream provider.
}
func (r *BifrostTextCompletionRequest) GetRawRequestBody() []byte {
return r.RawRequestBody
}
// ToBifrostChatRequest converts a Bifrost text completion request to a Bifrost chat completion request
// This method is discouraged to use, but is useful for litellm fallback flows
func (r *BifrostTextCompletionRequest) ToBifrostChatRequest() *BifrostChatRequest {
if r == nil || r.Input == nil {
return nil
}
message := ChatMessage{Role: ChatMessageRoleUser}
if r.Input.PromptStr != nil {
message.Content = &ChatMessageContent{
ContentStr: r.Input.PromptStr,
}
} else if len(r.Input.PromptArray) > 0 {
blocks := make([]ChatContentBlock, 0, len(r.Input.PromptArray))
for _, prompt := range r.Input.PromptArray {
blocks = append(blocks, ChatContentBlock{
Type: ChatContentBlockTypeText,
Text: &prompt,
})
}
message.Content = &ChatMessageContent{
ContentBlocks: blocks,
}
}
params := ChatParameters{}
if r.Params != nil {
params.MaxCompletionTokens = r.Params.MaxTokens
params.Temperature = r.Params.Temperature
params.TopP = r.Params.TopP
params.Stop = r.Params.Stop
params.ExtraParams = r.Params.ExtraParams
params.StreamOptions = r.Params.StreamOptions
params.User = r.Params.User
params.FrequencyPenalty = r.Params.FrequencyPenalty
params.LogitBias = r.Params.LogitBias
params.PresencePenalty = r.Params.PresencePenalty
params.Seed = r.Params.Seed
}
return &BifrostChatRequest{
Provider: r.Provider,
Model: r.Model,
Fallbacks: r.Fallbacks,
Input: []ChatMessage{message},
Params: &params,
}
}
type BifrostTextCompletionResponse struct {
ID string `json:"id"`
Choices []BifrostResponseChoice `json:"choices"`
Model string `json:"model"`
Object string `json:"object"` // "text_completion" (same for text completion stream)
SystemFingerprint string `json:"system_fingerprint"`
Usage *BifrostLLMUsage `json:"usage"`
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}
type TextCompletionInput struct {
PromptStr *string
PromptArray []string
}
func (t *TextCompletionInput) MarshalJSON() ([]byte, error) {
set := 0
if t.PromptStr != nil {
set++
}
if t.PromptArray != nil {
set++
}
if set == 0 {
return nil, fmt.Errorf("text completion input is empty")
}
if set > 1 {
return nil, fmt.Errorf("text completion input must set exactly one of: prompt_str or prompt_array")
}
if t.PromptStr != nil {
return MarshalSorted(*t.PromptStr)
}
return MarshalSorted(t.PromptArray)
}
func (t *TextCompletionInput) UnmarshalJSON(data []byte) error {
var prompt string
if err := Unmarshal(data, &prompt); err == nil {
t.PromptStr = &prompt
t.PromptArray = nil
return nil
}
var promptArray []string
if err := Unmarshal(data, &promptArray); err == nil {
t.PromptStr = nil
t.PromptArray = promptArray
return nil
}
return fmt.Errorf("invalid text completion input")
}
type TextCompletionParameters struct {
BestOf *int `json:"best_of,omitempty"`
Echo *bool `json:"echo,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
LogitBias *map[string]float64 `json:"logit_bias,omitempty"`
LogProbs *int `json:"logprobs,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"`
N *int `json:"n,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
Seed *int `json:"seed,omitempty"`
Stop []string `json:"stop,omitempty"`
Suffix *string `json:"suffix,omitempty"`
StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
User *string `json:"user,omitempty"`
// Dynamic parameters that can be provider-specific, they are directly
// added to the request as is.
ExtraParams map[string]interface{} `json:"-"`
}
// TextCompletionLogProb represents log probability information for text completion.
type TextCompletionLogProb struct {
TextOffset []int `json:"text_offset"`
TokenLogProbs []float64 `json:"token_logprobs"`
Tokens []string `json:"tokens"`
TopLogProbs []map[string]float64 `json:"top_logprobs"`
}

390
core/schemas/trace.go Normal file
View File

@@ -0,0 +1,390 @@
// Package schemas defines the core schemas and types used by the Bifrost system.
package schemas
import (
"sync"
"time"
)
// Trace represents a distributed trace that captures the full lifecycle of a request
type Trace struct {
RequestID string // Request ID for the trace
TraceID string // Unique identifier for this trace
ParentID string // Parent trace ID from incoming W3C traceparent header
RootSpan *Span // The root span of this trace
Spans []*Span // All spans in this trace
StartTime time.Time // When the trace started
EndTime time.Time // When the trace completed
Attributes map[string]any // Additional attributes for the trace
PluginLogs []PluginLogEntry // Plugin log entries accumulated during request processing
mu sync.Mutex // Mutex for thread-safe span operations
}
// AddSpan adds a span to the trace in a thread-safe manner
func (t *Trace) AddSpan(span *Span) {
t.mu.Lock()
defer t.mu.Unlock()
t.Spans = append(t.Spans, span)
}
// GetSpan retrieves a span by ID
func (t *Trace) GetSpan(spanID string) *Span {
t.mu.Lock()
defer t.mu.Unlock()
for _, span := range t.Spans {
if span.SpanID == spanID {
return span
}
}
return nil
}
// GetRequestID retrieves the request ID from the trace
func (t *Trace) GetRequestID() string {
t.mu.Lock()
defer t.mu.Unlock()
return t.RequestID
}
// SetRequestID sets the request ID for the trace
func (t *Trace) SetRequestID(requestID string) {
t.mu.Lock()
defer t.mu.Unlock()
t.RequestID = requestID
}
// Reset clears the trace for reuse from pool
func (t *Trace) Reset() {
t.mu.Lock()
defer t.mu.Unlock()
t.RequestID = ""
t.TraceID = ""
t.ParentID = ""
t.RootSpan = nil
for i := range t.Spans {
t.Spans[i] = nil
}
t.Spans = t.Spans[:0]
t.StartTime = time.Time{}
t.EndTime = time.Time{}
t.Attributes = nil
for i := range t.PluginLogs {
t.PluginLogs[i] = PluginLogEntry{}
}
t.PluginLogs = t.PluginLogs[:0]
}
// AppendPluginLogs appends plugin log entries to the trace in a thread-safe manner.
func (t *Trace) AppendPluginLogs(logs []PluginLogEntry) {
if len(logs) == 0 {
return
}
t.mu.Lock()
t.PluginLogs = append(t.PluginLogs, logs...)
t.mu.Unlock()
}
// Span represents a single operation within a trace
type Span struct {
SpanID string // Unique identifier for this span
ParentID string // Parent span ID (empty for root span)
TraceID string // The trace this span belongs to
Name string // Name of the operation
Kind SpanKind // Type of span (LLM call, plugin, etc.)
StartTime time.Time // When the span started
EndTime time.Time // When the span completed
Status SpanStatus // Status of the operation
StatusMsg string // Optional status message (for errors)
Attributes map[string]any // Additional attributes for the span
Events []SpanEvent // Events that occurred during the span
mu sync.Mutex // Mutex for thread-safe attribute operations
}
// SetAttribute sets an attribute on the span in a thread-safe manner
func (s *Span) SetAttribute(key string, value any) {
if value == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
if s.Attributes == nil {
s.Attributes = make(map[string]any)
}
s.Attributes[key] = value
}
// AddEvent adds an event to the span in a thread-safe manner
func (s *Span) AddEvent(event SpanEvent) {
s.mu.Lock()
defer s.mu.Unlock()
s.Events = append(s.Events, event)
}
// End marks the span as complete with the given status
func (s *Span) End(status SpanStatus, statusMsg string) {
s.mu.Lock()
defer s.mu.Unlock()
s.EndTime = time.Now()
s.Status = status
s.StatusMsg = statusMsg
}
// Reset clears the span for reuse from pool
func (s *Span) Reset() {
s.SpanID = ""
s.ParentID = ""
s.TraceID = ""
s.Name = ""
s.Kind = SpanKindUnspecified
s.StartTime = time.Time{}
s.EndTime = time.Time{}
s.Status = SpanStatusUnset
s.StatusMsg = ""
s.Attributes = nil
s.Events = s.Events[:0]
}
// SpanEvent represents a time-stamped event within a span
type SpanEvent struct {
Name string // Name of the event
Timestamp time.Time // When the event occurred
Attributes map[string]any // Additional attributes for the event
}
// SpanKind represents the type of operation a span represents
// These are LLM-specific kinds designed for AI gateway observability
type SpanKind string
const (
// SpanKindUnspecified is the default span kind
SpanKindUnspecified SpanKind = ""
// SpanKindLLMCall represents a call to an LLM provider
SpanKindLLMCall SpanKind = "llm.call"
// SpanKindPlugin represents plugin execution (PreLLMHook/PostLLMHook)
SpanKindPlugin SpanKind = "plugin"
// SpanKindMCPTool represents an MCP tool invocation
SpanKindMCPTool SpanKind = "mcp.tool"
// SpanKindRetry represents a retry attempt
SpanKindRetry SpanKind = "retry"
// SpanKindFallback represents a fallback to another provider
SpanKindFallback SpanKind = "fallback"
// SpanKindHTTPRequest represents the root HTTP request span
SpanKindHTTPRequest SpanKind = "http.request"
// SpanKindEmbedding represents an embedding request
SpanKindEmbedding SpanKind = "embedding"
// SpanKindSpeech represents a text-to-speech request
SpanKindSpeech SpanKind = "speech"
// SpanKindTranscription represents a speech-to-text request
SpanKindTranscription SpanKind = "transcription"
// SpanKindInternal represents internal operations (key selection, etc.)
SpanKindInternal SpanKind = "internal"
)
// SpanStatus represents the status of a span's operation
type SpanStatus string
const (
// SpanStatusUnset indicates status has not been set
SpanStatusUnset SpanStatus = "unset"
// SpanStatusOk indicates the operation completed successfully
SpanStatusOk SpanStatus = "ok"
// SpanStatusError indicates the operation failed
SpanStatusError SpanStatus = "error"
)
// LLM Attribute Keys (gen_ai.* namespace)
// These follow the OpenTelemetry semantic conventions for GenAI
// and are compatible with both OTEL and Datadog backends.
const (
// Provider and Model Attributes
AttrProviderName = "gen_ai.provider.name"
AttrRequestModel = "gen_ai.request.model"
// Request Parameter Attributes
AttrMaxTokens = "gen_ai.request.max_tokens"
AttrTemperature = "gen_ai.request.temperature"
AttrTopP = "gen_ai.request.top_p"
AttrStopSequences = "gen_ai.request.stop_sequences"
AttrPresencePenalty = "gen_ai.request.presence_penalty"
AttrFrequencyPenalty = "gen_ai.request.frequency_penalty"
AttrParallelToolCall = "gen_ai.request.parallel_tool_calls"
AttrRequestUser = "gen_ai.request.user"
AttrBestOf = "gen_ai.request.best_of"
AttrEcho = "gen_ai.request.echo"
AttrLogitBias = "gen_ai.request.logit_bias"
AttrLogProbs = "gen_ai.request.logprobs"
AttrN = "gen_ai.request.n"
AttrSeed = "gen_ai.request.seed"
AttrSuffix = "gen_ai.request.suffix"
AttrDimensions = "gen_ai.request.dimensions"
AttrEncodingFormat = "gen_ai.request.encoding_format"
AttrLanguage = "gen_ai.request.language"
AttrPrompt = "gen_ai.request.prompt"
AttrResponseFormat = "gen_ai.request.response_format"
AttrFormat = "gen_ai.request.format"
AttrVoice = "gen_ai.request.voice"
AttrMultiVoiceConfig = "gen_ai.request.multi_voice_config"
AttrInstructions = "gen_ai.request.instructions"
AttrSpeed = "gen_ai.request.speed"
AttrMessageCount = "gen_ai.request.message_count"
// Response Attributes
AttrResponseID = "gen_ai.response.id"
AttrResponseModel = "gen_ai.response.model"
AttrFinishReason = "gen_ai.response.finish_reason"
AttrSystemFprint = "gen_ai.response.system_fingerprint"
AttrServiceTier = "gen_ai.response.service_tier"
AttrCreated = "gen_ai.response.created"
AttrObject = "gen_ai.response.object"
AttrTimeToFirstToken = "gen_ai.response.time_to_first_token"
AttrTotalChunks = "gen_ai.response.total_chunks"
// Plugin Attributes (for aggregated streaming post-hook spans)
AttrPluginInvocations = "plugin.invocation_count"
AttrPluginAvgDurationMs = "plugin.avg_duration_ms"
AttrPluginTotalDurationMs = "plugin.total_duration_ms"
AttrPluginErrorCount = "plugin.error_count"
// Usage Attributes
AttrPromptTokens = "gen_ai.usage.prompt_tokens"
AttrCompletionTokens = "gen_ai.usage.completion_tokens"
AttrTotalTokens = "gen_ai.usage.total_tokens"
AttrInputTokens = "gen_ai.usage.input_tokens"
AttrOutputTokens = "gen_ai.usage.output_tokens"
AttrUsageCost = "gen_ai.usage.cost"
// Chat completion usage detail attributes
AttrPromptTokenDetailsText = "gen_ai.usage.prompt_token_details.text_tokens"
AttrPromptTokenDetailsAudio = "gen_ai.usage.prompt_token_details.audio_tokens"
AttrPromptTokenDetailsImage = "gen_ai.usage.prompt_token_details.image_tokens"
AttrPromptTokenDetailsCachedRead = "gen_ai.usage.prompt_token_details.cached_read_tokens"
AttrPromptTokenDetailsCachedWrite = "gen_ai.usage.prompt_token_details.cached_write_tokens"
AttrCompletionTokenDetailsText = "gen_ai.usage.completion_token_details.text_tokens"
AttrCompletionTokenDetailsAudio = "gen_ai.usage.completion_token_details.audio_tokens"
AttrCompletionTokenDetailsImage = "gen_ai.usage.completion_token_details.image_tokens"
AttrCompletionTokenDetailsReason = "gen_ai.usage.completion_token_details.reasoning_tokens"
AttrCompletionTokenDetailsAccept = "gen_ai.usage.completion_token_details.accepted_prediction_tokens"
AttrCompletionTokenDetailsReject = "gen_ai.usage.completion_token_details.rejected_prediction_tokens"
AttrCompletionTokenDetailsCite = "gen_ai.usage.completion_token_details.citation_tokens"
AttrCompletionTokenDetailsSearch = "gen_ai.usage.completion_token_details.num_search_queries"
// Error Attributes
AttrError = "gen_ai.error"
AttrErrorType = "gen_ai.error.type"
AttrErrorCode = "gen_ai.error.code"
// Input/Output Attributes
AttrInputText = "gen_ai.input.text"
AttrInputMessages = "gen_ai.input.messages"
AttrInputSpeech = "gen_ai.input.speech"
AttrInputEmbedding = "gen_ai.input.embedding"
AttrOutputMessages = "gen_ai.output.messages"
// Bifrost Context Attributes
AttrVirtualKeyID = "gen_ai.virtual_key_id"
AttrVirtualKeyName = "gen_ai.virtual_key_name"
AttrSelectedKeyID = "gen_ai.selected_key_id"
AttrSelectedKeyName = "gen_ai.selected_key_name"
AttrRoutingRuleID = "gen_ai.routing_rule_id"
AttrRoutingRuleName = "gen_ai.routing_rule_name"
AttrTeamID = "gen_ai.team_id"
AttrTeamName = "gen_ai.team_name"
AttrCustomerID = "gen_ai.customer_id"
AttrCustomerName = "gen_ai.customer_name"
AttrNumberOfRetries = "gen_ai.number_of_retries"
AttrFallbackIndex = "gen_ai.fallback_index"
// Responses API Request Attributes
AttrPromptCacheKey = "gen_ai.request.prompt_cache_key"
AttrReasoningEffort = "gen_ai.request.reasoning_effort"
AttrReasoningSummary = "gen_ai.request.reasoning_summary"
AttrReasoningGenSummary = "gen_ai.request.reasoning_generate_summary"
AttrSafetyIdentifier = "gen_ai.request.safety_identifier"
AttrStore = "gen_ai.request.store"
AttrTextVerbosity = "gen_ai.request.text_verbosity"
AttrTextFormatType = "gen_ai.request.text_format_type"
AttrTopLogProbs = "gen_ai.request.top_logprobs"
AttrToolChoiceType = "gen_ai.request.tool_choice_type"
AttrToolChoiceName = "gen_ai.request.tool_choice_name"
AttrTools = "gen_ai.request.tools"
AttrTruncation = "gen_ai.request.truncation"
// Responses API Response Attributes
AttrRespInclude = "gen_ai.responses.include"
AttrRespMaxOutputTokens = "gen_ai.responses.max_output_tokens"
AttrRespMaxToolCalls = "gen_ai.responses.max_tool_calls"
AttrRespMetadata = "gen_ai.responses.metadata"
AttrRespPreviousRespID = "gen_ai.responses.previous_response_id"
AttrRespPromptCacheKey = "gen_ai.responses.prompt_cache_key"
AttrRespReasoningText = "gen_ai.responses.reasoning"
AttrRespReasoningEffort = "gen_ai.responses.reasoning_effort"
AttrRespReasoningGenSum = "gen_ai.responses.reasoning_generate_summary"
AttrRespSafetyIdentifier = "gen_ai.responses.safety_identifier"
AttrRespStore = "gen_ai.responses.store"
AttrRespTemperature = "gen_ai.responses.temperature"
AttrRespTextVerbosity = "gen_ai.responses.text_verbosity"
AttrRespTextFormatType = "gen_ai.responses.text_format_type"
AttrRespTopLogProbs = "gen_ai.responses.top_logprobs"
AttrRespTopP = "gen_ai.responses.top_p"
AttrRespToolChoiceType = "gen_ai.responses.tool_choice_type"
AttrRespToolChoiceName = "gen_ai.responses.tool_choice_name"
AttrRespTruncation = "gen_ai.responses.truncation"
AttrRespTools = "gen_ai.responses.tools"
// Batch Operation Attributes
AttrBatchID = "gen_ai.batch.id"
AttrBatchStatus = "gen_ai.batch.status"
AttrBatchObject = "gen_ai.batch.object"
AttrBatchEndpoint = "gen_ai.batch.endpoint"
AttrBatchInputFileID = "gen_ai.batch.input_file_id"
AttrBatchOutputFileID = "gen_ai.batch.output_file_id"
AttrBatchErrorFileID = "gen_ai.batch.error_file_id"
AttrBatchCompletionWin = "gen_ai.batch.completion_window"
AttrBatchCreatedAt = "gen_ai.batch.created_at"
AttrBatchExpiresAt = "gen_ai.batch.expires_at"
AttrBatchRequestsCount = "gen_ai.batch.requests_count"
AttrBatchDataCount = "gen_ai.batch.data_count"
AttrBatchResultsCount = "gen_ai.batch.results_count"
AttrBatchHasMore = "gen_ai.batch.has_more"
AttrBatchMetadata = "gen_ai.batch.metadata"
AttrBatchLimit = "gen_ai.batch.limit"
AttrBatchAfter = "gen_ai.batch.after"
AttrBatchBeforeID = "gen_ai.batch.before_id"
AttrBatchAfterID = "gen_ai.batch.after_id"
AttrBatchPageToken = "gen_ai.batch.page_token"
AttrBatchPageSize = "gen_ai.batch.page_size"
AttrBatchCountTotal = "gen_ai.batch.request_counts.total"
AttrBatchCountCompleted = "gen_ai.batch.request_counts.completed"
AttrBatchCountFailed = "gen_ai.batch.request_counts.failed"
AttrBatchFirstID = "gen_ai.batch.first_id"
AttrBatchLastID = "gen_ai.batch.last_id"
AttrBatchInProgressAt = "gen_ai.batch.in_progress_at"
AttrBatchFinalizingAt = "gen_ai.batch.finalizing_at"
AttrBatchCompletedAt = "gen_ai.batch.completed_at"
AttrBatchFailedAt = "gen_ai.batch.failed_at"
AttrBatchExpiredAt = "gen_ai.batch.expired_at"
AttrBatchCancellingAt = "gen_ai.batch.cancelling_at"
AttrBatchCancelledAt = "gen_ai.batch.cancelled_at"
AttrBatchNextCursor = "gen_ai.batch.next_cursor"
// Transcription Response Attributes
AttrInputTokenDetailsText = "gen_ai.usage.input_token_details.text_tokens"
AttrInputTokenDetailsAudio = "gen_ai.usage.input_token_details.audio_tokens"
// File Operation Attributes
AttrFileID = "gen_ai.file.id"
AttrFileObject = "gen_ai.file.object"
AttrFileFilename = "gen_ai.file.filename"
AttrFilePurpose = "gen_ai.file.purpose"
AttrFileBytes = "gen_ai.file.bytes"
AttrFileCreatedAt = "gen_ai.file.created_at"
AttrFileStatus = "gen_ai.file.status"
AttrFileStorageBackend = "gen_ai.file.storage_backend"
AttrFileDataCount = "gen_ai.file.data_count"
AttrFileHasMore = "gen_ai.file.has_more"
AttrFileDeleted = "gen_ai.file.deleted"
AttrFileContentType = "gen_ai.file.content_type"
AttrFileContentBytes = "gen_ai.file.content_bytes"
AttrFileLimit = "gen_ai.file.limit"
AttrFileAfter = "gen_ai.file.after"
AttrFileOrder = "gen_ai.file.order"
)

204
core/schemas/tracer.go Normal file
View File

@@ -0,0 +1,204 @@
// Package schemas defines the core schemas and types used by the Bifrost system.
package schemas
import (
"context"
"time"
)
// SpanHandle is an opaque handle to a span, implementation-specific.
// Different Tracer implementations can use their own concrete types.
type SpanHandle interface{}
// StreamAccumulatorResult contains the accumulated data from streaming chunks.
// This is the return type for tracer's streaming accumulation methods.
type StreamAccumulatorResult struct {
RequestID string // Request ID
RequestedModel string // Original model requested by the caller
ResolvedModel string // Actual model used by the provider (equals RequestedModel when no alias mapping exists)
Provider ModelProvider // Provider used
Status string // Status of the stream
Latency int64 // Latency in milliseconds
TimeToFirstToken int64 // Time to first token in milliseconds
OutputMessage *ChatMessage // Accumulated output message
OutputMessages []ResponsesMessage // For responses API
TokenUsage *BifrostLLMUsage // Token usage
Cost *float64 // Cost in dollars
ErrorDetails *BifrostError // Error details if any
AudioOutput *BifrostSpeechResponse // For speech streaming
TranscriptionOutput *BifrostTranscriptionResponse // For transcription streaming
ImageGenerationOutput *BifrostImageGenerationResponse // For image generation streaming
FinishReason *string // Finish reason
RawResponse *string // Raw response
RawRequest interface{} // Raw request
}
// Tracer defines the interface for distributed tracing in Bifrost.
// Implementations can be injected via BifrostConfig to enable automatic instrumentation.
// The interface is designed to be minimal and implementation-agnostic.
type Tracer interface {
// CreateTrace creates a new trace with optional parent ID and returns the trace ID.
// The parentID can be extracted from W3C traceparent headers for distributed tracing.
// The requestID is optional and can be used to identify the request.
CreateTrace(parentID string, requestID ...string) string
// EndTrace completes a trace and returns the trace data for observation/export.
// After this call, the trace is removed from active tracking and returned for cleanup.
// Returns nil if trace not found.
EndTrace(traceID string) *Trace
// StartSpan creates a new span as a child of the current span in context.
// Returns updated context with new span and a handle for the span.
// The context should be used for subsequent operations to maintain span hierarchy.
StartSpan(ctx context.Context, name string, kind SpanKind) (context.Context, SpanHandle)
// EndSpan completes a span with status and optional message.
// Should be called when the operation represented by the span is complete.
EndSpan(handle SpanHandle, status SpanStatus, statusMsg string)
// SetAttribute sets an attribute on the span.
// Attributes provide additional context about the operation.
SetAttribute(handle SpanHandle, key string, value any)
// AddEvent adds a timestamped event to the span.
// Events represent discrete occurrences during the span's lifetime.
AddEvent(handle SpanHandle, name string, attrs map[string]any)
// PopulateLLMRequestAttributes populates all LLM-specific request attributes on the span.
// This includes model parameters, input messages, temperature, max tokens, etc.
PopulateLLMRequestAttributes(handle SpanHandle, req *BifrostRequest)
// PopulateLLMResponseAttributes populates all LLM-specific response attributes on the span.
// This includes output messages, tokens, usage stats, and error information if present.
PopulateLLMResponseAttributes(ctx *BifrostContext, handle SpanHandle, resp *BifrostResponse, err *BifrostError)
// StoreDeferredSpan stores a span handle for later completion (used for streaming requests).
// The span handle is stored keyed by trace ID so it can be retrieved when the stream completes.
StoreDeferredSpan(traceID string, handle SpanHandle)
// GetDeferredSpanHandle retrieves a deferred span handle by trace ID.
// Returns nil if no deferred span exists for the given trace ID.
GetDeferredSpanHandle(traceID string) SpanHandle
// ClearDeferredSpan removes the deferred span handle for a trace ID.
// Should be called after the deferred span has been completed.
ClearDeferredSpan(traceID string)
// GetDeferredSpanID returns the span ID for the deferred span.
// Returns empty string if no deferred span exists.
GetDeferredSpanID(traceID string) string
// AddStreamingChunk accumulates a streaming chunk for the deferred span.
// Pass the full BifrostResponse to capture content, tool calls, reasoning, etc.
// This is called for each streaming chunk to build up the complete response.
AddStreamingChunk(traceID string, response *BifrostResponse)
// GetAccumulatedChunks returns the accumulated response, TTFT, and chunk count for a deferred span.
// The response is built from the streaming accumulator during the final ProcessStreamingChunk call.
// Returns nil response if no plugin has called ProcessStreamingChunk (callers should nil-check).
// Returns nil, 0, 0 if no accumulated data exists.
GetAccumulatedChunks(traceID string) (response *BifrostResponse, ttftNs int64, chunkCount int)
// CreateStreamAccumulator creates a new stream accumulator for the given trace ID.
// This should be called at the start of a streaming request.
CreateStreamAccumulator(traceID string, startTime time.Time)
// CleanupStreamAccumulator removes the stream accumulator for the given trace ID.
// This should be called after the streaming request is complete.
CleanupStreamAccumulator(traceID string)
// ProcessStreamingChunk processes a streaming chunk and accumulates it.
// Returns the accumulated result. IsFinal will be true when the stream is complete.
// This method is used by plugins to access accumulated streaming data.
// The ctx parameter must contain the stream end indicator for proper final chunk detection.
ProcessStreamingChunk(traceID string, isFinalChunk bool, result *BifrostResponse, err *BifrostError) *StreamAccumulatorResult
// AttachPluginLogs appends plugin log entries to the trace identified by traceID.
// Thread-safe. Should be called after plugin hooks complete, before trace completion.
AttachPluginLogs(traceID string, logs []PluginLogEntry)
// CompleteAndFlushTrace ends a trace, exports it to observability plugins, and
// releases the trace resources. Used by transports that bypass normal HTTP trace completion.
CompleteAndFlushTrace(traceID string)
// Stop releases resources associated with the tracer.
// Should be called during shutdown to stop background goroutines.
Stop()
}
// NoOpTracer is a tracer that does nothing (default when tracing disabled).
// It satisfies the Tracer interface but performs no actual tracing operations.
type NoOpTracer struct{}
// CreateTrace returns an empty string (no trace created).
func (n *NoOpTracer) CreateTrace(_ string, _ ...string) string { return "" }
// EndTrace returns nil (no trace to end).
func (n *NoOpTracer) EndTrace(_ string) *Trace { return nil }
// StartSpan returns the context unchanged and a nil handle.
func (n *NoOpTracer) StartSpan(ctx context.Context, _ string, _ SpanKind) (context.Context, SpanHandle) {
return ctx, nil
}
// EndSpan does nothing.
func (n *NoOpTracer) EndSpan(_ SpanHandle, _ SpanStatus, _ string) {}
// SetAttribute does nothing.
func (n *NoOpTracer) SetAttribute(_ SpanHandle, _ string, _ any) {}
// AddEvent does nothing.
func (n *NoOpTracer) AddEvent(_ SpanHandle, _ string, _ map[string]any) {}
// PopulateLLMRequestAttributes does nothing.
func (n *NoOpTracer) PopulateLLMRequestAttributes(_ SpanHandle, _ *BifrostRequest) {}
// PopulateLLMResponseAttributes does nothing.
func (n *NoOpTracer) PopulateLLMResponseAttributes(_ *BifrostContext, _ SpanHandle, _ *BifrostResponse, _ *BifrostError) {
}
// StoreDeferredSpan does nothing.
func (n *NoOpTracer) StoreDeferredSpan(_ string, _ SpanHandle) {}
// GetDeferredSpanHandle returns nil.
func (n *NoOpTracer) GetDeferredSpanHandle(_ string) SpanHandle { return nil }
// ClearDeferredSpan does nothing.
func (n *NoOpTracer) ClearDeferredSpan(_ string) {}
// GetDeferredSpanID returns empty string.
func (n *NoOpTracer) GetDeferredSpanID(_ string) string { return "" }
// AddStreamingChunk does nothing.
func (n *NoOpTracer) AddStreamingChunk(_ string, _ *BifrostResponse) {}
// GetAccumulatedChunks returns nil, 0, 0.
func (n *NoOpTracer) GetAccumulatedChunks(_ string) (*BifrostResponse, int64, int) { return nil, 0, 0 }
// CreateStreamAccumulator does nothing.
func (n *NoOpTracer) CreateStreamAccumulator(_ string, _ time.Time) {}
// CleanupStreamAccumulator does nothing.
func (n *NoOpTracer) CleanupStreamAccumulator(_ string) {}
// ProcessStreamingChunk returns nil.
func (n *NoOpTracer) ProcessStreamingChunk(_ string, _ bool, _ *BifrostResponse, _ *BifrostError) *StreamAccumulatorResult {
return nil
}
// AttachPluginLogs does nothing.
func (n *NoOpTracer) AttachPluginLogs(_ string, _ []PluginLogEntry) {}
// CompleteAndFlushTrace does nothing.
func (n *NoOpTracer) CompleteAndFlushTrace(_ string) {}
// Stop does nothing.
func (n *NoOpTracer) Stop() {}
// DefaultTracer returns a no-op tracer for use when tracing is disabled.
func DefaultTracer() Tracer {
return &NoOpTracer{}
}
// Ensure NoOpTracer implements Tracer at compile time
var _ Tracer = (*NoOpTracer)(nil)

View File

@@ -0,0 +1,156 @@
package schemas
type BifrostTranscriptionRequest struct {
Provider ModelProvider `json:"provider"`
Model string `json:"model"`
Input *TranscriptionInput `json:"input,omitempty"`
Params *TranscriptionParameters `json:"params,omitempty"`
Fallbacks []Fallback `json:"fallbacks,omitempty"`
RawRequestBody []byte `json:"-"` // set bifrost-use-raw-request-body to true in ctx to use the raw request body. Bifrost will directly send this to the downstream provider.
}
func (r *BifrostTranscriptionRequest) GetRawRequestBody() []byte {
return r.RawRequestBody
}
type BifrostTranscriptionResponse struct {
Duration *float64 `json:"duration,omitempty"` // Duration in seconds
Language *string `json:"language,omitempty"` // e.g., "english"
LogProbs []TranscriptionLogProb `json:"logprobs,omitempty"`
Segments []TranscriptionSegment `json:"segments,omitempty"`
Task *string `json:"task,omitempty"` // e.g., "transcribe"
Text string `json:"text"`
Usage *TranscriptionUsage `json:"usage,omitempty"`
Words []TranscriptionWord `json:"words,omitempty"`
ResponseFormat *string `json:"-"` // Set by provider for non-JSON formats (text, srt, vtt); used by integration response converters
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}
func (r *BifrostTranscriptionResponse) BackfillParams(req *BifrostTranscriptionRequest) {
if r == nil || req == nil || req.Params == nil || req.Params.ResponseFormat == nil {
return
}
r.ResponseFormat = req.Params.ResponseFormat
}
// IsPlainTextTranscriptionFormat returns true if the given response format
// produces a plain-text response body (not JSON).
func IsPlainTextTranscriptionFormat(format *string) bool {
if format == nil {
return false
}
switch *format {
case "text", "srt", "vtt":
return true
default:
return false
}
}
type TranscriptionInput struct {
File []byte `json:"file"`
Filename string `json:"filename,omitempty"` // Original filename, used to preserve file format extension
}
type TranscriptionParameters struct {
Language *string `json:"language,omitempty"`
Prompt *string `json:"prompt,omitempty"`
ResponseFormat *string `json:"response_format,omitempty"` // Default is "json"
Temperature *float64 `json:"temperature,omitempty"` // Sampling temperature (0.0-1.0)
TimestampGranularities []string `json:"timestamp_granularities,omitempty"` // "word" and/or "segment"; requires response_format=verbose_json
Include []string `json:"include,omitempty"` // Additional response info (e.g., logprobs)
Format *string `json:"file_format,omitempty"` // Type of file, not required in openai, but required in gemini
MaxLength *int `json:"max_length,omitempty"` // Maximum length of the transcription used by HuggingFace
MinLength *int `json:"min_length,omitempty"` // Minimum length of the transcription used by HuggingFace
MaxNewTokens *int `json:"max_new_tokens,omitempty"` // Maximum new tokens to generate used by HuggingFace
MinNewTokens *int `json:"min_new_tokens,omitempty"` // Minimum new tokens to generate used by HuggingFace
// Elevenlabs-specific fields
AdditionalFormats []TranscriptionAdditionalFormat `json:"additional_formats,omitempty"`
WebhookMetadata interface{} `json:"webhook_metadata,omitempty"`
// Dynamic parameters that can be provider-specific, they are directly
// added to the request as is.
ExtraParams map[string]interface{} `json:"-"`
}
type TranscriptionAdditionalFormat struct {
Format TranscriptionExportOptions `json:"format"`
IncludeSpeakers *bool `json:"include_speakers,omitempty"`
IncludeTimestamps *bool `json:"include_timestamps,omitempty"`
SegmentOnSilenceLongerThanS *float64 `json:"segment_on_silence_longer_than_s,omitempty"`
MaxSegmentDurationS *float64 `json:"max_segment_duration_s,omitempty"`
MaxSegmentChars *int `json:"max_segment_chars,omitempty"`
MaxCharactersPerLine *int `json:"max_characters_per_line,omitempty"`
}
type TranscriptionExportOptions string
const (
TranscriptionExportOptionsSegmentedJson TranscriptionExportOptions = "segmented_json"
TranscriptionExportOptionsDocx TranscriptionExportOptions = "docx"
TranscriptionExportOptionsPdf TranscriptionExportOptions = "pdf"
TranscriptionExportOptionsTxt TranscriptionExportOptions = "txt"
TranscriptionExportOptionsHtml TranscriptionExportOptions = "html"
TranscriptionExportOptionsSrt TranscriptionExportOptions = "srt"
)
// TranscriptionLogProb represents log probability information for transcription
type TranscriptionLogProb struct {
Token string `json:"token"`
LogProb float64 `json:"logprob"`
Bytes []int `json:"bytes"`
}
// TranscriptionWord represents word-level timing information
type TranscriptionWord struct {
Word string `json:"word"`
Start float64 `json:"start"`
End float64 `json:"end"`
}
// TranscriptionSegment represents segment-level transcription information
type TranscriptionSegment struct {
ID int `json:"id"`
Seek int `json:"seek"`
Start float64 `json:"start"`
End float64 `json:"end"`
Text string `json:"text"`
Tokens []int `json:"tokens"`
Temperature float64 `json:"temperature"`
AvgLogProb float64 `json:"avg_logprob"`
CompressionRatio float64 `json:"compression_ratio"`
NoSpeechProb float64 `json:"no_speech_prob"`
}
// TranscriptionUsage represents usage information for transcription
type TranscriptionUsage struct {
Type string `json:"type"` // "tokens" or "duration"
InputTokens *int `json:"input_tokens,omitempty"`
InputTokenDetails *TranscriptionUsageInputTokenDetails `json:"input_token_details,omitempty"`
OutputTokens *int `json:"output_tokens,omitempty"`
TotalTokens *int `json:"total_tokens,omitempty"`
Seconds *int `json:"seconds,omitempty"` // For duration-based usage
}
type TranscriptionUsageInputTokenDetails struct {
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
}
type TranscriptionStreamResponseType string
const (
TranscriptionStreamResponseTypeDelta TranscriptionStreamResponseType = "transcript.text.delta"
TranscriptionStreamResponseTypeDone TranscriptionStreamResponseType = "transcript.text.done"
)
// BifrostTranscriptionStreamResponse represents streaming specific fields only
type BifrostTranscriptionStreamResponse struct {
Delta *string `json:"delta,omitempty"` // For delta events
LogProbs []TranscriptionLogProb `json:"logprobs,omitempty"`
Text string `json:"text"`
Type TranscriptionStreamResponseType `json:"type"`
Usage *TranscriptionUsage `json:"usage,omitempty"`
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}

View File

@@ -0,0 +1,72 @@
package schemas
import "strings"
// UserAgentIdentifiers lists substrings that may appear in User-Agent for a given integration.
// Versions of the same client may use different strings; Matches checks any of them.
type UserAgentIdentifiers []string
var (
// ClaudeCLI — Anthropic Claude Code / Claude CLI (identifiers vary by release).
ClaudeCLI = UserAgentIdentifiers{"claude-cli", "claude-code", "claude-vscode"}
GeminiCLI = UserAgentIdentifiers{"geminicli"}
CodexCLI = UserAgentIdentifiers{"codex-tui"}
QwenCodeCLI = UserAgentIdentifiers{"qwencode"}
OpenCode = UserAgentIdentifiers{"opencode"}
Cursor = UserAgentIdentifiers{"cursor"}
)
// integrationUserAgents is the set of known client User-Agent patterns we persist on the context.
var integrationUserAgents = []UserAgentIdentifiers{
ClaudeCLI, GeminiCLI, CodexCLI, QwenCodeCLI, OpenCode, Cursor,
}
// Matches reports whether userAgent contains any identifier (case-insensitive substring match).
func (ids UserAgentIdentifiers) Matches(userAgent string) bool {
if len(ids) == 0 || userAgent == "" {
return false
}
ua := strings.ToLower(userAgent)
for _, id := range ids {
if id == "" {
continue
}
if strings.Contains(ua, strings.ToLower(id)) {
return true
}
}
return false
}
// String returns the first identifier for logging and tests that need a canonical sample value.
func (ids UserAgentIdentifiers) String() string {
if len(ids) == 0 {
return ""
}
return ids[0]
}
func ExtractAndSetUserAgentFromHeaders(headers map[string][]string, bifrostCtx *BifrostContext) {
if len(headers) == 0 {
return
}
if bifrostCtx == nil {
return
}
var userAgent []string
for key, value := range headers {
if strings.EqualFold(key, "user-agent") {
userAgent = value
break
}
}
if len(userAgent) > 0 {
ua := userAgent[0]
for _, ids := range integrationUserAgents {
if ids.Matches(ua) {
bifrostCtx.SetValue(BifrostContextKeyUserAgent, ua)
break
}
}
}
}

1398
core/schemas/utils.go Normal file

File diff suppressed because it is too large Load Diff

247
core/schemas/videos.go Normal file
View File

@@ -0,0 +1,247 @@
package schemas
// VideoStatus is the lifecycle status of a video job.
type VideoStatus string
const (
VideoStatusQueued VideoStatus = "queued"
VideoStatusInProgress VideoStatus = "in_progress"
VideoStatusCompleted VideoStatus = "completed"
VideoStatusFailed VideoStatus = "failed"
)
type VideoOutputType string
const (
VideoOutputTypeBase64 VideoOutputType = "base64"
VideoOutputTypeURL VideoOutputType = "url"
)
// VideoCreateError is the error payload when video generation fails.
type VideoCreateError struct {
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}
// ContentFilterInfo contains information about content that was filtered due to safety policies.
// This is a provider-agnostic structure for representing content filtering results.
type ContentFilterInfo struct {
FilteredCount int `json:"filtered_count,omitempty"` // Number of items filtered
Reasons []string `json:"reasons,omitempty"` // Human-readable reasons for filtering
}
type VideoOutput struct {
Type VideoOutputType `json:"type"` // "url" | "base64"
URL *string `json:"url,omitempty"`
Base64Data *string `json:"base64,omitempty"`
ContentType string `json:"content_type"`
}
// VideoReferenceInput represents a reference image for video generation
type VideoReferenceInput struct {
Image []byte `json:"image"` // Image bytes
ReferenceType string `json:"reference_type,omitempty"` // "style" or "asset" (Gemini: "REFERENCE_TYPE_STYLE" or "REFERENCE_TYPE_ASSET")
}
type VideoObject struct {
ID string `json:"id"`
Object string `json:"object"` // always "video"
Model string `json:"model"`
Status VideoStatus `json:"status"`
CreatedAt int64 `json:"created_at"`
CompletedAt *int64 `json:"completed_at,omitempty"`
ExpiresAt *int64 `json:"expires_at,omitempty"`
Progress *float64 `json:"progress,omitempty"`
Prompt string `json:"prompt"`
RemixedFromVideoID *string `json:"remixed_from_video_id,omitempty"`
Seconds *string `json:"seconds"`
Size string `json:"size"`
Error *VideoCreateError `json:"error,omitempty"`
}
// --- Video Generation ---
type BifrostVideoGenerationRequest struct {
Provider ModelProvider `json:"provider"`
Model string `json:"model"`
Input *VideoGenerationInput `json:"input"`
Params *VideoGenerationParameters `json:"params,omitempty"`
Fallbacks []Fallback `json:"fallbacks,omitempty"`
RawRequestBody []byte `json:"-"`
}
func (b *BifrostVideoGenerationRequest) GetRawRequestBody() []byte {
return b.RawRequestBody
}
func (b *BifrostVideoGenerationRequest) GetExtraParams() map[string]interface{} {
if b == nil || b.Params == nil {
return nil
}
return b.Params.ExtraParams
}
type VideoGenerationInput struct {
Prompt string `json:"prompt"`
InputReference *string `json:"input_reference,omitempty"` // Primary image for image-to-video (OpenAI-compatible)
}
type VideoGenerationParameters struct {
Seconds *string `json:"seconds,omitempty"`
Size string `json:"size,omitempty"`
NegativePrompt *string `json:"negative_prompt,omitempty"`
Seed *int `json:"seed,omitempty"`
VideoURI *string `json:"video_uri,omitempty"` // for video to video generation
Audio *bool `json:"audio,omitempty"`
ExtraParams map[string]any `json:"-"`
}
// DefaultVideoDuration is the default video duration in seconds for Gemini/Vertex when not specified.
const DefaultVideoDuration = "8"
// BifrostVideoGenerationResponse represents the video generation job response in bifrost format.
type BifrostVideoGenerationResponse struct {
ID string `json:"id,omitempty"`
CompletedAt *int64 `json:"completed_at,omitempty"` // Unix timestamp (seconds) when the job completed
CreatedAt int64 `json:"created_at,omitempty"` // Unix timestamp (seconds) when the job was created
Error *VideoCreateError `json:"error,omitempty"` // Error payload if generation failed
ExpiresAt *int64 `json:"expires_at,omitempty"` // Unix timestamp (seconds) when downloadable assets expire
Model string `json:"model,omitempty"` // Video generation model that produced the job
Object string `json:"object,omitempty"` // Object type, always "video"
Progress *float64 `json:"progress,omitempty"` // Approximate completion percentage (0-100)
Prompt string `json:"prompt,omitempty"` // Prompt used to generate the video
RemixedFromVideoID *string `json:"remixed_from_video_id,omitempty"` // Source video ID if this is a remix
Seconds *string `json:"seconds,omitempty"` // Duration of the generated clip in seconds
Size string `json:"size,omitempty"` // Resolution of the generated video
Status VideoStatus `json:"status,omitempty"` // Current lifecycle status of the video job
Videos []VideoOutput `json:"videos,omitempty"` // Generated videos (supports multiple videos)
ContentFilter *ContentFilterInfo `json:"content_filter,omitempty"` // Information about content filtering (if applicable)
ExtraFields BifrostResponseExtraFields `json:"extra_fields,omitempty"`
}
// getSecondsFromVideoRequest extracts Seconds from video-related requests.
func getSecondsFromVideoRequest(req *BifrostRequest) *string {
if req == nil {
return nil
}
useDefaultForSeconds := func(p ModelProvider) bool {
return p == Gemini || p == Vertex
}
if req.VideoGenerationRequest != nil {
var seconds *string
if req.VideoGenerationRequest.Params != nil {
seconds = req.VideoGenerationRequest.Params.Seconds
}
if seconds == nil && useDefaultForSeconds(req.VideoGenerationRequest.Provider) {
seconds = Ptr(DefaultVideoDuration)
}
return seconds
}
if req.VideoRemixRequest != nil && useDefaultForSeconds(req.VideoRemixRequest.Provider) {
return Ptr(DefaultVideoDuration)
}
return nil
}
// BackfillParams populates response fields from the original request that are needed
// for cost calculation but may not be returned by the provider.
// - Seconds (duration from request params or default)
func (r *BifrostVideoGenerationResponse) BackfillParams(req *BifrostRequest) {
if r == nil || req == nil {
return
}
seconds := getSecondsFromVideoRequest(req)
if seconds != nil {
r.Seconds = seconds
}
if r.Model == "" && req.VideoGenerationRequest != nil {
r.Model = req.VideoGenerationRequest.Model
}
}
// --- Video Remix ---
type BifrostVideoRemixRequest struct {
ID string `json:"id"`
Provider ModelProvider `json:"provider"`
Input *VideoGenerationInput `json:"input"`
ExtraParams map[string]any `json:"-"`
RawRequestBody []byte `json:"-"`
}
func (b *BifrostVideoRemixRequest) GetRawRequestBody() []byte {
return b.RawRequestBody
}
func (b *BifrostVideoRemixRequest) GetExtraParams() map[string]interface{} {
if b == nil {
return nil
}
return b.ExtraParams
}
// --- Video List ---
type BifrostVideoListRequest struct {
Provider ModelProvider `json:"provider"`
After *string `json:"after,omitempty"`
Limit *int `json:"limit,omitempty"`
Order *string `json:"order,omitempty"`
}
type BifrostVideoListResponse struct {
Object string `json:"object"` // "list"
Data []VideoObject `json:"data"`
FirstID *string `json:"first_id,omitempty"`
HasMore *bool `json:"has_more,omitempty"`
LastID *string `json:"last_id,omitempty"`
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}
// --- Video Retrieve / Delete ---
type BifrostVideoReferenceRequest struct {
Provider ModelProvider `json:"provider"`
ID string `json:"id"`
}
type BifrostVideoDeleteRequest = BifrostVideoReferenceRequest
type BifrostVideoRetrieveRequest = BifrostVideoReferenceRequest
type BifrostVideoDeleteResponse struct {
ID string `json:"id"`
Deleted bool `json:"deleted"`
Object string `json:"object,omitempty"` // "video.deleted"
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}
// --- Video Download ---
type BifrostVideoDownloadRequest struct {
Provider ModelProvider `json:"provider"`
ID string `json:"id"`
Variant *VideoDownloadVariant `json:"variant,omitempty"`
ExtraParams map[string]any `json:"-"`
}
type VideoDownloadVariant string
const (
VideoDownloadVariantVideo VideoDownloadVariant = "video"
VideoDownloadVariantThumbnail VideoDownloadVariant = "thumbnail"
VideoDownloadVariantSpriteSheet VideoDownloadVariant = "sprite_sheet"
)
type BifrostVideoDownloadResponse struct {
VideoID string `json:"video_id"`
Content []byte `json:"-"` // Raw video content (not serialized)
ContentType string `json:"content_type,omitempty"` // MIME type (e.g., "video/mp4", "image/png" for thumbnails)
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}
type VideoLogParams struct {
VideoID string `json:"video_id"`
}

103
core/schemas/websocket.go Normal file
View File

@@ -0,0 +1,103 @@
package schemas
import "encoding/json"
// WebSocketEventType represents event types in the Responses API WebSocket protocol.
type WebSocketEventType string
const (
WSEventResponseCreate WebSocketEventType = "response.create"
WSEventError WebSocketEventType = "error"
)
// WebSocketResponsesEvent represents a client-sent event over the Responses WebSocket connection.
// The payload mirrors the Responses API create body, with transport-specific fields
// (stream, background) omitted since they're implicit in the WebSocket context.
type WebSocketResponsesEvent struct {
Type WebSocketEventType `json:"type"`
Model string `json:"model,omitempty"`
Store *bool `json:"store,omitempty"`
Input json.RawMessage `json:"input,omitempty"`
Instructions string `json:"instructions,omitempty"`
PreviousResponseID string `json:"previous_response_id,omitempty"`
Generate *bool `json:"generate,omitempty"`
Tools json.RawMessage `json:"tools,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
Reasoning json.RawMessage `json:"reasoning,omitempty"`
Metadata json.RawMessage `json:"metadata,omitempty"`
Text json.RawMessage `json:"text,omitempty"`
Truncation string `json:"truncation,omitempty"`
}
// WebSocketErrorEvent represents a server-sent error event over WebSocket.
type WebSocketErrorEvent struct {
Type WebSocketEventType `json:"type"`
Status int `json:"status,omitempty"`
Error *WebSocketErrorBody `json:"error,omitempty"`
}
// WebSocketErrorBody is the error detail within a WebSocketErrorEvent.
type WebSocketErrorBody struct {
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
Param string `json:"param,omitempty"`
}
// WebSocketConfig provides optional tuning for WebSocket gateway features.
// WebSocket is always enabled. These fields allow overriding the high defaults.
type WebSocketConfig struct {
MaxConnections int `json:"max_connections_per_user"`
TranscriptBufferSize int `json:"transcript_buffer_size"`
Pool *WSPoolConfig `json:"pool,omitempty"`
}
// WSPoolConfig configures the upstream WebSocket connection pool.
type WSPoolConfig struct {
MaxIdlePerKey int `json:"max_idle_per_key"`
MaxTotalConnections int `json:"max_total_connections"`
IdleTimeoutSeconds int `json:"idle_timeout_seconds"`
MaxConnectionLifetimeSeconds int `json:"max_connection_lifetime_seconds"`
}
// Default pool configuration values (set high for production workloads)
const (
DefaultWSMaxIdlePerKey = 50
DefaultWSMaxTotalConnections = 1000
DefaultWSIdleTimeoutSeconds = 600
DefaultWSMaxConnectionLifetimeSeconds = 7200
DefaultWSMaxConnections = 100
DefaultWSTranscriptBufferSize = 100
)
// CheckAndSetDefaults fills in default values for WebSocketConfig.
func (c *WebSocketConfig) CheckAndSetDefaults() {
if c.MaxConnections <= 0 {
c.MaxConnections = DefaultWSMaxConnections
}
if c.TranscriptBufferSize <= 0 {
c.TranscriptBufferSize = DefaultWSTranscriptBufferSize
}
if c.Pool == nil {
c.Pool = &WSPoolConfig{}
}
c.Pool.CheckAndSetDefaults()
}
// CheckAndSetDefaults fills in default values for WSPoolConfig.
func (c *WSPoolConfig) CheckAndSetDefaults() {
if c.MaxIdlePerKey <= 0 {
c.MaxIdlePerKey = DefaultWSMaxIdlePerKey
}
if c.MaxTotalConnections <= 0 {
c.MaxTotalConnections = DefaultWSMaxTotalConnections
}
if c.IdleTimeoutSeconds <= 0 {
c.IdleTimeoutSeconds = DefaultWSIdleTimeoutSeconds
}
if c.MaxConnectionLifetimeSeconds <= 0 {
c.MaxConnectionLifetimeSeconds = DefaultWSMaxConnectionLifetimeSeconds
}
}

View File

@@ -0,0 +1,59 @@
package schemas
import (
"testing"
)
func TestWebSocketConfig_CheckAndSetDefaults(t *testing.T) {
config := &WebSocketConfig{}
config.CheckAndSetDefaults()
if config.MaxConnections != DefaultWSMaxConnections {
t.Errorf("expected MaxConnections=%d, got %d", DefaultWSMaxConnections, config.MaxConnections)
}
if config.TranscriptBufferSize != DefaultWSTranscriptBufferSize {
t.Errorf("expected TranscriptBufferSize=%d, got %d", DefaultWSTranscriptBufferSize, config.TranscriptBufferSize)
}
if config.Pool == nil {
t.Fatal("expected Pool to be initialized")
}
if config.Pool.MaxIdlePerKey != DefaultWSMaxIdlePerKey {
t.Errorf("expected Pool.MaxIdlePerKey=%d, got %d", DefaultWSMaxIdlePerKey, config.Pool.MaxIdlePerKey)
}
if config.Pool.MaxTotalConnections != DefaultWSMaxTotalConnections {
t.Errorf("expected Pool.MaxTotalConnections=%d, got %d", DefaultWSMaxTotalConnections, config.Pool.MaxTotalConnections)
}
if config.Pool.IdleTimeoutSeconds != DefaultWSIdleTimeoutSeconds {
t.Errorf("expected Pool.IdleTimeoutSeconds=%d, got %d", DefaultWSIdleTimeoutSeconds, config.Pool.IdleTimeoutSeconds)
}
if config.Pool.MaxConnectionLifetimeSeconds != DefaultWSMaxConnectionLifetimeSeconds {
t.Errorf("expected Pool.MaxConnectionLifetimeSeconds=%d, got %d", DefaultWSMaxConnectionLifetimeSeconds, config.Pool.MaxConnectionLifetimeSeconds)
}
}
func TestWebSocketConfig_PreservesExistingValues(t *testing.T) {
config := &WebSocketConfig{
MaxConnections: 20,
TranscriptBufferSize: 123,
Pool: &WSPoolConfig{
MaxIdlePerKey: 5,
MaxTotalConnections: 50,
IdleTimeoutSeconds: 60,
MaxConnectionLifetimeSeconds: 1800,
},
}
config.CheckAndSetDefaults()
if config.MaxConnections != 20 {
t.Errorf("expected MaxConnections=20, got %d", config.MaxConnections)
}
if config.TranscriptBufferSize != 123 {
t.Errorf("expected TranscriptBufferSize=123, got %d", config.TranscriptBufferSize)
}
if config.Pool.MaxIdlePerKey != 5 {
t.Errorf("expected Pool.MaxIdlePerKey=5, got %d", config.Pool.MaxIdlePerKey)
}
if config.Pool.MaxTotalConnections != 50 {
t.Errorf("expected Pool.MaxTotalConnections=50, got %d", config.Pool.MaxTotalConnections)
}
}