first commit
This commit is contained in:
294
core/schemas/account.go
Normal file
294
core/schemas/account.go
Normal file
@@ -0,0 +1,294 @@
|
||||
// Package schemas defines the core schemas and types used by the Bifrost system.
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type KeyStatusType string
|
||||
|
||||
const (
|
||||
KeyStatusSuccess KeyStatusType = "success"
|
||||
KeyStatusListModelsFailed KeyStatusType = "list_models_failed"
|
||||
)
|
||||
|
||||
// WhiteList is a list of values that are allowed to be used.
|
||||
// Semantics:
|
||||
// - "*" (alone) means all values are allowed.
|
||||
// - Empty list means nothing is allowed.
|
||||
// - Non-empty list (without "*") means only the listed values are allowed.
|
||||
//
|
||||
// This type is used generically for any field that needs whitelist behavior
|
||||
// (e.g., allowed models, allowed tools).
|
||||
type WhiteList []string
|
||||
|
||||
// Contains reports whether value is in the whitelist.
|
||||
// Returns true if value is in the list.
|
||||
func (wl WhiteList) Contains(value string) bool {
|
||||
return slices.ContainsFunc(wl, func(s string) bool {
|
||||
return strings.EqualFold(s, value)
|
||||
})
|
||||
}
|
||||
|
||||
// IsAllowed reports whether value is in the whitelist.
|
||||
// Returns true if value is in the list.
|
||||
func (wl WhiteList) IsAllowed(value string) bool {
|
||||
return wl.IsUnrestricted() || wl.Contains(value)
|
||||
}
|
||||
|
||||
// IsEmpty reports whether the whitelist has no entries.
|
||||
func (wl WhiteList) IsEmpty() bool {
|
||||
return len(wl) == 0
|
||||
}
|
||||
|
||||
// IsUnrestricted reports whether the whitelist contains only "*",
|
||||
// meaning all values are allowed.
|
||||
func (wl WhiteList) IsUnrestricted() bool {
|
||||
return len(wl) == 1 && wl[0] == "*"
|
||||
}
|
||||
|
||||
// IsRestricted reports whether the whitelist contains entries other than "*",
|
||||
// meaning only the listed values are allowed.
|
||||
func (wl WhiteList) IsRestricted() bool {
|
||||
return !wl.IsUnrestricted()
|
||||
}
|
||||
|
||||
// Validate checks that the whitelist is well-formed.
|
||||
// Returns an error if "*" is present alongside other values, or if there are duplicate entries.
|
||||
func (wl WhiteList) Validate() error {
|
||||
if wl.Contains("*") && len(wl) > 1 {
|
||||
return fmt.Errorf("wildcard '*' cannot be used with other values in the whitelist")
|
||||
}
|
||||
seen := make(map[string]struct{}, len(wl))
|
||||
for _, v := range wl {
|
||||
normalized := strings.ToLower(v)
|
||||
if _, ok := seen[normalized]; ok {
|
||||
return fmt.Errorf("duplicate value '%s' in whitelist", v)
|
||||
}
|
||||
seen[normalized] = struct{}{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BlackList is a list of values that are denied.
|
||||
// Semantics:
|
||||
// - "*" (alone) means all values are blocked.
|
||||
// - Empty list means nothing is blocked.
|
||||
// - Non-empty list (without "*") means only the listed values are blocked.
|
||||
type BlackList []string
|
||||
|
||||
func (bl BlackList) Contains(value string) bool {
|
||||
return slices.ContainsFunc(bl, func(s string) bool {
|
||||
return strings.EqualFold(s, value)
|
||||
})
|
||||
}
|
||||
|
||||
// IsBlocked reports whether value is blocked.
|
||||
func (bl BlackList) IsBlocked(value string) bool {
|
||||
return bl.IsBlockAll() || bl.Contains(value)
|
||||
}
|
||||
|
||||
// IsEmpty reports whether the blacklist has no entries (nothing is blocked).
|
||||
func (bl BlackList) IsEmpty() bool {
|
||||
return len(bl) == 0
|
||||
}
|
||||
|
||||
// IsBlockAll reports whether the blacklist contains "*", meaning all values are blocked.
|
||||
func (bl BlackList) IsBlockAll() bool {
|
||||
return len(bl) == 1 && bl[0] == "*"
|
||||
}
|
||||
|
||||
// Validate checks that the blacklist is well-formed.
|
||||
func (bl BlackList) Validate() error {
|
||||
if bl.Contains("*") && len(bl) > 1 {
|
||||
return fmt.Errorf("wildcard '*' cannot be used with other values in the blacklist")
|
||||
}
|
||||
seen := make(map[string]struct{}, len(bl))
|
||||
for _, v := range bl {
|
||||
normalized := strings.ToLower(v)
|
||||
if _, ok := seen[normalized]; ok {
|
||||
return fmt.Errorf("duplicate value '%s' in blacklist", v)
|
||||
}
|
||||
seen[normalized] = struct{}{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Key represents an API key and its associated configuration for a provider.
|
||||
// It contains the key value, supported models, and a weight for load balancing.
|
||||
type Key struct {
|
||||
ID string `json:"id"` // The unique identifier for the key (used by bifrost to identify the key)
|
||||
Name string `json:"name"` // The name of the key (used by users to identify the key, not used by bifrost)
|
||||
Value EnvVar `json:"value"` // The actual API key value
|
||||
Models WhiteList `json:"models"` // List of models this key can access
|
||||
BlacklistedModels BlackList `json:"blacklisted_models"` // List of models this key cannot access
|
||||
Weight float64 `json:"weight"` // Weight for load balancing between multiple keys
|
||||
Aliases KeyAliases `json:"aliases,omitempty"` // Mapping of model identifiers to inference profiles
|
||||
AzureKeyConfig *AzureKeyConfig `json:"azure_key_config,omitempty"` // Azure-specific key configuration
|
||||
VertexKeyConfig *VertexKeyConfig `json:"vertex_key_config,omitempty"` // Vertex-specific key configuration
|
||||
BedrockKeyConfig *BedrockKeyConfig `json:"bedrock_key_config,omitempty"` // AWS Bedrock-specific key configuration
|
||||
VLLMKeyConfig *VLLMKeyConfig `json:"vllm_key_config,omitempty"` // vLLM-specific key configuration
|
||||
ReplicateKeyConfig *ReplicateKeyConfig `json:"replicate_key_config,omitempty"` // Replicate-specific key configuration
|
||||
OllamaKeyConfig *OllamaKeyConfig `json:"ollama_key_config,omitempty"` // Ollama-specific key configuration
|
||||
SGLKeyConfig *SGLKeyConfig `json:"sgl_key_config,omitempty"` // SGLang-specific key configuration
|
||||
Enabled *bool `json:"enabled,omitempty"` // Whether the key is active (default:true)
|
||||
UseForBatchAPI *bool `json:"use_for_batch_api,omitempty"` // Whether this key can be used for batch API operations (default:false for new keys, migrated keys default to true)
|
||||
ConfigHash string `json:"config_hash,omitempty"` // Hash of config.json version, used for change detection
|
||||
Status KeyStatusType `json:"status,omitempty"` // Status of key
|
||||
Description string `json:"description,omitempty"` // Description of key
|
||||
}
|
||||
|
||||
type KeyAliases map[string]string
|
||||
|
||||
func (ka KeyAliases) Validate() error {
|
||||
seen := make(map[string]struct{}, len(ka))
|
||||
for from, to := range ka {
|
||||
if strings.TrimSpace(from) == "" {
|
||||
return fmt.Errorf("alias source cannot be empty")
|
||||
}
|
||||
if strings.TrimSpace(to) == "" {
|
||||
return fmt.Errorf("alias target for %q cannot be empty", from)
|
||||
}
|
||||
if strings.TrimSpace(from) != from {
|
||||
return fmt.Errorf("alias source %q cannot have leading or trailing whitespace", from)
|
||||
}
|
||||
if strings.TrimSpace(to) != to {
|
||||
return fmt.Errorf("alias target for %q cannot have leading or trailing whitespace", from)
|
||||
}
|
||||
normalized := strings.ToLower(from)
|
||||
if _, ok := seen[normalized]; ok {
|
||||
return fmt.Errorf("duplicate alias source %q (case-insensitive)", from)
|
||||
}
|
||||
seen[normalized] = struct{}{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ka KeyAliases) Resolve(model string) string {
|
||||
if ka == nil {
|
||||
return model
|
||||
}
|
||||
if alias, ok := ka[model]; ok {
|
||||
return alias
|
||||
}
|
||||
// Fall back to case-insensitive lookup for consistency with WhiteList.Contains
|
||||
for k, v := range ka {
|
||||
if strings.EqualFold(k, model) {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
type AzureAuthType string
|
||||
|
||||
const (
|
||||
AzureAuthTypeClientSecret AzureAuthType = "client_secret"
|
||||
AzureAuthTypeManagedIdentity AzureAuthType = "managed_identity"
|
||||
)
|
||||
|
||||
// AzureKeyConfig represents the Azure-specific configuration.
|
||||
// It contains Azure-specific settings required for service access and deployment management.
|
||||
type AzureKeyConfig struct {
|
||||
Endpoint EnvVar `json:"endpoint"` // Azure service endpoint URL
|
||||
APIVersion *EnvVar `json:"api_version,omitempty"` // Azure API version to use; defaults to "2024-10-21"
|
||||
|
||||
ClientID *EnvVar `json:"client_id,omitempty"` // Azure client ID for authentication
|
||||
ClientSecret *EnvVar `json:"client_secret,omitempty"` // Azure client secret for authentication
|
||||
TenantID *EnvVar `json:"tenant_id,omitempty"` // Azure tenant ID for authentication
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
}
|
||||
|
||||
// VertexKeyConfig represents the Vertex-specific configuration.
|
||||
// It contains Vertex-specific settings required for authentication and service access.
|
||||
type VertexKeyConfig struct {
|
||||
ProjectID EnvVar `json:"project_id"`
|
||||
ProjectNumber EnvVar `json:"project_number"`
|
||||
Region EnvVar `json:"region"`
|
||||
AuthCredentials EnvVar `json:"auth_credentials"`
|
||||
}
|
||||
|
||||
// NOTE: To use Vertex IAM role authentication, set AuthCredentials to empty string.
|
||||
|
||||
// S3BucketConfig represents a single S3 bucket configuration for batch operations.
|
||||
type S3BucketConfig struct {
|
||||
BucketName string `json:"bucket_name"` // S3 bucket name
|
||||
Prefix string `json:"prefix,omitempty"` // S3 key prefix for batch files
|
||||
IsDefault bool `json:"is_default,omitempty"` // Whether this is the default bucket for batch operations
|
||||
}
|
||||
|
||||
// BatchS3Config holds S3 bucket configurations for Bedrock batch operations.
|
||||
// Supports multiple buckets to allow flexible batch job routing.
|
||||
type BatchS3Config struct {
|
||||
Buckets []S3BucketConfig `json:"buckets,omitempty"` // List of S3 bucket configurations
|
||||
}
|
||||
|
||||
// BedrockKeyConfig represents the AWS Bedrock-specific configuration.
|
||||
// It contains AWS-specific settings required for authentication and service access.
|
||||
type BedrockKeyConfig struct {
|
||||
AccessKey EnvVar `json:"access_key,omitempty"` // AWS access key for authentication
|
||||
SecretKey EnvVar `json:"secret_key,omitempty"` // AWS secret access key for authentication
|
||||
SessionToken *EnvVar `json:"session_token,omitempty"` // AWS session token for temporary credentials
|
||||
Region *EnvVar `json:"region,omitempty"` // AWS region for service access
|
||||
ARN *EnvVar `json:"arn,omitempty"` // Amazon Resource Name for resource identification
|
||||
// IAM role for STS AssumeRole
|
||||
RoleARN *EnvVar `json:"role_arn,omitempty"`
|
||||
ExternalID *EnvVar `json:"external_id,omitempty"`
|
||||
RoleSessionName *EnvVar `json:"session_name,omitempty"`
|
||||
|
||||
BatchS3Config *BatchS3Config `json:"batch_s3_config,omitempty"` // S3 bucket configuration for batch operations
|
||||
}
|
||||
|
||||
// NOTE: To use Bedrock IAM role authentication, set both AccessKey and SecretKey to empty strings.
|
||||
// To use Bedrock API Key authentication, set Value in Key struct instead.
|
||||
|
||||
// VLLMKeyConfig represents the vLLM-specific key configuration.
|
||||
// It allows each key to target a different vLLM server URL and model name,
|
||||
// enabling per-key routing and round-robin load balancing across multiple vLLM instances.
|
||||
type VLLMKeyConfig struct {
|
||||
URL EnvVar `json:"url"` // VLLM server base URL (required, supports env. prefix)
|
||||
ModelName string `json:"model_name"` // Exact model name served on this VLLM instance (used for key selection)
|
||||
}
|
||||
|
||||
// ReplicateKeyConfig represents the Replicate-specific key configuration.
|
||||
// It contains Replicate-specific settings required for authentication and service access.
|
||||
type ReplicateKeyConfig struct {
|
||||
UseDeploymentsEndpoint bool `json:"use_deployments_endpoint"` // Whether to use the deployments endpoint instead of the models endpoint
|
||||
}
|
||||
|
||||
// OllamaKeyConfig represents the Ollama-specific key configuration.
|
||||
// It allows each key to target a different Ollama server URL,
|
||||
// enabling per-key routing and round-robin load balancing across multiple Ollama instances.
|
||||
type OllamaKeyConfig struct {
|
||||
URL EnvVar `json:"url"` // Ollama server base URL (required, supports env. prefix)
|
||||
}
|
||||
|
||||
// SGLKeyConfig represents the SGLang-specific key configuration.
|
||||
// It allows each key to target a different SGLang server URL,
|
||||
// enabling per-key routing and round-robin load balancing across multiple SGLang instances.
|
||||
type SGLKeyConfig struct {
|
||||
URL EnvVar `json:"url"` // SGLang server base URL (required, supports env. prefix)
|
||||
}
|
||||
|
||||
// Account defines the interface for managing provider accounts and their configurations.
|
||||
// It provides methods to access provider-specific settings, API keys, and configurations.
|
||||
type Account interface {
|
||||
// GetConfiguredProviders returns a list of providers that are configured
|
||||
// in the account. This is used to determine which providers are available for use.
|
||||
GetConfiguredProviders() ([]ModelProvider, error)
|
||||
|
||||
// GetKeysForProvider returns the API keys configured for a specific provider.
|
||||
// The keys include their values, supported models, and weights for load balancing.
|
||||
// The context can carry data from any source that sets values before the Bifrost request,
|
||||
// including but not limited to plugin pre-hooks, application logic, or any in app middleware sharing the context.
|
||||
// This enables dynamic key selection based on any context values present during the request.
|
||||
GetKeysForProvider(ctx context.Context, providerKey ModelProvider) ([]Key, error)
|
||||
|
||||
// GetConfigForProvider returns the configuration for a specific provider.
|
||||
// This includes network settings, authentication details, and other provider-specific
|
||||
// configurations.
|
||||
GetConfigForProvider(providerKey ModelProvider) (*ProviderConfig, error)
|
||||
}
|
||||
34
core/schemas/async.go
Normal file
34
core/schemas/async.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package schemas
|
||||
|
||||
import "time"
|
||||
|
||||
// AsyncJobStatus represents the status of an async job
|
||||
type AsyncJobStatus string
|
||||
|
||||
const (
|
||||
AsyncJobStatusPending AsyncJobStatus = "pending"
|
||||
AsyncJobStatusProcessing AsyncJobStatus = "processing"
|
||||
AsyncJobStatusCompleted AsyncJobStatus = "completed"
|
||||
AsyncJobStatusFailed AsyncJobStatus = "failed"
|
||||
)
|
||||
|
||||
const (
|
||||
// AsyncHeaderResultTTL is the header containing the result TTL for async job retrieval.
|
||||
AsyncHeaderResultTTL = "x-bf-async-job-result-ttl"
|
||||
// AsyncHeaderCreate is the header that triggers async job creation on integration routes.
|
||||
AsyncHeaderCreate = "x-bf-async"
|
||||
// AsyncHeaderGetID is the header containing the job ID for async job retrieval on integration routes.
|
||||
AsyncHeaderGetID = "x-bf-async-id"
|
||||
)
|
||||
|
||||
// AsyncJobResponse is the JSON response returned when creating or polling an async job
|
||||
type AsyncJobResponse struct {
|
||||
ID string `json:"id"`
|
||||
Status AsyncJobStatus `json:"status"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
CompletedAt *time.Time `json:"completed_at,omitempty"`
|
||||
StatusCode int `json:"status_code,omitempty"`
|
||||
Result interface{} `json:"result,omitempty"`
|
||||
Error *BifrostError `json:"error,omitempty"`
|
||||
}
|
||||
329
core/schemas/batch.go
Normal file
329
core/schemas/batch.go
Normal file
@@ -0,0 +1,329 @@
|
||||
// Package schemas defines the core schemas and types used by the Bifrost system.
|
||||
package schemas
|
||||
|
||||
// BatchStatus represents the status of a batch job.
|
||||
type BatchStatus string
|
||||
|
||||
const (
|
||||
BatchStatusValidating BatchStatus = "validating"
|
||||
BatchStatusFailed BatchStatus = "failed"
|
||||
BatchStatusInProgress BatchStatus = "in_progress"
|
||||
BatchStatusFinalizing BatchStatus = "finalizing"
|
||||
BatchStatusCompleted BatchStatus = "completed"
|
||||
BatchStatusExpired BatchStatus = "expired"
|
||||
BatchStatusCancelling BatchStatus = "cancelling"
|
||||
BatchStatusCancelled BatchStatus = "cancelled"
|
||||
BatchStatusEnded BatchStatus = "ended" // Anthropic-specific
|
||||
BatchStatusDeleted BatchStatus = "deleted" // Gemini-specific
|
||||
)
|
||||
|
||||
// BatchEndpoint represents supported batch API endpoints.
|
||||
type BatchEndpoint string
|
||||
|
||||
const (
|
||||
BatchEndpointChatCompletions BatchEndpoint = "/v1/chat/completions"
|
||||
BatchEndpointEmbeddings BatchEndpoint = "/v1/embeddings"
|
||||
BatchEndpointCompletions BatchEndpoint = "/v1/completions"
|
||||
BatchEndpointResponses BatchEndpoint = "/v1/responses"
|
||||
BatchEndpointMessages BatchEndpoint = "/v1/messages" // Anthropic
|
||||
)
|
||||
|
||||
// BatchRequestItem represents a single request in a batch (for inline requests).
|
||||
type BatchRequestItem struct {
|
||||
CustomID string `json:"custom_id"` // User-provided unique ID for this request
|
||||
Method string `json:"method,omitempty"` // HTTP method (typically "POST")
|
||||
URL string `json:"url,omitempty"` // Endpoint URL (e.g., "/v1/chat/completions")
|
||||
Body map[string]interface{} `json:"body,omitempty"` // Request body parameters
|
||||
Params map[string]interface{} `json:"params,omitempty"` // Alternative to Body for Anthropic
|
||||
}
|
||||
|
||||
// BatchRequestCounts tracks the counts of requests in different states.
|
||||
type BatchRequestCounts struct {
|
||||
Total int `json:"total"`
|
||||
Completed int `json:"completed"`
|
||||
Failed int `json:"failed"`
|
||||
Succeeded int `json:"succeeded,omitempty"` // Anthropic-specific
|
||||
Expired int `json:"expired,omitempty"` // Anthropic-specific
|
||||
Canceled int `json:"canceled,omitempty"` // Anthropic-specific
|
||||
Pending int `json:"pending,omitempty"` // Anthropic-specific
|
||||
}
|
||||
|
||||
// BatchErrors represents errors encountered during batch processing.
|
||||
type BatchErrors struct {
|
||||
Object string `json:"object,omitempty"`
|
||||
Data []BatchError `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// BatchError represents a single error in batch processing.
|
||||
type BatchError struct {
|
||||
Code string `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Param string `json:"param,omitempty"`
|
||||
Line *int `json:"line,omitempty"`
|
||||
}
|
||||
|
||||
// BifrostBatchCreateRequest represents a request to create a batch job.
|
||||
type BifrostBatchCreateRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
Model *string `json:"model,omitempty"` // Model hint for routing (optional for file-based) it may or may not present depending on the provider and usage of integration vs direct API
|
||||
RawRequestBody []byte `json:"-"` // Raw request body (not serialized)
|
||||
|
||||
// OpenAI-style: file-based batching
|
||||
InputFileID string `json:"input_file_id,omitempty"` // ID of uploaded JSONL file
|
||||
|
||||
// Anthropic-style: inline requests
|
||||
Requests []BatchRequestItem `json:"requests,omitempty"` // Inline request items
|
||||
|
||||
// Common fields
|
||||
Endpoint BatchEndpoint `json:"endpoint,omitempty"` // Target endpoint for batch requests
|
||||
CompletionWindow string `json:"completion_window,omitempty"` // Time window (e.g., "24h")
|
||||
Metadata map[string]string `json:"metadata,omitempty"` // User-provided metadata
|
||||
OutputExpiresAfter *BatchExpiresAfter `json:"output_expires_after,omitempty"` // Expiration for batch output (OpenAI only)
|
||||
|
||||
// Extra parameters for provider-specific features
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// BatchExpiresAfter represents an expiration configuration for batch output.
|
||||
type BatchExpiresAfter struct {
|
||||
Anchor string `json:"anchor"` // e.g., "created_at"
|
||||
Seconds int `json:"seconds"` // 3600-2592000 (1 hour to 30 days)
|
||||
}
|
||||
|
||||
// GetRawRequestBody returns the raw request body.
|
||||
func (request *BifrostBatchCreateRequest) GetRawRequestBody() []byte {
|
||||
return request.RawRequestBody
|
||||
}
|
||||
|
||||
// BifrostBatchCreateResponse represents the response from creating a batch job.
|
||||
type BifrostBatchCreateResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object,omitempty"` // "batch" for OpenAI
|
||||
Endpoint string `json:"endpoint,omitempty"`
|
||||
InputFileID string `json:"input_file_id,omitempty"`
|
||||
CompletionWindow string `json:"completion_window,omitempty"`
|
||||
Status BatchStatus `json:"status"`
|
||||
RequestCounts BatchRequestCounts `json:"request_counts,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
CreatedAt int64 `json:"created_at,omitempty"`
|
||||
ExpiresAt *int64 `json:"expires_at,omitempty"`
|
||||
|
||||
// Output file references (OpenAI)
|
||||
OutputFileID *string `json:"output_file_id,omitempty"`
|
||||
ErrorFileID *string `json:"error_file_id,omitempty"`
|
||||
|
||||
// Anthropic-specific
|
||||
ProcessingStatus *string `json:"processing_status,omitempty"`
|
||||
ResultsURL *string `json:"results_url,omitempty"`
|
||||
|
||||
// Gemini-specific (operation response)
|
||||
OperationName *string `json:"operation_name,omitempty"`
|
||||
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
|
||||
// BifrostBatchListRequest represents a request to list batch jobs.
|
||||
type BifrostBatchListRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
Model *string `json:"model"`
|
||||
|
||||
// Pagination
|
||||
Limit int `json:"limit,omitempty"` // Max results to return
|
||||
After *string `json:"after,omitempty"` // Cursor for pagination (OpenAI)
|
||||
BeforeID *string `json:"before_id,omitempty"` // Pagination cursor (Anthropic)
|
||||
AfterID *string `json:"after_id,omitempty"` // Pagination cursor (Anthropic)
|
||||
PageToken *string `json:"page_token,omitempty"` // For Gemini pagination
|
||||
PageSize int `json:"page_size,omitempty"` // For Gemini pagination
|
||||
NextCursor *string `json:"next_cursor,omitempty"` // For Gemini pagination
|
||||
|
||||
// Extra parameters for provider-specific features
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// BifrostBatchListResponse represents the response from listing batch jobs.
|
||||
type BifrostBatchListResponse struct {
|
||||
Object string `json:"object,omitempty"` // "list"
|
||||
Data []BifrostBatchRetrieveResponse `json:"data"`
|
||||
FirstID *string `json:"first_id,omitempty"`
|
||||
LastID *string `json:"last_id,omitempty"`
|
||||
HasMore bool `json:"has_more,omitempty"`
|
||||
|
||||
// Anthropic pagination
|
||||
NextCursor *string `json:"next_cursor,omitempty"` // For cursor-based pagination
|
||||
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
|
||||
// BifrostBatchRetrieveRequest represents a request to retrieve a batch job.
|
||||
type BifrostBatchRetrieveRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
Model *string `json:"model"`
|
||||
BatchID string `json:"batch_id"` // ID of the batch to retrieve
|
||||
|
||||
RawRequestBody []byte `json:"-"` // Raw request body (not serialized)
|
||||
|
||||
// Extra parameters for provider-specific features
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// GetRawRequestBody returns the raw request body.
|
||||
func (request *BifrostBatchRetrieveRequest) GetRawRequestBody() []byte {
|
||||
return request.RawRequestBody
|
||||
}
|
||||
|
||||
// BifrostBatchRetrieveResponse represents the response from retrieving a batch job.
|
||||
type BifrostBatchRetrieveResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object,omitempty"`
|
||||
Endpoint string `json:"endpoint,omitempty"`
|
||||
InputFileID string `json:"input_file_id,omitempty"`
|
||||
CompletionWindow string `json:"completion_window,omitempty"`
|
||||
Status BatchStatus `json:"status"`
|
||||
RequestCounts BatchRequestCounts `json:"request_counts,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
CreatedAt int64 `json:"created_at,omitempty"`
|
||||
ExpiresAt *int64 `json:"expires_at,omitempty"`
|
||||
InProgressAt *int64 `json:"in_progress_at,omitempty"`
|
||||
FinalizingAt *int64 `json:"finalizing_at,omitempty"`
|
||||
CompletedAt *int64 `json:"completed_at,omitempty"`
|
||||
FailedAt *int64 `json:"failed_at,omitempty"`
|
||||
ExpiredAt *int64 `json:"expired_at,omitempty"`
|
||||
CancellingAt *int64 `json:"cancelling_at,omitempty"`
|
||||
CancelledAt *int64 `json:"cancelled_at,omitempty"`
|
||||
|
||||
// Output references
|
||||
OutputFileID *string `json:"output_file_id,omitempty"`
|
||||
ErrorFileID *string `json:"error_file_id,omitempty"`
|
||||
Errors *BatchErrors `json:"errors,omitempty"`
|
||||
|
||||
// Anthropic-specific
|
||||
ProcessingStatus *string `json:"processing_status,omitempty"`
|
||||
ResultsURL *string `json:"results_url,omitempty"`
|
||||
ArchivedAt *int64 `json:"archived_at,omitempty"`
|
||||
|
||||
// Gemini-specific
|
||||
OperationName *string `json:"operation_name,omitempty"`
|
||||
Done *bool `json:"done,omitempty"`
|
||||
Progress *int `json:"progress,omitempty"` // Percentage progress
|
||||
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
|
||||
// BifrostBatchCancelRequest represents a request to cancel a batch job.
|
||||
type BifrostBatchCancelRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
Model *string `json:"model"`
|
||||
BatchID string `json:"batch_id"` // ID of the batch to cancel
|
||||
|
||||
RawRequestBody []byte `json:"-"` // Raw request body (not serialized)
|
||||
|
||||
// Extra parameters for provider-specific features
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// GetRawRequestBody returns the raw request body.
|
||||
func (request *BifrostBatchCancelRequest) GetRawRequestBody() []byte {
|
||||
return request.RawRequestBody
|
||||
}
|
||||
|
||||
// BifrostBatchCancelResponse represents the response from cancelling a batch job.
|
||||
type BifrostBatchCancelResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object,omitempty"`
|
||||
Status BatchStatus `json:"status"`
|
||||
RequestCounts BatchRequestCounts `json:"request_counts,omitempty"`
|
||||
CancellingAt *int64 `json:"cancelling_at,omitempty"`
|
||||
CancelledAt *int64 `json:"cancelled_at,omitempty"`
|
||||
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
|
||||
// BifrostBatchDeleteRequest represents a request to delete a batch job.
|
||||
type BifrostBatchDeleteRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
Model *string `json:"model"`
|
||||
BatchID string `json:"batch_id"` // ID of the batch to delete
|
||||
|
||||
RawRequestBody []byte `json:"-"` // Raw request body (not serialized)
|
||||
|
||||
// Extra parameters for provider-specific features
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// GetRawRequestBody returns the raw request body.
|
||||
func (request *BifrostBatchDeleteRequest) GetRawRequestBody() []byte {
|
||||
return request.RawRequestBody
|
||||
}
|
||||
|
||||
// BifrostBatchDeleteResponse represents the response from deleting a batch job.
|
||||
type BifrostBatchDeleteResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object,omitempty"`
|
||||
Status BatchStatus `json:"status"`
|
||||
RequestCounts BatchRequestCounts `json:"request_counts,omitempty"`
|
||||
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
|
||||
// BifrostBatchResultsRequest represents a request to retrieve batch results.
|
||||
type BifrostBatchResultsRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
Model *string `json:"model"`
|
||||
BatchID string `json:"batch_id"` // ID of the batch to get results for
|
||||
|
||||
RawRequestBody []byte `json:"-"` // Raw request body (not serialized)
|
||||
|
||||
// For OpenAI, results are retrieved via output_file_id (file download)
|
||||
// For Anthropic, results are streamed from a dedicated endpoint
|
||||
|
||||
// Extra parameters for provider-specific features
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// GetRawRequestBody returns the raw request body.
|
||||
func (request *BifrostBatchResultsRequest) GetRawRequestBody() []byte {
|
||||
return request.RawRequestBody
|
||||
}
|
||||
|
||||
// BatchResultItem represents a single result from a batch request.
|
||||
type BatchResultItem struct {
|
||||
CustomID string `json:"custom_id"`
|
||||
|
||||
// Result data (varies by request type)
|
||||
Response *BatchResultResponse `json:"response,omitempty"` // OpenAI format
|
||||
Result *BatchResultData `json:"result,omitempty"` // Anthropic format
|
||||
|
||||
// Error if the individual request failed
|
||||
Error *BatchResultError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// BatchResultResponse represents OpenAI-style result response.
|
||||
type BatchResultResponse struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
Body map[string]interface{} `json:"body,omitempty"`
|
||||
}
|
||||
|
||||
// BatchResultData represents Anthropic-style result data.
|
||||
type BatchResultData struct {
|
||||
Type string `json:"type"` // "succeeded", "errored", "expired", "canceled"
|
||||
Message map[string]interface{} `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// BatchResultError represents an error for a single batch request.
|
||||
type BatchResultError struct {
|
||||
Code string `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// BifrostBatchResultsResponse represents the response from retrieving batch results.
|
||||
type BifrostBatchResultsResponse struct {
|
||||
BatchID string `json:"batch_id"`
|
||||
Results []BatchResultItem `json:"results"`
|
||||
|
||||
// For streaming results (Anthropic)
|
||||
HasMore bool `json:"has_more,omitempty"`
|
||||
NextCursor *string `json:"next_cursor,omitempty"`
|
||||
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
1287
core/schemas/bifrost.go
Normal file
1287
core/schemas/bifrost.go
Normal file
File diff suppressed because it is too large
Load Diff
1612
core/schemas/chatcompletions.go
Normal file
1612
core/schemas/chatcompletions.go
Normal file
File diff suppressed because it is too large
Load Diff
255
core/schemas/containers.go
Normal file
255
core/schemas/containers.go
Normal file
@@ -0,0 +1,255 @@
|
||||
// Package schemas defines the core schemas and types used by the Bifrost system.
|
||||
package schemas
|
||||
|
||||
// ContainerStatus represents the status of a container.
|
||||
type ContainerStatus string
|
||||
|
||||
const (
|
||||
ContainerStatusRunning ContainerStatus = "running"
|
||||
)
|
||||
|
||||
// ContainerExpiresAfter represents the expiration configuration for a container.
|
||||
type ContainerExpiresAfter struct {
|
||||
Anchor string `json:"anchor"` // The anchor point for expiration (e.g., "last_active_at")
|
||||
Minutes int `json:"minutes"` // Number of minutes after anchor point
|
||||
}
|
||||
|
||||
// ContainerObject represents a container object returned by the API.
|
||||
type ContainerObject struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object,omitempty"` // "container"
|
||||
Name string `json:"name"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Status ContainerStatus `json:"status,omitempty"`
|
||||
ExpiresAfter *ContainerExpiresAfter `json:"expires_after,omitempty"`
|
||||
LastActiveAt *int64 `json:"last_active_at,omitempty"`
|
||||
MemoryLimit string `json:"memory_limit,omitempty"` // e.g., "1g", "4g"
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// BifrostContainerCreateRequest represents a request to create a container.
|
||||
type BifrostContainerCreateRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
|
||||
// Required fields
|
||||
Name string `json:"name"` // Name of the container
|
||||
|
||||
// Optional fields
|
||||
ExpiresAfter *ContainerExpiresAfter `json:"expires_after,omitempty"` // Expiration configuration
|
||||
FileIDs []string `json:"file_ids,omitempty"` // IDs of existing files to copy into this container
|
||||
MemoryLimit string `json:"memory_limit,omitempty"` // Memory limit (e.g., "1g", "4g")
|
||||
Metadata map[string]string `json:"metadata,omitempty"` // User-provided metadata
|
||||
|
||||
// Extra parameters for provider-specific features
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// BifrostContainerCreateResponse represents the response from creating a container.
|
||||
type BifrostContainerCreateResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object,omitempty"` // "container"
|
||||
Name string `json:"name"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Status ContainerStatus `json:"status,omitempty"`
|
||||
ExpiresAfter *ContainerExpiresAfter `json:"expires_after,omitempty"`
|
||||
LastActiveAt *int64 `json:"last_active_at,omitempty"`
|
||||
MemoryLimit string `json:"memory_limit,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
|
||||
// BifrostContainerListRequest represents a request to list containers.
|
||||
type BifrostContainerListRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
|
||||
// Pagination
|
||||
Limit int `json:"limit,omitempty"` // Max results to return (1-100, default 20)
|
||||
After *string `json:"after,omitempty"` // Cursor for pagination
|
||||
Order *string `json:"order,omitempty"` // Sort order (asc/desc), default desc
|
||||
|
||||
// Extra parameters for provider-specific features
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// BifrostContainerListResponse represents the response from listing containers.
|
||||
type BifrostContainerListResponse struct {
|
||||
Object string `json:"object,omitempty"` // "list"
|
||||
Data []ContainerObject `json:"data"`
|
||||
FirstID *string `json:"first_id,omitempty"`
|
||||
LastID *string `json:"last_id,omitempty"`
|
||||
HasMore bool `json:"has_more,omitempty"`
|
||||
After *string `json:"after,omitempty"` // Encoded cursor for next page (includes key index for multi-key pagination)
|
||||
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
|
||||
// BifrostContainerRetrieveRequest represents a request to retrieve a container.
|
||||
type BifrostContainerRetrieveRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
ContainerID string `json:"container_id"` // ID of the container to retrieve
|
||||
|
||||
// Extra parameters for provider-specific features
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// BifrostContainerRetrieveResponse represents the response from retrieving a container.
|
||||
type BifrostContainerRetrieveResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object,omitempty"` // "container"
|
||||
Name string `json:"name"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Status ContainerStatus `json:"status,omitempty"`
|
||||
ExpiresAfter *ContainerExpiresAfter `json:"expires_after,omitempty"`
|
||||
LastActiveAt *int64 `json:"last_active_at,omitempty"`
|
||||
MemoryLimit string `json:"memory_limit,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
|
||||
// BifrostContainerDeleteRequest represents a request to delete a container.
|
||||
type BifrostContainerDeleteRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
ContainerID string `json:"container_id"` // ID of the container to delete
|
||||
|
||||
// Extra parameters for provider-specific features
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// BifrostContainerDeleteResponse represents the response from deleting a container.
|
||||
type BifrostContainerDeleteResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object,omitempty"` // "container.deleted"
|
||||
Deleted bool `json:"deleted"`
|
||||
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// CONTAINER FILES API
|
||||
// =============================================================================
|
||||
|
||||
// ContainerFileObject represents a file within a container.
|
||||
type ContainerFileObject struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object,omitempty"` // "container.file"
|
||||
Bytes int64 `json:"bytes"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
ContainerID string `json:"container_id"`
|
||||
Path string `json:"path"`
|
||||
Source string `json:"source"` // "user" typically
|
||||
}
|
||||
|
||||
// BifrostContainerFileCreateRequest represents a request to create a file in a container.
|
||||
type BifrostContainerFileCreateRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
ContainerID string `json:"container_id"` // ID of the container
|
||||
|
||||
// One of these must be provided
|
||||
File []byte `json:"-"` // File content (for multipart upload)
|
||||
FileID *string `json:"file_id,omitempty"` // Reference to existing file
|
||||
Path *string `json:"file_path,omitempty"` // Path for the file in the container
|
||||
|
||||
// Extra parameters for provider-specific features
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// BifrostContainerFileCreateResponse represents the response from creating a container file.
|
||||
type BifrostContainerFileCreateResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object,omitempty"` // "container.file"
|
||||
Bytes int64 `json:"bytes"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
ContainerID string `json:"container_id"`
|
||||
Path string `json:"path"`
|
||||
Source string `json:"source"`
|
||||
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
|
||||
// BifrostContainerFileListRequest represents a request to list files in a container.
|
||||
type BifrostContainerFileListRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
ContainerID string `json:"container_id"` // ID of the container
|
||||
|
||||
// Pagination
|
||||
Limit int `json:"limit,omitempty"` // Max results to return (1-100, default 20)
|
||||
After *string `json:"after,omitempty"` // Cursor for pagination
|
||||
Order *string `json:"order,omitempty"` // Sort order (asc/desc), default desc
|
||||
|
||||
// Extra parameters for provider-specific features
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// BifrostContainerFileListResponse represents the response from listing container files.
|
||||
type BifrostContainerFileListResponse struct {
|
||||
Object string `json:"object,omitempty"` // "list"
|
||||
Data []ContainerFileObject `json:"data"`
|
||||
FirstID *string `json:"first_id,omitempty"`
|
||||
LastID *string `json:"last_id,omitempty"`
|
||||
HasMore bool `json:"has_more,omitempty"`
|
||||
After *string `json:"after,omitempty"` // Encoded cursor for next page (includes key index for multi-key pagination)
|
||||
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
|
||||
// BifrostContainerFileRetrieveRequest represents a request to retrieve a container file.
|
||||
type BifrostContainerFileRetrieveRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
ContainerID string `json:"container_id"` // ID of the container
|
||||
FileID string `json:"file_id"` // ID of the file to retrieve
|
||||
|
||||
// Extra parameters for provider-specific features
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// BifrostContainerFileRetrieveResponse represents the response from retrieving a container file.
|
||||
type BifrostContainerFileRetrieveResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object,omitempty"` // "container.file"
|
||||
Bytes int64 `json:"bytes"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
ContainerID string `json:"container_id"`
|
||||
Path string `json:"path"`
|
||||
Source string `json:"source"`
|
||||
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
|
||||
// BifrostContainerFileContentRequest represents a request to retrieve the content of a container file.
|
||||
type BifrostContainerFileContentRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
ContainerID string `json:"container_id"` // ID of the container
|
||||
FileID string `json:"file_id"` // ID of the file
|
||||
|
||||
// Extra parameters for provider-specific features
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// BifrostContainerFileContentResponse represents the response from retrieving container file content.
|
||||
type BifrostContainerFileContentResponse struct {
|
||||
Content []byte `json:"content"` // Raw file content
|
||||
ContentType string `json:"content_type"` // MIME type of the content
|
||||
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
|
||||
// BifrostContainerFileDeleteRequest represents a request to delete a container file.
|
||||
type BifrostContainerFileDeleteRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
ContainerID string `json:"container_id"` // ID of the container
|
||||
FileID string `json:"file_id"` // ID of the file to delete
|
||||
|
||||
// Extra parameters for provider-specific features
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// BifrostContainerFileDeleteResponse represents the response from deleting a container file.
|
||||
type BifrostContainerFileDeleteResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object,omitempty"` // "container.file.deleted"
|
||||
Deleted bool `json:"deleted"`
|
||||
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
501
core/schemas/context.go
Normal file
501
core/schemas/context.go
Normal file
@@ -0,0 +1,501 @@
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
var NoDeadline time.Time
|
||||
|
||||
var reservedKeys = []any{
|
||||
BifrostContextKeyVirtualKey,
|
||||
BifrostContextKeyAPIKeyName,
|
||||
BifrostContextKeyAPIKeyID,
|
||||
BifrostContextKeyRequestID,
|
||||
BifrostContextKeyFallbackRequestID,
|
||||
BifrostContextKeyDirectKey,
|
||||
BifrostContextKeySelectedKeyID,
|
||||
BifrostContextKeySelectedKeyName,
|
||||
BifrostContextKeyNumberOfRetries,
|
||||
BifrostContextKeyFallbackIndex,
|
||||
BifrostContextKeySkipKeySelection,
|
||||
BifrostContextKeyURLPath,
|
||||
BifrostContextKeyDeferTraceCompletion,
|
||||
BifrostContextKeyAttemptTrail,
|
||||
}
|
||||
|
||||
// pluginLogStore holds plugin log entries accumulated during request processing.
|
||||
// It is shared between the root BifrostContext and all scoped contexts derived from it.
|
||||
// Uses a flat slice (not map) to minimize heap allocations.
|
||||
type pluginLogStore struct {
|
||||
mu sync.Mutex
|
||||
logs []PluginLogEntry
|
||||
}
|
||||
|
||||
// pluginLogStorePool pools pluginLogStore instances to reduce per-request allocations.
|
||||
var pluginLogStorePool = sync.Pool{
|
||||
New: func() any {
|
||||
return &pluginLogStore{logs: make([]PluginLogEntry, 0, 8)}
|
||||
},
|
||||
}
|
||||
|
||||
// pluginScopePool pools BifrostContext instances used as scoped plugin contexts.
|
||||
var pluginScopePool = sync.Pool{
|
||||
New: func() any {
|
||||
return &BifrostContext{}
|
||||
},
|
||||
}
|
||||
|
||||
// BifrostContext is a custom context.Context implementation that tracks user-set values.
|
||||
// It supports deadlines, can be derived from other contexts, and provides layered
|
||||
// value inheritance when derived from another BifrostContext.
|
||||
type BifrostContext struct {
|
||||
parent context.Context
|
||||
deadline time.Time
|
||||
hasDeadline bool
|
||||
done chan struct{}
|
||||
doneOnce sync.Once
|
||||
err error
|
||||
errMu sync.RWMutex
|
||||
userValues map[any]any
|
||||
valuesMu sync.RWMutex
|
||||
blockRestrictedWrites atomic.Bool
|
||||
|
||||
// Plugin scoping fields
|
||||
pluginScope *string // Non-nil when this is a scoped plugin context
|
||||
pluginLogs atomic.Pointer[pluginLogStore] // Shared log store; lazily initialized on root, shared by scoped contexts
|
||||
valueDelegate *BifrostContext // For scoped contexts: delegate Value/SetValue to this root context
|
||||
}
|
||||
|
||||
// NewBifrostContext creates a new BifrostContext with the given parent context and deadline.
|
||||
// If the deadline is zero, no deadline is set on this context (though the parent may have one).
|
||||
// The context will be cancelled when the deadline expires or when the parent context is cancelled.
|
||||
func NewBifrostContext(parent context.Context, deadline time.Time) *BifrostContext {
|
||||
if parent == nil {
|
||||
parent = context.Background()
|
||||
}
|
||||
ctx := &BifrostContext{
|
||||
parent: parent,
|
||||
deadline: deadline,
|
||||
hasDeadline: !deadline.IsZero(),
|
||||
done: make(chan struct{}),
|
||||
userValues: make(map[any]any),
|
||||
blockRestrictedWrites: atomic.Bool{},
|
||||
}
|
||||
ctx.blockRestrictedWrites.Store(false)
|
||||
// Only start goroutine if there's something to watch:
|
||||
// - If we have a deadline, we need the timer
|
||||
// - If parent can be cancelled (Done() != nil) AND is not a non-cancelling context
|
||||
// - If parent has a deadline, we need a timer (parent may not properly cancel via Done())
|
||||
_, parentHasDeadline := parent.Deadline()
|
||||
parentCanCancel := parent.Done() != nil && !isNonCancellingContext(parent)
|
||||
if ctx.hasDeadline || parentCanCancel || parentHasDeadline {
|
||||
go ctx.watchCancellation()
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
// NewBifrostContextWithValue creates a new BifrostContext with the given value set.
|
||||
func NewBifrostContextWithValue(parent context.Context, deadline time.Time, key any, value any) *BifrostContext {
|
||||
ctx := NewBifrostContext(parent, deadline)
|
||||
ctx.SetValue(key, value)
|
||||
return ctx
|
||||
}
|
||||
|
||||
// NewBifrostContextWithTimeout creates a new BifrostContext with a timeout duration.
|
||||
// This is a convenience wrapper around NewBifrostContext.
|
||||
// Returns the context and a cancel function that should be called to release resources.
|
||||
func NewBifrostContextWithTimeout(parent context.Context, timeout time.Duration) (*BifrostContext, context.CancelFunc) {
|
||||
ctx := NewBifrostContext(parent, time.Now().Add(timeout))
|
||||
return ctx, func() { ctx.Cancel() }
|
||||
}
|
||||
|
||||
// NewBifrostContextWithCancel creates a new BifrostContext with a cancel function.
|
||||
// This is a convenience wrapper around NewBifrostContext.
|
||||
// Returns the context and a cancel function that should be called to release resources.
|
||||
func NewBifrostContextWithCancel(parent context.Context) (*BifrostContext, context.CancelFunc) {
|
||||
ctx := NewBifrostContext(parent, NoDeadline)
|
||||
return ctx, func() { ctx.Cancel() }
|
||||
}
|
||||
|
||||
// WithValue returns a new context with the given value set.
|
||||
func (bc *BifrostContext) WithValue(key any, value any) *BifrostContext {
|
||||
bc.SetValue(key, value)
|
||||
return bc
|
||||
}
|
||||
|
||||
// BlockRestrictedWrites returns true if restricted writes are blocked.
|
||||
func (bc *BifrostContext) BlockRestrictedWrites() {
|
||||
bc.blockRestrictedWrites.Store(true)
|
||||
}
|
||||
|
||||
// UnblockRestrictedWrites unblocks restricted writes.
|
||||
func (bc *BifrostContext) UnblockRestrictedWrites() {
|
||||
bc.blockRestrictedWrites.Store(false)
|
||||
}
|
||||
|
||||
// Cancel cancels the context, closing the Done channel and setting the error to context.Canceled.
|
||||
func (bc *BifrostContext) Cancel() {
|
||||
bc.cancel(context.Canceled)
|
||||
}
|
||||
|
||||
// watchCancellation monitors for deadline expiration and parent cancellation.
|
||||
func (bc *BifrostContext) watchCancellation() {
|
||||
var timer <-chan time.Time
|
||||
|
||||
// Use effective deadline (considers both own and parent deadlines)
|
||||
// This handles cases where parent has a deadline but doesn't properly
|
||||
// cancel via Done() (e.g., fasthttp.RequestCtx)
|
||||
if effectiveDeadline, hasDeadline := bc.Deadline(); hasDeadline {
|
||||
duration := time.Until(effectiveDeadline)
|
||||
if duration <= 0 {
|
||||
// Deadline already passed
|
||||
bc.cancel(context.DeadlineExceeded)
|
||||
return
|
||||
}
|
||||
t := time.NewTimer(duration)
|
||||
defer t.Stop()
|
||||
timer = t.C
|
||||
}
|
||||
|
||||
// Don't watch parent.Done() for contexts known to never close it
|
||||
// (e.g., fasthttp.RequestCtx pools contexts and never cancels them)
|
||||
if isNonCancellingContext(bc.parent) {
|
||||
select {
|
||||
case <-timer:
|
||||
bc.cancel(context.DeadlineExceeded)
|
||||
case <-bc.done:
|
||||
// Already cancelled
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-bc.parent.Done():
|
||||
bc.cancel(bc.parent.Err())
|
||||
case <-timer:
|
||||
bc.cancel(context.DeadlineExceeded)
|
||||
case <-bc.done:
|
||||
// Already cancelled
|
||||
}
|
||||
}
|
||||
|
||||
// cancel closes the done channel and sets the error.
|
||||
func (bc *BifrostContext) cancel(err error) {
|
||||
bc.doneOnce.Do(func() {
|
||||
bc.errMu.Lock()
|
||||
bc.err = err
|
||||
bc.errMu.Unlock()
|
||||
close(bc.done)
|
||||
})
|
||||
}
|
||||
|
||||
// Deadline returns the deadline for this context.
|
||||
// For scoped contexts, delegates to the root context.
|
||||
// If both this context and the parent have deadlines, the earlier one is returned.
|
||||
func (bc *BifrostContext) Deadline() (time.Time, bool) {
|
||||
if bc.valueDelegate != nil {
|
||||
return bc.valueDelegate.Deadline()
|
||||
}
|
||||
parentDeadline, parentHasDeadline := bc.parent.Deadline()
|
||||
|
||||
if !bc.hasDeadline && !parentHasDeadline {
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
if !bc.hasDeadline {
|
||||
return parentDeadline, true
|
||||
}
|
||||
|
||||
if !parentHasDeadline {
|
||||
return bc.deadline, true
|
||||
}
|
||||
|
||||
// Both have deadlines, return the earlier one
|
||||
if bc.deadline.Before(parentDeadline) {
|
||||
return bc.deadline, true
|
||||
}
|
||||
return parentDeadline, true
|
||||
}
|
||||
|
||||
// Done returns a channel that is closed when the context is cancelled.
|
||||
func (bc *BifrostContext) Done() <-chan struct{} {
|
||||
return bc.done
|
||||
}
|
||||
|
||||
// Err returns the error explaining why the context was cancelled.
|
||||
// For scoped contexts, delegates to the root context.
|
||||
// Returns nil if the context has not been cancelled.
|
||||
func (bc *BifrostContext) Err() error {
|
||||
if bc.valueDelegate != nil {
|
||||
return bc.valueDelegate.Err()
|
||||
}
|
||||
bc.errMu.RLock()
|
||||
defer bc.errMu.RUnlock()
|
||||
return bc.err
|
||||
}
|
||||
|
||||
// Value returns the value associated with the key.
|
||||
// For scoped contexts, delegates to the root context via valueDelegate.
|
||||
// Otherwise checks the internal userValues map, then delegates to the parent context.
|
||||
func (bc *BifrostContext) Value(key any) any {
|
||||
if bc.valueDelegate != nil {
|
||||
return bc.valueDelegate.Value(key)
|
||||
}
|
||||
bc.valuesMu.RLock()
|
||||
if val, ok := bc.userValues[key]; ok {
|
||||
bc.valuesMu.RUnlock()
|
||||
return val
|
||||
}
|
||||
bc.valuesMu.RUnlock()
|
||||
|
||||
if bc.parent == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return bc.parent.Value(key)
|
||||
}
|
||||
|
||||
// SetValue sets a value in the internal userValues map.
|
||||
// For scoped contexts, delegates to the root context via valueDelegate.
|
||||
// This is thread-safe and can be called concurrently.
|
||||
func (bc *BifrostContext) SetValue(key, value any) {
|
||||
if bc.valueDelegate != nil {
|
||||
bc.valueDelegate.SetValue(key, value)
|
||||
return
|
||||
}
|
||||
// Check if the key is a reserved key
|
||||
if bc.blockRestrictedWrites.Load() && slices.Contains(reservedKeys, key) {
|
||||
// we silently drop writes for these reserved keys
|
||||
return
|
||||
}
|
||||
bc.valuesMu.Lock()
|
||||
defer bc.valuesMu.Unlock()
|
||||
if bc.userValues == nil {
|
||||
bc.userValues = make(map[any]any)
|
||||
}
|
||||
bc.userValues[key] = value
|
||||
}
|
||||
|
||||
// ClearValue clears a value from the internal userValues map.
|
||||
// For scoped contexts, delegates to the root context via valueDelegate.
|
||||
func (bc *BifrostContext) ClearValue(key any) {
|
||||
if bc.valueDelegate != nil {
|
||||
bc.valueDelegate.ClearValue(key)
|
||||
return
|
||||
}
|
||||
// Check if the key is a reserved key
|
||||
if bc.blockRestrictedWrites.Load() && slices.Contains(reservedKeys, key) {
|
||||
// we silently drop writes for these reserved keys
|
||||
return
|
||||
}
|
||||
bc.valuesMu.Lock()
|
||||
defer bc.valuesMu.Unlock()
|
||||
if bc.userValues != nil {
|
||||
bc.userValues[key] = nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetAndSetValue gets a value from the internal userValues map and sets it.
|
||||
// For scoped contexts, delegates to the root context via valueDelegate.
|
||||
func (bc *BifrostContext) GetAndSetValue(key any, value any) any {
|
||||
if bc.valueDelegate != nil {
|
||||
return bc.valueDelegate.GetAndSetValue(key, value)
|
||||
}
|
||||
bc.valuesMu.Lock()
|
||||
defer bc.valuesMu.Unlock()
|
||||
// Check if the key is a reserved key
|
||||
if bc.blockRestrictedWrites.Load() && slices.Contains(reservedKeys, key) {
|
||||
// we silently drop writes for these reserved keys
|
||||
return bc.userValues[key]
|
||||
}
|
||||
if bc.userValues == nil {
|
||||
bc.userValues = make(map[any]any)
|
||||
}
|
||||
oldValue := bc.userValues[key]
|
||||
bc.userValues[key] = value
|
||||
return oldValue
|
||||
}
|
||||
|
||||
// GetUserValues returns a copy of all user-set values in this context.
|
||||
// If the parent is also a PluginContext, the values are merged with parent values
|
||||
// (this context's values take precedence over parent values).
|
||||
func (bc *BifrostContext) GetUserValues() map[any]any {
|
||||
result := make(map[any]any)
|
||||
|
||||
// First, get parent's user values if parent is a PluginContext
|
||||
if parentCtx, ok := bc.parent.(*BifrostContext); ok {
|
||||
for k, v := range parentCtx.GetUserValues() {
|
||||
result[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// Then overlay with our own values (our values take precedence)
|
||||
bc.valuesMu.RLock()
|
||||
for k, v := range bc.userValues {
|
||||
result[k] = v
|
||||
}
|
||||
bc.valuesMu.RUnlock()
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// GetParentCtxWithUserValues returns a copy of the parent context with all user-set values merged in.
|
||||
func (bc *BifrostContext) GetParentCtxWithUserValues() context.Context {
|
||||
parentCtx := bc.parent
|
||||
bc.valuesMu.RLock()
|
||||
for k, v := range bc.userValues {
|
||||
parentCtx = context.WithValue(parentCtx, k, v)
|
||||
}
|
||||
bc.valuesMu.RUnlock()
|
||||
return parentCtx
|
||||
}
|
||||
|
||||
// AppendRoutingEngineLog appends a routing engine log entry to the context.
|
||||
// Parameters:
|
||||
// - ctx: The Bifrost context
|
||||
// - engineName: Name of the routing engine (e.g., "governance", "routing-rule")
|
||||
// - message: Human-readable log message describing the decision/action
|
||||
func (bc *BifrostContext) AppendRoutingEngineLog(engineName string, message string) {
|
||||
entry := RoutingEngineLogEntry{
|
||||
Engine: engineName,
|
||||
Message: message,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
}
|
||||
AppendToContextList(bc, BifrostContextKeyRoutingEngineLogs, entry)
|
||||
}
|
||||
|
||||
// GetRoutingEngineLogs retrieves all routing engine logs from the context.
|
||||
// Parameters:
|
||||
// - ctx: The Bifrost context
|
||||
//
|
||||
// Returns:
|
||||
// - []RoutingEngineLogEntry: Slice of routing engine log entries (nil if none)
|
||||
func (bc *BifrostContext) GetRoutingEngineLogs() []RoutingEngineLogEntry {
|
||||
if val := bc.Value(BifrostContextKeyRoutingEngineLogs); val != nil {
|
||||
if logs, ok := val.([]RoutingEngineLogEntry); ok {
|
||||
return logs
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AppendToContextList appends a value to the context list value.
|
||||
// Parameters:
|
||||
// - ctx: The Bifrost context
|
||||
// - key: The key to append the value to
|
||||
// - value: The value to append
|
||||
func AppendToContextList[T any](ctx *BifrostContext, key BifrostContextKey, value T) {
|
||||
if ctx == nil {
|
||||
return
|
||||
}
|
||||
existingValues, ok := ctx.Value(key).([]T)
|
||||
if !ok {
|
||||
existingValues = []T{}
|
||||
}
|
||||
ctx.SetValue(key, append(existingValues, value))
|
||||
}
|
||||
|
||||
// WithPluginScope returns a lightweight scoped BifrostContext from the pool.
|
||||
// The scoped context shares the root's pluginLogs store and delegates all
|
||||
// Value/SetValue operations to the root context.
|
||||
// Call ReleasePluginScope() when done to return the scoped context to the pool.
|
||||
func (bc *BifrostContext) WithPluginScope(name *string) *BifrostContext {
|
||||
// Lazily initialize the plugin log store on the root context (CAS to avoid race)
|
||||
if bc.pluginLogs.Load() == nil {
|
||||
newStore := pluginLogStorePool.Get().(*pluginLogStore)
|
||||
if !bc.pluginLogs.CompareAndSwap(nil, newStore) {
|
||||
// Another goroutine initialized first — return unused store to pool
|
||||
pluginLogStorePool.Put(newStore)
|
||||
}
|
||||
}
|
||||
|
||||
scoped := pluginScopePool.Get().(*BifrostContext)
|
||||
scoped.parent = bc.parent
|
||||
scoped.done = bc.done
|
||||
scoped.pluginScope = name
|
||||
scoped.pluginLogs.Store(bc.pluginLogs.Load())
|
||||
scoped.valueDelegate = bc
|
||||
return scoped
|
||||
}
|
||||
|
||||
// ReleasePluginScope returns a scoped context to the pool.
|
||||
// Safe no-op if called on a non-scoped context.
|
||||
// Do not use the scoped context after calling this method.
|
||||
func (bc *BifrostContext) ReleasePluginScope() {
|
||||
if bc.valueDelegate == nil {
|
||||
return // not a scoped context
|
||||
}
|
||||
bc.parent = nil
|
||||
bc.done = nil
|
||||
bc.pluginScope = nil
|
||||
bc.pluginLogs.Store(nil)
|
||||
bc.valueDelegate = nil
|
||||
pluginScopePool.Put(bc)
|
||||
}
|
||||
|
||||
// Log appends a structured log entry for the current plugin scope.
|
||||
// No-op if the context is not scoped to a plugin or has no log store.
|
||||
func (bc *BifrostContext) Log(level LogLevel, msg string) {
|
||||
store := bc.pluginLogs.Load()
|
||||
if bc.pluginScope == nil || store == nil {
|
||||
return
|
||||
}
|
||||
store.mu.Lock()
|
||||
store.logs = append(store.logs, PluginLogEntry{
|
||||
PluginName: *bc.pluginScope,
|
||||
Level: level,
|
||||
Message: msg,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
})
|
||||
store.mu.Unlock()
|
||||
}
|
||||
|
||||
// GetPluginLogs returns a deep copy of all accumulated plugin log entries.
|
||||
// Thread-safe. Returns nil if no logs have been recorded.
|
||||
func (bc *BifrostContext) GetPluginLogs() []PluginLogEntry {
|
||||
store := bc.pluginLogs.Load()
|
||||
if store == nil {
|
||||
return nil
|
||||
}
|
||||
store.mu.Lock()
|
||||
defer store.mu.Unlock()
|
||||
if len(store.logs) == 0 {
|
||||
return nil
|
||||
}
|
||||
copied := make([]PluginLogEntry, len(store.logs))
|
||||
copy(copied, store.logs)
|
||||
return copied
|
||||
}
|
||||
|
||||
// DrainPluginLogs transfers ownership of the plugin log slice to the caller.
|
||||
// The internal log store is returned to the pool after draining.
|
||||
// Returns nil if no logs have been recorded.
|
||||
// This should be called once on the root context after all plugin hooks have completed.
|
||||
func (bc *BifrostContext) DrainPluginLogs() []PluginLogEntry {
|
||||
if bc.valueDelegate != nil {
|
||||
return nil // scoped contexts must not drain the shared log store
|
||||
}
|
||||
store := bc.pluginLogs.Load()
|
||||
if store == nil {
|
||||
return nil
|
||||
}
|
||||
bc.pluginLogs.Store(nil)
|
||||
|
||||
store.mu.Lock()
|
||||
logs := store.logs
|
||||
// Reset with fresh pre-allocated slice before returning to pool
|
||||
store.logs = make([]PluginLogEntry, 0, 8)
|
||||
store.mu.Unlock()
|
||||
|
||||
// Return the store to the pool for reuse
|
||||
pluginLogStorePool.Put(store)
|
||||
|
||||
if len(logs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return logs
|
||||
}
|
||||
12
core/schemas/context_native.go
Normal file
12
core/schemas/context_native.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build !tinygo && !wasm
|
||||
|
||||
package schemas
|
||||
|
||||
import "github.com/valyala/fasthttp"
|
||||
|
||||
// isNonCancellingContext returns true if the context is known to have
|
||||
// a Done() channel that never closes (e.g., fasthttp.RequestCtx).
|
||||
func isNonCancellingContext(parent any) bool {
|
||||
_, ok := parent.(*fasthttp.RequestCtx)
|
||||
return ok
|
||||
}
|
||||
331
core/schemas/context_test.go
Normal file
331
core/schemas/context_test.go
Normal file
@@ -0,0 +1,331 @@
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewBifrostContext_NoGoroutineLeakWithBackgroundAndNoDeadline(t *testing.T) {
|
||||
// Get baseline goroutine count
|
||||
runtime.GC()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
baseline := runtime.NumGoroutine()
|
||||
|
||||
// Create multiple contexts with context.Background() and no deadline
|
||||
// Previously this would leak goroutines
|
||||
contexts := make([]*BifrostContext, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
contexts[i] = NewBifrostContext(context.Background(), NoDeadline)
|
||||
}
|
||||
|
||||
// Give time for any goroutines to start
|
||||
runtime.Gosched()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Check goroutine count - should not have increased significantly
|
||||
// (allow some slack for runtime/test goroutines)
|
||||
afterCreate := runtime.NumGoroutine()
|
||||
|
||||
// With the fix, no goroutines should be spawned since there's nothing to watch
|
||||
// Allow a small margin for test framework goroutines
|
||||
if afterCreate > baseline+10 {
|
||||
t.Errorf("Goroutine leak detected: baseline=%d, after creating 100 contexts=%d", baseline, afterCreate)
|
||||
}
|
||||
|
||||
// Verify the contexts still work correctly
|
||||
for i, ctx := range contexts {
|
||||
// Should not be cancelled
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Errorf("Context %d should not be done", i)
|
||||
default:
|
||||
// Expected
|
||||
}
|
||||
|
||||
// Should return nil error
|
||||
if ctx.Err() != nil {
|
||||
t.Errorf("Context %d Err() should be nil, got %v", i, ctx.Err())
|
||||
}
|
||||
|
||||
// Should have no deadline
|
||||
if _, ok := ctx.Deadline(); ok {
|
||||
t.Errorf("Context %d should not have deadline", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Explicitly cancel all contexts
|
||||
for _, ctx := range contexts {
|
||||
ctx.Cancel()
|
||||
}
|
||||
|
||||
// Verify all are cancelled
|
||||
for i, ctx := range contexts {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Expected
|
||||
default:
|
||||
t.Errorf("Context %d should be done after Cancel()", i)
|
||||
}
|
||||
|
||||
if ctx.Err() != context.Canceled {
|
||||
t.Errorf("Context %d Err() should be context.Canceled, got %v", i, ctx.Err())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBifrostContext_GoroutineStartsWithDeadline(t *testing.T) {
|
||||
// Get baseline goroutine count
|
||||
runtime.GC()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
baseline := runtime.NumGoroutine()
|
||||
|
||||
// Create context with a deadline - should spawn goroutine
|
||||
deadline := time.Now().Add(1 * time.Hour)
|
||||
ctx := NewBifrostContext(context.Background(), deadline)
|
||||
|
||||
// Give time for goroutine to start
|
||||
runtime.Gosched()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
afterCreate := runtime.NumGoroutine()
|
||||
|
||||
// Should have at least one more goroutine for the deadline watcher
|
||||
if afterCreate <= baseline {
|
||||
t.Errorf("Expected goroutine to be spawned for deadline context: baseline=%d, after=%d", baseline, afterCreate)
|
||||
}
|
||||
|
||||
// Clean up
|
||||
ctx.Cancel()
|
||||
}
|
||||
|
||||
func TestNewBifrostContext_GoroutineStartsWithCancellableParent(t *testing.T) {
|
||||
// Get baseline goroutine count
|
||||
runtime.GC()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
baseline := runtime.NumGoroutine()
|
||||
|
||||
// Create a cancellable parent
|
||||
parent, parentCancel := context.WithCancel(context.Background())
|
||||
defer parentCancel()
|
||||
|
||||
// Create BifrostContext with cancellable parent but no deadline
|
||||
// Should spawn goroutine to watch parent
|
||||
ctx := NewBifrostContext(parent, NoDeadline)
|
||||
|
||||
// Give time for goroutine to start
|
||||
runtime.Gosched()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
afterCreate := runtime.NumGoroutine()
|
||||
|
||||
// Should have goroutine for watching parent cancellation
|
||||
if afterCreate <= baseline {
|
||||
t.Errorf("Expected goroutine to be spawned for cancellable parent: baseline=%d, after=%d", baseline, afterCreate)
|
||||
}
|
||||
|
||||
// Verify parent cancellation propagates
|
||||
parentCancel()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Expected
|
||||
default:
|
||||
t.Error("Context should be cancelled when parent is cancelled")
|
||||
}
|
||||
|
||||
if ctx.Err() != context.Canceled {
|
||||
t.Errorf("Context Err() should be context.Canceled, got %v", ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBifrostContext_DeadlineExpires(t *testing.T) {
|
||||
// Create context with short deadline
|
||||
deadline := time.Now().Add(50 * time.Millisecond)
|
||||
ctx := NewBifrostContext(context.Background(), deadline)
|
||||
|
||||
// Should not be done yet
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Error("Context should not be done before deadline")
|
||||
default:
|
||||
// Expected
|
||||
}
|
||||
|
||||
// Wait for deadline
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Should be done now
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Expected
|
||||
default:
|
||||
t.Error("Context should be done after deadline")
|
||||
}
|
||||
|
||||
if ctx.Err() != context.DeadlineExceeded {
|
||||
t.Errorf("Context Err() should be context.DeadlineExceeded, got %v", ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBifrostContext_SetAndGetValue(t *testing.T) {
|
||||
ctx := NewBifrostContext(context.Background(), NoDeadline)
|
||||
|
||||
// Set a value
|
||||
ctx.SetValue("key1", "value1")
|
||||
|
||||
// Get the value
|
||||
if v := ctx.Value("key1"); v != "value1" {
|
||||
t.Errorf("Expected value1, got %v", v)
|
||||
}
|
||||
|
||||
// Get non-existent key
|
||||
if v := ctx.Value("nonexistent"); v != nil {
|
||||
t.Errorf("Expected nil for non-existent key, got %v", v)
|
||||
}
|
||||
|
||||
// Clean up
|
||||
ctx.Cancel()
|
||||
}
|
||||
|
||||
func TestNewBifrostContext_NilParent(t *testing.T) {
|
||||
// Should not panic with nil parent
|
||||
// Note: passing nil is allowed by NewBifrostContext which converts it to context.Background()
|
||||
var nilCtx context.Context //nolint:staticcheck // testing nil parent handling
|
||||
ctx := NewBifrostContext(nilCtx, NoDeadline)
|
||||
|
||||
// Should work normally
|
||||
if ctx.Err() != nil {
|
||||
t.Errorf("New context should have nil error, got %v", ctx.Err())
|
||||
}
|
||||
|
||||
ctx.Cancel()
|
||||
|
||||
if ctx.Err() != context.Canceled {
|
||||
t.Errorf("Cancelled context should have Canceled error, got %v", ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
// Plugin logging tests
|
||||
|
||||
func TestPluginLog_NoScopeIsNoop(t *testing.T) {
|
||||
ctx := NewBifrostContext(context.Background(), NoDeadline)
|
||||
ctx.Log(LogLevelInfo, "should be ignored")
|
||||
logs := ctx.GetPluginLogs()
|
||||
if logs != nil {
|
||||
t.Errorf("expected nil logs without plugin scope, got %v", logs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPluginLog_SinglePlugin(t *testing.T) {
|
||||
ctx := NewBifrostContext(context.Background(), NoDeadline)
|
||||
name := "test-plugin"
|
||||
scoped := ctx.WithPluginScope(&name)
|
||||
scoped.Log(LogLevelInfo, "hello")
|
||||
scoped.Log(LogLevelError, "oops")
|
||||
scoped.ReleasePluginScope()
|
||||
|
||||
logs := ctx.GetPluginLogs()
|
||||
if len(logs) != 2 {
|
||||
t.Fatalf("expected 2 logs, got %d", len(logs))
|
||||
}
|
||||
if logs[0].PluginName != "test-plugin" || logs[0].Level != LogLevelInfo || logs[0].Message != "hello" {
|
||||
t.Errorf("unexpected first log: %+v", logs[0])
|
||||
}
|
||||
if logs[1].Level != LogLevelError || logs[1].Message != "oops" {
|
||||
t.Errorf("unexpected second log: %+v", logs[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPluginLog_MultiplePlugins(t *testing.T) {
|
||||
ctx := NewBifrostContext(context.Background(), NoDeadline)
|
||||
|
||||
name1 := "plugin-a"
|
||||
s1 := ctx.WithPluginScope(&name1)
|
||||
s1.Log(LogLevelDebug, "a-msg")
|
||||
s1.ReleasePluginScope()
|
||||
|
||||
name2 := "plugin-b"
|
||||
s2 := ctx.WithPluginScope(&name2)
|
||||
s2.Log(LogLevelWarn, "b-msg")
|
||||
s2.ReleasePluginScope()
|
||||
|
||||
logs := ctx.GetPluginLogs()
|
||||
if len(logs) != 2 {
|
||||
t.Fatalf("expected 2 logs, got %d", len(logs))
|
||||
}
|
||||
if logs[0].PluginName != "plugin-a" {
|
||||
t.Errorf("expected plugin-a, got %s", logs[0].PluginName)
|
||||
}
|
||||
if logs[1].PluginName != "plugin-b" {
|
||||
t.Errorf("expected plugin-b, got %s", logs[1].PluginName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPluginLog_DrainTransfersOwnership(t *testing.T) {
|
||||
ctx := NewBifrostContext(context.Background(), NoDeadline)
|
||||
name := "drain-test"
|
||||
scoped := ctx.WithPluginScope(&name)
|
||||
scoped.Log(LogLevelInfo, "msg1")
|
||||
scoped.ReleasePluginScope()
|
||||
|
||||
drained := ctx.DrainPluginLogs()
|
||||
if len(drained) != 1 {
|
||||
t.Fatalf("expected 1 drained log, got %d", len(drained))
|
||||
}
|
||||
|
||||
// After drain, GetPluginLogs should return nil
|
||||
after := ctx.GetPluginLogs()
|
||||
if after != nil {
|
||||
t.Errorf("expected nil after drain, got %v", after)
|
||||
}
|
||||
|
||||
// Second drain should return nil
|
||||
second := ctx.DrainPluginLogs()
|
||||
if second != nil {
|
||||
t.Errorf("expected nil on second drain, got %v", second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPluginLog_ScopedContextValueDelegation(t *testing.T) {
|
||||
ctx := NewBifrostContext(context.Background(), NoDeadline)
|
||||
ctx.SetValue(BifrostContextKeyTraceID, "trace-123")
|
||||
|
||||
name := "delegate-test"
|
||||
scoped := ctx.WithPluginScope(&name)
|
||||
|
||||
// Scoped should read from root
|
||||
val := scoped.Value(BifrostContextKeyTraceID)
|
||||
if val != "trace-123" {
|
||||
t.Errorf("expected trace-123, got %v", val)
|
||||
}
|
||||
|
||||
// Scoped should write to root
|
||||
type testContextKey string
|
||||
const customKey testContextKey = "custom-key"
|
||||
scoped.SetValue(customKey, "custom-val")
|
||||
if ctx.Value(customKey) != "custom-val" {
|
||||
t.Errorf("SetValue on scoped did not delegate to root")
|
||||
}
|
||||
|
||||
scoped.ReleasePluginScope()
|
||||
}
|
||||
|
||||
func TestPluginLog_PoolReuse(t *testing.T) {
|
||||
ctx := NewBifrostContext(context.Background(), NoDeadline)
|
||||
|
||||
// Create and release multiple scoped contexts to exercise the pool
|
||||
for i := 0; i < 100; i++ {
|
||||
name := "pool-test"
|
||||
scoped := ctx.WithPluginScope(&name)
|
||||
scoped.Log(LogLevelInfo, "pooled")
|
||||
scoped.ReleasePluginScope()
|
||||
}
|
||||
|
||||
logs := ctx.DrainPluginLogs()
|
||||
if len(logs) != 100 {
|
||||
t.Errorf("expected 100 logs from pool reuse, got %d", len(logs))
|
||||
}
|
||||
}
|
||||
10
core/schemas/context_wasm.go
Normal file
10
core/schemas/context_wasm.go
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build tinygo || wasm
|
||||
|
||||
package schemas
|
||||
|
||||
// isNonCancellingContext returns true if the context is known to have
|
||||
// a Done() channel that never closes. In wasm builds, fasthttp is not
|
||||
// available, so this always returns false.
|
||||
func isNonCancellingContext(parent any) bool {
|
||||
return false
|
||||
}
|
||||
14
core/schemas/count_tokens.go
Normal file
14
core/schemas/count_tokens.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package schemas
|
||||
|
||||
// BifrostCountTokensResponse captures token counts for a provided input.
|
||||
type BifrostCountTokensResponse struct {
|
||||
Object string `json:"object,omitempty"`
|
||||
Model string `json:"model"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
InputTokensDetails *ResponsesResponseInputTokens `json:"input_tokens_details,omitempty"`
|
||||
Tokens []int `json:"tokens"`
|
||||
TokenStrings []string `json:"token_strings,omitempty"`
|
||||
OutputTokens *int `json:"output_tokens,omitempty"`
|
||||
TotalTokens *int `json:"total_tokens"`
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
187
core/schemas/embedding.go
Normal file
187
core/schemas/embedding.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type BifrostEmbeddingRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
Input *EmbeddingInput `json:"input,omitempty"`
|
||||
Params *EmbeddingParameters `json:"params,omitempty"`
|
||||
Fallbacks []Fallback `json:"fallbacks,omitempty"`
|
||||
RawRequestBody []byte `json:"-"` // set bifrost-use-raw-request-body to true in ctx to use the raw request body. Bifrost will directly send this to the downstream provider.
|
||||
}
|
||||
|
||||
func (r *BifrostEmbeddingRequest) GetRawRequestBody() []byte {
|
||||
return r.RawRequestBody
|
||||
}
|
||||
|
||||
type BifrostEmbeddingResponse struct {
|
||||
Data []EmbeddingData `json:"data"` // Maps to "data" field in provider responses (e.g., OpenAI embedding format)
|
||||
Model string `json:"model"`
|
||||
Object string `json:"object"` // "list"
|
||||
Usage *BifrostLLMUsage `json:"usage"`
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
|
||||
// EmbeddingInput represents the input for an embedding request.
|
||||
type EmbeddingInput struct {
|
||||
Text *string
|
||||
Texts []string
|
||||
Embedding []int
|
||||
Embeddings [][]int
|
||||
}
|
||||
|
||||
func (e *EmbeddingInput) MarshalJSON() ([]byte, error) {
|
||||
// enforce one-of
|
||||
set := 0
|
||||
if e.Text != nil {
|
||||
set++
|
||||
}
|
||||
if e.Texts != nil {
|
||||
set++
|
||||
}
|
||||
if e.Embedding != nil {
|
||||
set++
|
||||
}
|
||||
if e.Embeddings != nil {
|
||||
set++
|
||||
}
|
||||
if set == 0 {
|
||||
return nil, fmt.Errorf("embedding input is empty")
|
||||
}
|
||||
if set > 1 {
|
||||
return nil, fmt.Errorf("embedding input must set exactly one of: text, texts, embedding, embeddings")
|
||||
}
|
||||
|
||||
if e.Text != nil {
|
||||
return MarshalSorted(*e.Text)
|
||||
}
|
||||
if e.Texts != nil {
|
||||
return MarshalSorted(e.Texts)
|
||||
}
|
||||
if e.Embedding != nil {
|
||||
return MarshalSorted(e.Embedding)
|
||||
}
|
||||
if e.Embeddings != nil {
|
||||
return MarshalSorted(e.Embeddings)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("invalid embedding input")
|
||||
}
|
||||
|
||||
func (e *EmbeddingInput) UnmarshalJSON(data []byte) error {
|
||||
e.Text = nil
|
||||
e.Texts = nil
|
||||
e.Embedding = nil
|
||||
e.Embeddings = nil
|
||||
// Try string
|
||||
var s string
|
||||
if err := Unmarshal(data, &s); err == nil {
|
||||
e.Text = &s
|
||||
return nil
|
||||
}
|
||||
// Try []string
|
||||
var ss []string
|
||||
if err := Unmarshal(data, &ss); err == nil {
|
||||
e.Texts = ss
|
||||
return nil
|
||||
}
|
||||
// Try []int
|
||||
var i []int
|
||||
if err := Unmarshal(data, &i); err == nil {
|
||||
e.Embedding = i
|
||||
return nil
|
||||
}
|
||||
// Try [][]int
|
||||
var i2 [][]int
|
||||
if err := Unmarshal(data, &i2); err == nil {
|
||||
e.Embeddings = i2
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("unsupported embedding input shape")
|
||||
}
|
||||
|
||||
type EmbeddingParameters struct {
|
||||
EncodingFormat *string `json:"encoding_format,omitempty"` // Format for embedding output (e.g., "float", "base64")
|
||||
Dimensions *int `json:"dimensions,omitempty"` // Number of dimensions for embedding output
|
||||
|
||||
// Dynamic parameters that can be provider-specific, they are directly
|
||||
// added to the request as is.
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
type EmbeddingData struct {
|
||||
Index int `json:"index"`
|
||||
Object string `json:"object"` // "embedding"
|
||||
Embedding EmbeddingStruct `json:"embedding"` // can be string, []float64, [][]float64, []int8, or []int32
|
||||
}
|
||||
|
||||
type EmbeddingStruct struct {
|
||||
// Embedding responses preserve provider precision in normalized API output.
|
||||
EmbeddingStr *string
|
||||
EmbeddingArray []float64
|
||||
Embedding2DArray [][]float64
|
||||
EmbeddingInt8Array []int8 // for int8 / binary formats
|
||||
EmbeddingInt32Array []int32 // for uint8 / ubinary formats
|
||||
}
|
||||
|
||||
func (be EmbeddingStruct) MarshalJSON() ([]byte, error) {
|
||||
if be.EmbeddingStr != nil {
|
||||
return MarshalSorted(be.EmbeddingStr)
|
||||
}
|
||||
if be.EmbeddingArray != nil {
|
||||
return MarshalSorted(be.EmbeddingArray)
|
||||
}
|
||||
if be.Embedding2DArray != nil {
|
||||
return MarshalSorted(be.Embedding2DArray)
|
||||
}
|
||||
if be.EmbeddingInt8Array != nil {
|
||||
return Marshal(be.EmbeddingInt8Array)
|
||||
}
|
||||
if be.EmbeddingInt32Array != nil {
|
||||
return Marshal(be.EmbeddingInt32Array)
|
||||
}
|
||||
return nil, fmt.Errorf("no embedding found")
|
||||
}
|
||||
|
||||
func (be *EmbeddingStruct) UnmarshalJSON(data []byte) error {
|
||||
// First, try to unmarshal as a direct string
|
||||
var stringContent string
|
||||
if err := Unmarshal(data, &stringContent); err == nil {
|
||||
be.EmbeddingStr = &stringContent
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to unmarshal as a direct array of float64
|
||||
var arrayContent []float64
|
||||
if err := Unmarshal(data, &arrayContent); err == nil {
|
||||
be.EmbeddingArray = arrayContent
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to unmarshal as a direct 2D array of float64
|
||||
var arrayContent2D [][]float64
|
||||
if err := Unmarshal(data, &arrayContent2D); err == nil {
|
||||
be.Embedding2DArray = arrayContent2D
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to unmarshal as a direct array of int8
|
||||
var int8Content []int8
|
||||
if err := Unmarshal(data, &int8Content); err == nil {
|
||||
be.EmbeddingInt8Array = int8Content
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to unmarshal as a direct array of int32
|
||||
var int32Content []int32
|
||||
if err := Unmarshal(data, &int32Content); err == nil {
|
||||
be.EmbeddingInt32Array = int32Content
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("embedding field is neither a string, []float64, [][]float64, []int8, nor []int32")
|
||||
}
|
||||
353
core/schemas/envvar.go
Normal file
353
core/schemas/envvar.go
Normal file
@@ -0,0 +1,353 @@
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
)
|
||||
|
||||
// EnvVar is a wrapper around a value that can be sourced from an environment variable.
|
||||
type EnvVar struct {
|
||||
Val string `json:"value"`
|
||||
EnvVar string `json:"env_var"`
|
||||
FromEnv bool `json:"from_env"`
|
||||
}
|
||||
|
||||
// NewEnvVar creates a new EnvValue from a string.
|
||||
func NewEnvVar(value string) *EnvVar {
|
||||
// Cleanup string if required
|
||||
// Use strconv.Unquote to properly handle JSON string escape sequences
|
||||
// This converts "\"{\\\"key\\\":\\\"value\\\"}\"" to "{\"key\":\"value\"}"
|
||||
val := value
|
||||
if unquoted, err := strconv.Unquote(value); err == nil {
|
||||
val = unquoted
|
||||
}
|
||||
// Here we will need to check if the incoming data is a valid JSON object
|
||||
// If it's a valid JSON object and follows the EnvVar schema, then we will unmarshal it into an EnvVar object
|
||||
if sonic.Valid([]byte(value)) {
|
||||
valueNode, _ := sonic.Get([]byte(val), "value")
|
||||
envNode, _ := sonic.Get([]byte(val), "env_var")
|
||||
if valueNode.Exists() && envNode.Exists() {
|
||||
// Use a type alias to avoid infinite recursion (alias doesn't inherit methods)
|
||||
type envVarAlias EnvVar
|
||||
var envVar envVarAlias
|
||||
if err := sonic.Unmarshal([]byte(value), &envVar); err == nil {
|
||||
e := &EnvVar{
|
||||
Val: envVar.Val,
|
||||
FromEnv: envVar.FromEnv,
|
||||
EnvVar: envVar.EnvVar,
|
||||
}
|
||||
// Here we will check if the Val starts with env and is same as the EnvVar
|
||||
if strings.HasPrefix(e.Val, "env.") && e.Val == e.EnvVar {
|
||||
e.Val = ""
|
||||
// Load the environment variable value
|
||||
envValue, ok := os.LookupEnv(strings.TrimPrefix(e.EnvVar, "env."))
|
||||
if ok {
|
||||
e.Val = envValue
|
||||
}
|
||||
e.FromEnv = true
|
||||
}
|
||||
return e
|
||||
}
|
||||
}
|
||||
}
|
||||
if envKey, ok := strings.CutPrefix(val, "env."); ok {
|
||||
if envValue, ok := os.LookupEnv(envKey); ok {
|
||||
return &EnvVar{
|
||||
Val: envValue,
|
||||
FromEnv: true,
|
||||
EnvVar: val,
|
||||
}
|
||||
}
|
||||
return &EnvVar{
|
||||
Val: "",
|
||||
FromEnv: true,
|
||||
EnvVar: val,
|
||||
}
|
||||
}
|
||||
return &EnvVar{
|
||||
Val: val,
|
||||
FromEnv: false,
|
||||
EnvVar: "",
|
||||
}
|
||||
}
|
||||
|
||||
// IsRedacted returns true if the value is redacted.
|
||||
func (e *EnvVar) IsRedacted() bool {
|
||||
if e.Val == "" && !e.FromEnv {
|
||||
return false
|
||||
}
|
||||
// Check if it's an environment variable reference
|
||||
if e.FromEnv {
|
||||
return true
|
||||
}
|
||||
if len(e.Val) <= 8 {
|
||||
return strings.Count(e.Val, "*") == len(e.Val)
|
||||
}
|
||||
// Check for exact redaction pattern: 4 chars + 24 asterisks + 4 chars
|
||||
if len(e.Val) == 32 {
|
||||
middle := e.Val[4:28]
|
||||
if middle == strings.Repeat("*", 24) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// Check if its string <redacted>
|
||||
if e.Val == "<redacted>" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Equals checks if two SecretKeys are equal.
|
||||
func (e *EnvVar) Equals(other *EnvVar) bool {
|
||||
if e == nil && other == nil {
|
||||
return true
|
||||
}
|
||||
if e == nil || other == nil {
|
||||
return false
|
||||
}
|
||||
return e.Val == other.Val &&
|
||||
e.EnvVar == other.EnvVar &&
|
||||
e.FromEnv == other.FromEnv
|
||||
}
|
||||
|
||||
// Redacted returns a new SecretKey with the value redacted.
|
||||
func (e *EnvVar) Redacted() *EnvVar {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
if e.Val == "" {
|
||||
return &EnvVar{
|
||||
Val: "",
|
||||
FromEnv: e.FromEnv,
|
||||
EnvVar: e.EnvVar,
|
||||
}
|
||||
}
|
||||
// If key is 8 characters or less, just return all asterisks
|
||||
if len(e.Val) <= 8 {
|
||||
return &EnvVar{
|
||||
Val: strings.Repeat("*", len(e.Val)),
|
||||
FromEnv: e.FromEnv,
|
||||
EnvVar: e.EnvVar,
|
||||
}
|
||||
}
|
||||
// Show first 4 and last 4 characters, replace middle with asterisks
|
||||
prefix := e.Val[:4]
|
||||
suffix := e.Val[len(e.Val)-4:]
|
||||
middle := strings.Repeat("*", 24)
|
||||
|
||||
return &EnvVar{
|
||||
Val: prefix + middle + suffix,
|
||||
FromEnv: e.FromEnv,
|
||||
EnvVar: e.EnvVar,
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalJSON serializes the EnvVar to JSON.
|
||||
// SECURITY: When the value was sourced from an environment variable, the resolved
|
||||
// value is automatically redacted before being serialized. This ensures that secrets
|
||||
// injected via env vars are never leaked through any JSON API response, regardless
|
||||
// of whether the surrounding code remembered to call Redacted() explicitly.
|
||||
//
|
||||
// Plain (non-env) values are still emitted as-is — callers that want to mask those
|
||||
// must continue using Redacted() at the field level (this matches the existing
|
||||
// per-provider redaction logic).
|
||||
//
|
||||
// This does NOT affect:
|
||||
// - GORM persistence (uses the Value() driver method, not JSON)
|
||||
// - Encryption (operates on the Val field directly)
|
||||
// - Internal LLM request paths (use GetValue() directly)
|
||||
func (e EnvVar) MarshalJSON() ([]byte, error) {
|
||||
type envVarAlias EnvVar
|
||||
out := envVarAlias(e)
|
||||
if e.FromEnv {
|
||||
// Redact the resolved value but keep the env var reference and from_env flag
|
||||
// so the UI still knows which env var backs this field.
|
||||
redacted := e.Redacted()
|
||||
if redacted != nil {
|
||||
out = envVarAlias(*redacted)
|
||||
}
|
||||
}
|
||||
return sonic.Marshal(out)
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals the value from JSON.
|
||||
func (e *EnvVar) UnmarshalJSON(data []byte) error {
|
||||
// This is always going to be value
|
||||
// Here we will first considering this as value
|
||||
// if it has env. then we will process it and set the FromEnv to true
|
||||
// if it doesn't have env. then we will set the FromEnv to false
|
||||
// if it has env. then we will process it and set the FromEnv to true
|
||||
val := string(data)
|
||||
// Cleanup string if required
|
||||
// Use strconv.Unquote to properly handle JSON string escape sequences
|
||||
// This converts "\"{\\\"key\\\":\\\"value\\\"}\"" to "{\"key\":\"value\"}"
|
||||
if unquoted, err := strconv.Unquote(val); err == nil {
|
||||
val = unquoted
|
||||
}
|
||||
// Here we will need to check if the incoming data is a valid JSON object
|
||||
// If it's a valid JSON object and follows the EnvVar schema, then we will unmarshal it into an EnvVar object
|
||||
if sonic.Valid(data) {
|
||||
valueNode, _ := sonic.Get(data, "value")
|
||||
envNode, _ := sonic.Get(data, "env_var")
|
||||
if valueNode.Exists() && envNode.Exists() {
|
||||
// Use a type alias to avoid infinite recursion (alias doesn't inherit methods)
|
||||
type envVarAlias EnvVar
|
||||
var envVar envVarAlias
|
||||
if err := sonic.Unmarshal(data, &envVar); err == nil {
|
||||
e.Val = envVar.Val
|
||||
e.FromEnv = envVar.FromEnv
|
||||
e.EnvVar = envVar.EnvVar
|
||||
// Here we will check if the Val starts with env and is same as the EnvVar
|
||||
if strings.HasPrefix(e.Val, "env.") && e.Val == e.EnvVar {
|
||||
e.Val = ""
|
||||
// Load the environment variable value
|
||||
envValue, ok := os.LookupEnv(strings.TrimPrefix(e.EnvVar, "env."))
|
||||
if ok {
|
||||
e.Val = envValue
|
||||
}
|
||||
e.FromEnv = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// Else the value is JSON, so we will treat this as a normal value
|
||||
}
|
||||
}
|
||||
if envKey, ok := strings.CutPrefix(val, "env."); ok {
|
||||
if envValue, ok := os.LookupEnv(envKey); ok {
|
||||
e.Val = envValue
|
||||
e.FromEnv = true
|
||||
e.EnvVar = val
|
||||
return nil
|
||||
}
|
||||
e.Val = ""
|
||||
e.FromEnv = true
|
||||
e.EnvVar = val
|
||||
return nil
|
||||
}
|
||||
e.Val = val
|
||||
e.FromEnv = false
|
||||
e.EnvVar = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
// String returns the value as a string.
|
||||
func (e *EnvVar) String() string {
|
||||
return e.Val
|
||||
}
|
||||
|
||||
// Scan scans the value from the database.
|
||||
func (e *EnvVar) Scan(value any) error {
|
||||
if value == nil {
|
||||
e.Val = ""
|
||||
e.FromEnv = false
|
||||
e.EnvVar = ""
|
||||
return nil
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case []byte:
|
||||
return e.Scan(string(v))
|
||||
case string:
|
||||
// Cleanup string if required
|
||||
// The string may have "\"env.TEST\"", "env.TEST" or "env.TEST\"", we need to clean it up to "env.TEST"
|
||||
val := strings.Trim(v, "\"")
|
||||
if envKey, ok := strings.CutPrefix(val, "env."); ok {
|
||||
if envValue, ok := os.LookupEnv(envKey); ok {
|
||||
e.Val = envValue
|
||||
e.FromEnv = true
|
||||
e.EnvVar = val
|
||||
return nil
|
||||
}
|
||||
e.Val = ""
|
||||
e.FromEnv = true
|
||||
e.EnvVar = val
|
||||
return nil
|
||||
}
|
||||
e.Val = val
|
||||
e.FromEnv = false
|
||||
e.EnvVar = ""
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed to scan value: %v", value)
|
||||
}
|
||||
|
||||
// Value implements driver.Valuer for database storage.
|
||||
// It stores the original env reference (e.g., "env.API_KEY") if FromEnv is true,
|
||||
// otherwise stores the raw value.
|
||||
func (e EnvVar) Value() (driver.Value, error) {
|
||||
if e.FromEnv {
|
||||
return e.EnvVar, nil
|
||||
}
|
||||
return e.Val, nil
|
||||
}
|
||||
|
||||
// IsFromEnv returns true if the value is sourced from an environment variable.
|
||||
func (e *EnvVar) IsFromEnv() bool {
|
||||
return e.FromEnv
|
||||
}
|
||||
|
||||
// IsSet returns true if the EnvVar has a resolved value or an environment variable reference.
|
||||
// This should be used instead of GetValue() != "" when checking whether a field was configured,
|
||||
// because env var references may have an empty Val before resolution (e.g., when the env var
|
||||
// is not available in the current environment).
|
||||
func (e *EnvVar) IsSet() bool {
|
||||
if e == nil {
|
||||
return false
|
||||
}
|
||||
return e.Val != "" || e.EnvVar != ""
|
||||
}
|
||||
|
||||
// GetValue returns the value.
|
||||
func (e *EnvVar) GetValue() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
return e.Val
|
||||
}
|
||||
|
||||
// GetValuePtr returns a pointer to the value.
|
||||
func (e *EnvVar) GetValuePtr() *string {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
return &e.Val
|
||||
}
|
||||
|
||||
// CoerceInt coerces value to int
|
||||
func (e *EnvVar) CoerceInt(defaultValue int) int {
|
||||
if e == nil {
|
||||
return defaultValue
|
||||
}
|
||||
val, err := strconv.Atoi(e.GetValue())
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// CoerceBool coerces value to bool
|
||||
func (e *EnvVar) CoerceBool(defaultValue bool) bool {
|
||||
if e == nil {
|
||||
return defaultValue
|
||||
}
|
||||
val, err := strconv.ParseBool(e.GetValue())
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// IsDefined returns true if the EnvVar has a source (static value or env key)
|
||||
func (e *EnvVar) IsDefined() bool {
|
||||
if e == nil {
|
||||
return false
|
||||
}
|
||||
if e.IsFromEnv() {
|
||||
return e.EnvVar != ""
|
||||
}
|
||||
return e.Val != ""
|
||||
}
|
||||
609
core/schemas/envvar_test.go
Normal file
609
core/schemas/envvar_test.go
Normal file
@@ -0,0 +1,609 @@
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEnvVar_UnmarshalJSON_DoubleEscapedJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "service account credentials with escaped JSON",
|
||||
input: `"{\"type\":\"service_account\",\"project_id\":\"test-project\"}"`,
|
||||
expected: `{"type":"service_account","project_id":"test-project"}`,
|
||||
},
|
||||
{
|
||||
name: "nested JSON object with multiple levels of escaping",
|
||||
input: `"{\"key\":\"value\",\"nested\":{\"inner\":\"data\"}}"`,
|
||||
expected: `{"key":"value","nested":{"inner":"data"}}`,
|
||||
},
|
||||
{
|
||||
name: "JSON with escaped newlines in private key",
|
||||
input: `"{\"private_key\":\"-----BEGIN PRIVATE KEY-----\\nMIIE...\\n-----END PRIVATE KEY-----\\n\"}"`,
|
||||
expected: `{"private_key":"-----BEGIN PRIVATE KEY-----\nMIIE...\n-----END PRIVATE KEY-----\n"}`,
|
||||
},
|
||||
{
|
||||
name: "simple string value",
|
||||
input: `"sk-test-api-key-12345"`,
|
||||
expected: "sk-test-api-key-12345",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: `""`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "string with special characters",
|
||||
input: `"hello\"world"`,
|
||||
expected: `hello"world`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var envVar EnvVar
|
||||
err := envVar.UnmarshalJSON([]byte(tt.input))
|
||||
if err != nil {
|
||||
t.Fatalf("UnmarshalJSON failed: %v", err)
|
||||
}
|
||||
if envVar.Val != tt.expected {
|
||||
t.Errorf("Expected Val=%q, got Val=%q", tt.expected, envVar.Val)
|
||||
}
|
||||
if envVar.FromEnv {
|
||||
t.Errorf("Expected FromEnv=false, got FromEnv=true")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvVar_UnmarshalJSON_EnvVarReference(t *testing.T) {
|
||||
// Set up test environment variable
|
||||
os.Setenv("TEST_API_KEY", "actual-api-key-value")
|
||||
defer os.Unsetenv("TEST_API_KEY")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedVal string
|
||||
expectedEnvVar string
|
||||
expectedFromEnv bool
|
||||
}{
|
||||
{
|
||||
name: "env var reference with value present",
|
||||
input: `"env.TEST_API_KEY"`,
|
||||
expectedVal: "actual-api-key-value",
|
||||
expectedEnvVar: "env.TEST_API_KEY",
|
||||
expectedFromEnv: true,
|
||||
},
|
||||
{
|
||||
name: "env var reference with missing value",
|
||||
input: `"env.NONEXISTENT_VAR"`,
|
||||
expectedVal: "",
|
||||
expectedEnvVar: "env.NONEXISTENT_VAR",
|
||||
expectedFromEnv: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var envVar EnvVar
|
||||
err := envVar.UnmarshalJSON([]byte(tt.input))
|
||||
if err != nil {
|
||||
t.Fatalf("UnmarshalJSON failed: %v", err)
|
||||
}
|
||||
if envVar.Val != tt.expectedVal {
|
||||
t.Errorf("Expected Val=%q, got Val=%q", tt.expectedVal, envVar.Val)
|
||||
}
|
||||
if envVar.EnvVar != tt.expectedEnvVar {
|
||||
t.Errorf("Expected EnvVar=%q, got EnvVar=%q", tt.expectedEnvVar, envVar.EnvVar)
|
||||
}
|
||||
if envVar.FromEnv != tt.expectedFromEnv {
|
||||
t.Errorf("Expected FromEnv=%v, got FromEnv=%v", tt.expectedFromEnv, envVar.FromEnv)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvVar_UnmarshalJSON_FullStructure(t *testing.T) {
|
||||
// Test when the input is already an EnvVar JSON object
|
||||
input := `{"value":"my-api-key","env_var":"env.MY_KEY","from_env":true}`
|
||||
|
||||
var envVar EnvVar
|
||||
err := envVar.UnmarshalJSON([]byte(input))
|
||||
if err != nil {
|
||||
t.Fatalf("UnmarshalJSON failed: %v", err)
|
||||
}
|
||||
if envVar.Val != "my-api-key" {
|
||||
t.Errorf("Expected Val=%q, got Val=%q", "my-api-key", envVar.Val)
|
||||
}
|
||||
if envVar.EnvVar != "env.MY_KEY" {
|
||||
t.Errorf("Expected EnvVar=%q, got EnvVar=%q", "env.MY_KEY", envVar.EnvVar)
|
||||
}
|
||||
if !envVar.FromEnv {
|
||||
t.Errorf("Expected FromEnv=true, got FromEnv=false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewEnvVar_DoubleEscapedJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "service account credentials with escaped JSON",
|
||||
input: `"{\"type\":\"service_account\",\"project_id\":\"test-project\"}"`,
|
||||
expected: `{"type":"service_account","project_id":"test-project"}`,
|
||||
},
|
||||
{
|
||||
name: "JSON with escaped newlines",
|
||||
input: `"{\"private_key\":\"-----BEGIN-----\\nDATA\\n-----END-----\\n\"}"`,
|
||||
expected: `{"private_key":"-----BEGIN-----\nDATA\n-----END-----\n"}`,
|
||||
},
|
||||
{
|
||||
name: "simple string without quotes",
|
||||
input: "sk-test-api-key",
|
||||
expected: "sk-test-api-key",
|
||||
},
|
||||
{
|
||||
name: "simple string with outer quotes",
|
||||
input: `"sk-test-api-key"`,
|
||||
expected: "sk-test-api-key",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
envVar := NewEnvVar(tt.input)
|
||||
if envVar.Val != tt.expected {
|
||||
t.Errorf("Expected Val=%q, got Val=%q", tt.expected, envVar.Val)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewEnvVar_EnvVarReference(t *testing.T) {
|
||||
// Set up test environment variable
|
||||
os.Setenv("TEST_NEW_ENVVAR_KEY", "resolved-value")
|
||||
defer os.Unsetenv("TEST_NEW_ENVVAR_KEY")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedVal string
|
||||
expectedEnvVar string
|
||||
expectedFromEnv bool
|
||||
}{
|
||||
{
|
||||
name: "env var reference with value present",
|
||||
input: "env.TEST_NEW_ENVVAR_KEY",
|
||||
expectedVal: "resolved-value",
|
||||
expectedEnvVar: "env.TEST_NEW_ENVVAR_KEY",
|
||||
expectedFromEnv: true,
|
||||
},
|
||||
{
|
||||
name: "env var reference with quotes",
|
||||
input: `"env.TEST_NEW_ENVVAR_KEY"`,
|
||||
expectedVal: "resolved-value",
|
||||
expectedEnvVar: "env.TEST_NEW_ENVVAR_KEY",
|
||||
expectedFromEnv: true,
|
||||
},
|
||||
{
|
||||
name: "env var reference missing",
|
||||
input: "env.MISSING_VAR",
|
||||
expectedVal: "",
|
||||
expectedEnvVar: "env.MISSING_VAR",
|
||||
expectedFromEnv: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
envVar := NewEnvVar(tt.input)
|
||||
if envVar.Val != tt.expectedVal {
|
||||
t.Errorf("Expected Val=%q, got Val=%q", tt.expectedVal, envVar.Val)
|
||||
}
|
||||
if envVar.EnvVar != tt.expectedEnvVar {
|
||||
t.Errorf("Expected EnvVar=%q, got EnvVar=%q", tt.expectedEnvVar, envVar.EnvVar)
|
||||
}
|
||||
if envVar.FromEnv != tt.expectedFromEnv {
|
||||
t.Errorf("Expected FromEnv=%v, got FromEnv=%v", tt.expectedFromEnv, envVar.FromEnv)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEnvVar_RealWorldVertexCredentials tests the actual use case that triggered
|
||||
// the double-escaping bug: Vertex AI service account credentials
|
||||
func TestEnvVar_RealWorldVertexCredentials(t *testing.T) {
|
||||
// This simulates what happens when parsing config.json with embedded service account JSON
|
||||
type VertexKeyConfig struct {
|
||||
ProjectID EnvVar `json:"project_id"`
|
||||
Region EnvVar `json:"region"`
|
||||
AuthCredentials EnvVar `json:"auth_credentials"`
|
||||
}
|
||||
|
||||
jsonInput := `{
|
||||
"project_id": "my-project",
|
||||
"region": "us-central1",
|
||||
"auth_credentials": "{\"type\":\"service_account\",\"project_id\":\"my-project\",\"private_key_id\":\"abc123\",\"private_key\":\"-----BEGIN PRIVATE KEY-----\\nMIIE...\\n-----END PRIVATE KEY-----\\n\",\"client_email\":\"test@my-project.iam.gserviceaccount.com\"}"
|
||||
}`
|
||||
|
||||
var config VertexKeyConfig
|
||||
err := json.Unmarshal([]byte(jsonInput), &config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
// Verify auth_credentials is properly unescaped
|
||||
expectedAuthCreds := `{"type":"service_account","project_id":"my-project","private_key_id":"abc123","private_key":"-----BEGIN PRIVATE KEY-----\nMIIE...\n-----END PRIVATE KEY-----\n","client_email":"test@my-project.iam.gserviceaccount.com"}`
|
||||
if config.AuthCredentials.Val != expectedAuthCreds {
|
||||
t.Errorf("AuthCredentials not properly unescaped.\nExpected: %s\nGot: %s", expectedAuthCreds, config.AuthCredentials.Val)
|
||||
}
|
||||
|
||||
// Verify simple string fields work correctly
|
||||
if config.ProjectID.Val != "my-project" {
|
||||
t.Errorf("Expected ProjectID=%q, got %q", "my-project", config.ProjectID.Val)
|
||||
}
|
||||
if config.Region.Val != "us-central1" {
|
||||
t.Errorf("Expected Region=%q, got %q", "us-central1", config.Region.Val)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEnvVar_MixedConfigParsing tests parsing a config with both env var references
|
||||
// and embedded JSON credentials
|
||||
func TestEnvVar_MixedConfigParsing(t *testing.T) {
|
||||
os.Setenv("TEST_PROJECT_ID", "env-project-id")
|
||||
defer os.Unsetenv("TEST_PROJECT_ID")
|
||||
|
||||
type Config struct {
|
||||
ProjectID EnvVar `json:"project_id"`
|
||||
Credentials EnvVar `json:"credentials"`
|
||||
}
|
||||
|
||||
jsonInput := `{
|
||||
"project_id": "env.TEST_PROJECT_ID",
|
||||
"credentials": "{\"type\":\"service_account\",\"key\":\"value\"}"
|
||||
}`
|
||||
|
||||
var config Config
|
||||
err := json.Unmarshal([]byte(jsonInput), &config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
// Verify env var reference is resolved
|
||||
if config.ProjectID.Val != "env-project-id" {
|
||||
t.Errorf("Expected ProjectID=%q, got %q", "env-project-id", config.ProjectID.Val)
|
||||
}
|
||||
if !config.ProjectID.FromEnv {
|
||||
t.Errorf("Expected ProjectID.FromEnv=true")
|
||||
}
|
||||
|
||||
// Verify JSON credentials are properly unescaped
|
||||
expectedCreds := `{"type":"service_account","key":"value"}`
|
||||
if config.Credentials.Val != expectedCreds {
|
||||
t.Errorf("Expected Credentials=%q, got %q", expectedCreds, config.Credentials.Val)
|
||||
}
|
||||
if config.Credentials.FromEnv {
|
||||
t.Errorf("Expected Credentials.FromEnv=false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvVar_Equals(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
a *EnvVar
|
||||
b *EnvVar
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "both nil",
|
||||
a: nil,
|
||||
b: nil,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "first nil",
|
||||
a: nil,
|
||||
b: &EnvVar{Val: "test"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "second nil",
|
||||
a: &EnvVar{Val: "test"},
|
||||
b: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "equal values",
|
||||
a: &EnvVar{Val: "test", EnvVar: "env.TEST", FromEnv: true},
|
||||
b: &EnvVar{Val: "test", EnvVar: "env.TEST", FromEnv: true},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "different values",
|
||||
a: &EnvVar{Val: "test1"},
|
||||
b: &EnvVar{Val: "test2"},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.a.Equals(tt.b)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected Equals=%v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvVar_Redacted(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input EnvVar
|
||||
expectedVal string
|
||||
}{
|
||||
{
|
||||
name: "empty value",
|
||||
input: EnvVar{Val: ""},
|
||||
expectedVal: "",
|
||||
},
|
||||
{
|
||||
name: "short value (8 chars)",
|
||||
input: EnvVar{Val: "12345678"},
|
||||
expectedVal: "********",
|
||||
},
|
||||
{
|
||||
name: "long value",
|
||||
input: EnvVar{Val: "sk-1234567890abcdefghijklmnop"},
|
||||
expectedVal: "sk-1************************mnop",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.input.Redacted()
|
||||
if result.Val != tt.expectedVal {
|
||||
t.Errorf("Expected Redacted Val=%q, got %q", tt.expectedVal, result.Val)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvVar_IsRedacted(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input EnvVar
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "empty not from env",
|
||||
input: EnvVar{Val: "", FromEnv: false},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "from env",
|
||||
input: EnvVar{Val: "test", FromEnv: true},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "short all asterisks",
|
||||
input: EnvVar{Val: "****"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "redacted pattern 32 chars",
|
||||
input: EnvVar{Val: "sk-1************************mnop"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "normal value",
|
||||
input: EnvVar{Val: "sk-test-key"},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.input.IsRedacted()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected IsRedacted=%v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEnvVar_IsSet verifies the semantic difference between GetValue() != "" and IsSet().
|
||||
// IsSet() must return true when the EnvVar references an env var (regardless of whether
|
||||
// that env var has been resolved to a non-empty Val). This is the property that the
|
||||
// BeforeSave hooks rely on so env var references survive persistence.
|
||||
func TestEnvVar_IsSet(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input *EnvVar
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil envvar",
|
||||
input: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "completely empty",
|
||||
input: &EnvVar{},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "only Val set (plain value)",
|
||||
input: &EnvVar{Val: "abc"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "only EnvVar reference set (env not resolved on this server)",
|
||||
input: &EnvVar{EnvVar: "env.MISSING", FromEnv: true},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Val and EnvVar both set (env was resolved)",
|
||||
input: &EnvVar{Val: "resolved-secret", EnvVar: "env.X", FromEnv: true},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "FromEnv true but no reference and no value",
|
||||
input: &EnvVar{FromEnv: true},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.input.IsSet(); got != tt.expected {
|
||||
t.Errorf("IsSet() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEnvVar_MarshalJSON_AutoRedactsEnvBackedValues verifies that any EnvVar marshaled
|
||||
// to JSON with FromEnv=true is automatically masked, regardless of whether the
|
||||
// surrounding code remembered to call Redacted() explicitly. This is the defense-in-depth
|
||||
// guarantee that prevents env-resolved secrets from leaking through unredacted fields.
|
||||
func TestEnvVar_MarshalJSON_AutoRedactsEnvBackedValues(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input EnvVar
|
||||
wantValue string
|
||||
wantEnvVar string
|
||||
wantFromEnv bool
|
||||
}{
|
||||
{
|
||||
name: "env-backed long secret is redacted",
|
||||
input: EnvVar{Val: "sk-1234567890abcdefghijklmnop", EnvVar: "env.OPENAI_API_KEY", FromEnv: true},
|
||||
wantValue: "sk-1************************mnop",
|
||||
wantEnvVar: "env.OPENAI_API_KEY",
|
||||
wantFromEnv: true,
|
||||
},
|
||||
{
|
||||
name: "env-backed short secret is fully masked",
|
||||
input: EnvVar{Val: "12345678", EnvVar: "env.SHORT", FromEnv: true},
|
||||
wantValue: "********",
|
||||
wantEnvVar: "env.SHORT",
|
||||
wantFromEnv: true,
|
||||
},
|
||||
{
|
||||
name: "env-backed unresolved on this server keeps empty value",
|
||||
input: EnvVar{Val: "", EnvVar: "env.MISSING", FromEnv: true},
|
||||
wantValue: "",
|
||||
wantEnvVar: "env.MISSING",
|
||||
wantFromEnv: true,
|
||||
},
|
||||
{
|
||||
name: "plain value (not from env) is NOT redacted",
|
||||
input: EnvVar{Val: "2024-10-21", EnvVar: "", FromEnv: false},
|
||||
wantValue: "2024-10-21",
|
||||
wantEnvVar: "",
|
||||
wantFromEnv: false,
|
||||
},
|
||||
{
|
||||
name: "empty plain value passes through",
|
||||
input: EnvVar{Val: "", EnvVar: "", FromEnv: false},
|
||||
wantValue: "",
|
||||
wantEnvVar: "",
|
||||
wantFromEnv: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data, err := json.Marshal(tt.input)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
var got struct {
|
||||
Value string `json:"value"`
|
||||
EnvVar string `json:"env_var"`
|
||||
FromEnv bool `json:"from_env"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Fatalf("Unmarshal of marshaled output failed: %v", err)
|
||||
}
|
||||
if got.Value != tt.wantValue {
|
||||
t.Errorf("value: got %q, want %q", got.Value, tt.wantValue)
|
||||
}
|
||||
if got.EnvVar != tt.wantEnvVar {
|
||||
t.Errorf("env_var: got %q, want %q", got.EnvVar, tt.wantEnvVar)
|
||||
}
|
||||
if got.FromEnv != tt.wantFromEnv {
|
||||
t.Errorf("from_env: got %v, want %v", got.FromEnv, tt.wantFromEnv)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEnvVar_MarshalJSON_DoesNotMutateOriginal ensures the auto-redaction in MarshalJSON
|
||||
// does not mutate the receiver. The inference path calls GetValue() to build the actual
|
||||
// HTTP request to the LLM provider, so the original Val must remain intact.
|
||||
func TestEnvVar_MarshalJSON_DoesNotMutateOriginal(t *testing.T) {
|
||||
original := EnvVar{Val: "real-secret-value", EnvVar: "env.SECRET", FromEnv: true}
|
||||
if _, err := json.Marshal(original); err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
if original.Val != "real-secret-value" {
|
||||
t.Errorf("MarshalJSON mutated Val: got %q, want %q", original.Val, "real-secret-value")
|
||||
}
|
||||
if original.GetValue() != "real-secret-value" {
|
||||
t.Errorf("GetValue() returns mutated value: got %q", original.GetValue())
|
||||
}
|
||||
}
|
||||
|
||||
// TestEnvVar_MarshalJSON_RoundTripIsRedacted verifies that a marshaled-then-unmarshaled
|
||||
// env-backed EnvVar is recognized as redacted. The merge logic in provider_keys.go relies
|
||||
// on this so it can detect "the UI sent back the same redacted value, don't overwrite".
|
||||
func TestEnvVar_MarshalJSON_RoundTripIsRedacted(t *testing.T) {
|
||||
original := EnvVar{Val: "sk-1234567890abcdefghijklmnop", EnvVar: "env.KEY", FromEnv: true}
|
||||
data, err := json.Marshal(original)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
var roundTripped EnvVar
|
||||
if err := json.Unmarshal(data, &roundTripped); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if !roundTripped.IsRedacted() {
|
||||
t.Errorf("Round-tripped env-backed value should be IsRedacted, got Val=%q", roundTripped.Val)
|
||||
}
|
||||
if roundTripped.EnvVar != "env.KEY" {
|
||||
t.Errorf("env_var reference lost in round-trip: got %q, want %q", roundTripped.EnvVar, "env.KEY")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEnvVar_MarshalJSON_DoesNotAffectGetValue is a critical safety net: marshaling an
|
||||
// EnvVar to JSON must NOT change what GetValue() returns. The inference path uses
|
||||
// GetValue() to build outgoing LLM requests; if marshaling were to mutate the value,
|
||||
// every request after a UI fetch would silently start using the redacted mask as the
|
||||
// API key.
|
||||
func TestEnvVar_MarshalJSON_DoesNotAffectGetValue(t *testing.T) {
|
||||
os.Setenv("MY_REAL_API_KEY", "sk-real-secret-1234567890abcdef")
|
||||
defer os.Unsetenv("MY_REAL_API_KEY")
|
||||
|
||||
ev := NewEnvVar("env.MY_REAL_API_KEY")
|
||||
if ev.GetValue() != "sk-real-secret-1234567890abcdef" {
|
||||
t.Fatalf("setup: GetValue() = %q, want resolved env value", ev.GetValue())
|
||||
}
|
||||
|
||||
// Marshaling would redact in the JSON output, but must not touch the in-memory Val.
|
||||
if _, err := json.Marshal(ev); err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
|
||||
if ev.GetValue() != "sk-real-secret-1234567890abcdef" {
|
||||
t.Errorf("GetValue() returns mutated value after MarshalJSON: got %q", ev.GetValue())
|
||||
}
|
||||
}
|
||||
253
core/schemas/files.go
Normal file
253
core/schemas/files.go
Normal file
@@ -0,0 +1,253 @@
|
||||
// Package schemas defines the core schemas and types used by the Bifrost system.
|
||||
package schemas
|
||||
|
||||
// FilePurpose represents the purpose of an uploaded file.
|
||||
type FilePurpose string
|
||||
|
||||
const (
|
||||
FilePurposeBatch FilePurpose = "batch"
|
||||
FilePurposeAssistants FilePurpose = "assistants"
|
||||
FilePurposeFineTune FilePurpose = "fine-tune"
|
||||
FilePurposeVision FilePurpose = "vision"
|
||||
FilePurposeBatchOutput FilePurpose = "batch_output"
|
||||
FilePurposeUserData FilePurpose = "user_data"
|
||||
FilePurposeResponses FilePurpose = "responses"
|
||||
FilePurposeEvals FilePurpose = "evals"
|
||||
)
|
||||
|
||||
// FileStatus represents the status of a file.
|
||||
type FileStatus string
|
||||
|
||||
const (
|
||||
FileStatusUploaded FileStatus = "uploaded"
|
||||
FileStatusProcessed FileStatus = "processed"
|
||||
FileStatusProcessing FileStatus = "processing"
|
||||
FileStatusError FileStatus = "error"
|
||||
FileStatusDeleted FileStatus = "deleted"
|
||||
)
|
||||
|
||||
// FileStorageBackend represents the storage backend type.
|
||||
type FileStorageBackend string
|
||||
|
||||
const (
|
||||
FileStorageAPI FileStorageBackend = "api" // OpenAI/Azure REST API
|
||||
FileStorageS3 FileStorageBackend = "s3" // AWS S3
|
||||
FileStorageGCS FileStorageBackend = "gcs" // Google Cloud Storage
|
||||
FileStorageMemory FileStorageBackend = "memory" // In-memory (for Anthropic virtual files)
|
||||
)
|
||||
|
||||
// FileObject represents a file object returned by the API.
|
||||
type FileObject struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object,omitempty"` // "file"
|
||||
Bytes int64 `json:"bytes"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at,omitempty"`
|
||||
Filename string `json:"filename"`
|
||||
Purpose FilePurpose `json:"purpose"`
|
||||
Status FileStatus `json:"status,omitempty"`
|
||||
StatusDetails *string `json:"status_details,omitempty"`
|
||||
ExpiresAt *int64 `json:"expires_at,omitempty"`
|
||||
}
|
||||
|
||||
// BifrostFileUploadRequest represents a request to upload a file.
|
||||
type BifrostFileUploadRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
Model *string `json:"model"`
|
||||
|
||||
// File content
|
||||
File []byte `json:"-"` // Raw file content (not serialized)
|
||||
Filename string `json:"filename"` // Original filename
|
||||
Purpose FilePurpose `json:"purpose"` // Purpose of the file (e.g., "batch")
|
||||
ContentType *string `json:"content_type,omitempty"` // MIME type of the file
|
||||
|
||||
// Storage configuration (for S3/GCS backends)
|
||||
StorageConfig *FileStorageConfig `json:"storage_config,omitempty"`
|
||||
|
||||
// Expiration configuration (OpenAI only)
|
||||
ExpiresAfter *FileExpiresAfter `json:"expires_after,omitempty"`
|
||||
|
||||
// Extra parameters for provider-specific features
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// S3StorageConfig represents AWS S3 storage configuration.
|
||||
type S3StorageConfig struct {
|
||||
Bucket string `json:"bucket,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
Prefix string `json:"prefix,omitempty"`
|
||||
}
|
||||
|
||||
// GCSStorageConfig represents Google Cloud Storage configuration.
|
||||
type GCSStorageConfig struct {
|
||||
Bucket string `json:"bucket,omitempty"`
|
||||
Project string `json:"project,omitempty"`
|
||||
Prefix string `json:"prefix,omitempty"`
|
||||
}
|
||||
|
||||
// FileExpiresAfter represents an expiration configuration for uploaded files.
|
||||
type FileExpiresAfter struct {
|
||||
Anchor string `json:"anchor"` // e.g., "created_at"
|
||||
Seconds int `json:"seconds"` // 3600-2592000 (1 hour to 30 days)
|
||||
}
|
||||
|
||||
// FileStorageConfig represents storage configuration for cloud storage backends.
|
||||
type FileStorageConfig struct {
|
||||
S3 *S3StorageConfig `json:"s3,omitempty"`
|
||||
GCS *GCSStorageConfig `json:"gcs,omitempty"`
|
||||
}
|
||||
|
||||
// BifrostFileUploadResponse represents the response from uploading a file.
|
||||
type BifrostFileUploadResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object,omitempty"` // "file"
|
||||
Bytes int64 `json:"bytes"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Filename string `json:"filename"`
|
||||
Purpose FilePurpose `json:"purpose"`
|
||||
Status FileStatus `json:"status,omitempty"`
|
||||
StatusDetails *string `json:"status_details,omitempty"`
|
||||
ExpiresAt *int64 `json:"expires_at,omitempty"`
|
||||
|
||||
// Storage backend info
|
||||
StorageBackend FileStorageBackend `json:"storage_backend,omitempty"`
|
||||
StorageURI string `json:"storage_uri,omitempty"` // S3/GCS URI if applicable
|
||||
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
|
||||
// BifrostFileListRequest represents a request to list files.
|
||||
type BifrostFileListRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
Model *string `json:"model"`
|
||||
|
||||
RawRequestBody []byte `json:"-"` // Raw request body (not serialized)
|
||||
|
||||
// Filters
|
||||
Purpose FilePurpose `json:"purpose,omitempty"` // Filter by purpose
|
||||
|
||||
// Pagination
|
||||
Limit int `json:"limit,omitempty"` // Max results to return
|
||||
After *string `json:"after,omitempty"` // Cursor for pagination
|
||||
Order *string `json:"order,omitempty"` // Sort order (asc/desc)
|
||||
|
||||
// Storage configuration (for S3/GCS backends)
|
||||
StorageConfig *FileStorageConfig `json:"storage_config,omitempty"`
|
||||
|
||||
// Extra parameters for provider-specific features
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// GetRawRequestBody returns the raw request body.
|
||||
func (request *BifrostFileListRequest) GetRawRequestBody() []byte {
|
||||
return request.RawRequestBody
|
||||
}
|
||||
|
||||
// BifrostFileListResponse represents the response from listing files.
|
||||
type BifrostFileListResponse struct {
|
||||
Object string `json:"object,omitempty"` // "list"
|
||||
Data []FileObject `json:"data"`
|
||||
HasMore bool `json:"has_more,omitempty"`
|
||||
After *string `json:"after,omitempty"` // Continuation token for pagination
|
||||
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
|
||||
// BifrostFileRetrieveRequest represents a request to retrieve file metadata.
|
||||
type BifrostFileRetrieveRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
Model *string `json:"model"`
|
||||
|
||||
RawRequestBody []byte `json:"-"` // Raw request body (not serialized)
|
||||
|
||||
FileID string `json:"file_id"` // ID of the file to retrieve
|
||||
|
||||
// Storage configuration (for S3/GCS backends)
|
||||
StorageConfig *FileStorageConfig `json:"storage_config,omitempty"`
|
||||
|
||||
// Extra parameters for provider-specific features
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// GetRawRequestBody returns the raw request body.
|
||||
func (request *BifrostFileRetrieveRequest) GetRawRequestBody() []byte {
|
||||
return request.RawRequestBody
|
||||
}
|
||||
|
||||
// BifrostFileRetrieveResponse represents the response from retrieving file metadata.
|
||||
type BifrostFileRetrieveResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object,omitempty"` // "file"
|
||||
Bytes int64 `json:"bytes"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at,omitempty"`
|
||||
Filename string `json:"filename"`
|
||||
Purpose FilePurpose `json:"purpose"`
|
||||
Status FileStatus `json:"status,omitempty"`
|
||||
StatusDetails *string `json:"status_details,omitempty"`
|
||||
ExpiresAt *int64 `json:"expires_at,omitempty"`
|
||||
|
||||
// Storage backend info
|
||||
StorageBackend FileStorageBackend `json:"storage_backend,omitempty"`
|
||||
StorageURI string `json:"storage_uri,omitempty"`
|
||||
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
|
||||
// BifrostFileDeleteRequest represents a request to delete a file.
|
||||
type BifrostFileDeleteRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
Model *string `json:"model"`
|
||||
FileID string `json:"file_id"` // ID of the file to delete
|
||||
|
||||
RawRequestBody []byte `json:"-"` // Raw request body (not serialized)
|
||||
|
||||
// Storage configuration (for S3/GCS backends)
|
||||
StorageConfig *FileStorageConfig `json:"storage_config,omitempty"`
|
||||
|
||||
// Extra parameters for provider-specific features
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// GetRawRequestBody returns the raw request body.
|
||||
func (request *BifrostFileDeleteRequest) GetRawRequestBody() []byte {
|
||||
return request.RawRequestBody
|
||||
}
|
||||
|
||||
// BifrostFileDeleteResponse represents the response from deleting a file.
|
||||
type BifrostFileDeleteResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object,omitempty"` // "file"
|
||||
Deleted bool `json:"deleted"`
|
||||
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
|
||||
// BifrostFileContentRequest represents a request to download file content.
|
||||
type BifrostFileContentRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
Model *string `json:"model"`
|
||||
FileID string `json:"file_id"` // ID of the file to download
|
||||
|
||||
RawRequestBody []byte `json:"-"` // Raw request body (not serialized)
|
||||
|
||||
// Storage configuration (for S3/GCS backends)
|
||||
StorageConfig *FileStorageConfig `json:"storage_config,omitempty"`
|
||||
|
||||
// Extra parameters for provider-specific features
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// GetRawRequestBody returns the raw request body.
|
||||
func (request *BifrostFileContentRequest) GetRawRequestBody() []byte {
|
||||
return request.RawRequestBody
|
||||
}
|
||||
|
||||
// BifrostFileContentResponse represents the response from downloading file content.
|
||||
type BifrostFileContentResponse struct {
|
||||
FileID string `json:"file_id"`
|
||||
Content []byte `json:"-"` // Raw file content (not serialized)
|
||||
ContentType string `json:"content_type,omitempty"` // MIME type
|
||||
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
338
core/schemas/images.go
Normal file
338
core/schemas/images.go
Normal file
@@ -0,0 +1,338 @@
|
||||
package schemas
|
||||
|
||||
type ImageEventType string
|
||||
|
||||
const (
|
||||
ImageGenerationEventTypePartial ImageEventType = "image_generation.partial_image"
|
||||
ImageGenerationEventTypeCompleted ImageEventType = "image_generation.completed"
|
||||
ImageGenerationEventTypeError ImageEventType = "error"
|
||||
ImageEditEventTypePartial ImageEventType = "image_edit.partial_image"
|
||||
ImageEditEventTypeCompleted ImageEventType = "image_edit.completed"
|
||||
ImageEditEventTypeError ImageEventType = "error"
|
||||
)
|
||||
|
||||
// BifrostImageGenerationRequest represents an image generation request in bifrost format
|
||||
type BifrostImageGenerationRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
Input *ImageGenerationInput `json:"input"`
|
||||
Params *ImageGenerationParameters `json:"params,omitempty"`
|
||||
Fallbacks []Fallback `json:"fallbacks,omitempty"`
|
||||
RawRequestBody []byte `json:"-"`
|
||||
}
|
||||
|
||||
// GetRawRequestBody implements utils.RequestBodyGetter.
|
||||
func (b *BifrostImageGenerationRequest) GetRawRequestBody() []byte {
|
||||
return b.RawRequestBody
|
||||
}
|
||||
|
||||
type ImageGenerationInput struct {
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
type ImageGenerationParameters struct {
|
||||
N *int `json:"n,omitempty"` // Number of images (1-10)
|
||||
Background *string `json:"background,omitempty"` // "transparent", "opaque", "auto"
|
||||
Moderation *string `json:"moderation,omitempty"` // "low", "auto"
|
||||
PartialImages *int `json:"partial_images,omitempty"` // 0-3
|
||||
Size *string `json:"size,omitempty"` // "256x256", "512x512", "1024x1024", "1792x1024", "1024x1792", "1536x1024", "1024x1536", "auto"
|
||||
Quality *string `json:"quality,omitempty"` // "auto", "high", "medium", "low", "hd", "standard"
|
||||
OutputCompression *int `json:"output_compression,omitempty"` // compression level (0-100%)
|
||||
OutputFormat *string `json:"output_format,omitempty"` // "png", "webp", "jpeg"
|
||||
Style *string `json:"style,omitempty"` // "natural", "vivid"
|
||||
ResponseFormat *string `json:"response_format,omitempty"` // "url", "b64_json"
|
||||
Seed *int `json:"seed,omitempty"` // seed for image generation
|
||||
NegativePrompt *string `json:"negative_prompt,omitempty"` // negative prompt for image generation
|
||||
NumInferenceSteps *int `json:"num_inference_steps,omitempty"` // number of inference steps
|
||||
User *string `json:"user,omitempty"`
|
||||
InputImages []string `json:"input_images,omitempty"` // input images for image generation, base64 encoded or URL
|
||||
AspectRatio *string `json:"aspect_ratio,omitempty"` // aspect ratio of the image
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// BifrostImageGenerationResponse represents the image generation response in bifrost format
|
||||
type BifrostImageGenerationResponse struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Created int64 `json:"created,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Data []ImageData `json:"data"`
|
||||
|
||||
*ImageGenerationResponseParameters
|
||||
|
||||
Usage *ImageUsage `json:"usage,omitempty"`
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields,omitempty"`
|
||||
}
|
||||
|
||||
// BackfillParams populates response fields from the original request that are needed
|
||||
// for cost calculation but may not be returned by the provider.
|
||||
// - NumInputImages on ImageUsage (count of input images from the request)
|
||||
// - Size on ImageGenerationResponseParameters (from request params if not in response)
|
||||
// - Quality (low, medium, high, auto) only
|
||||
func (r *BifrostImageGenerationResponse) BackfillParams(req *BifrostRequest) {
|
||||
if r == nil || req == nil {
|
||||
return
|
||||
}
|
||||
numInputImages, size, quality := getNumInputImagesSizeAndQualityFromRequest(req)
|
||||
|
||||
// Backfill Model from whichever inner request carries it. Some provider APIs
|
||||
// (notably OpenAI /v1/images/*) omit model in the response body.
|
||||
if r.Model == "" {
|
||||
switch {
|
||||
case req.ImageGenerationRequest != nil:
|
||||
r.Model = req.ImageGenerationRequest.Model
|
||||
case req.ImageEditRequest != nil:
|
||||
r.Model = req.ImageEditRequest.Model
|
||||
case req.ImageVariationRequest != nil:
|
||||
r.Model = req.ImageVariationRequest.Model
|
||||
}
|
||||
}
|
||||
|
||||
// Backfill NumInputImages
|
||||
if numInputImages > 0 {
|
||||
if r.Usage == nil {
|
||||
r.Usage = &ImageUsage{}
|
||||
}
|
||||
r.Usage.NumInputImages = numInputImages
|
||||
}
|
||||
|
||||
// Backfill Size if not already present from provider response
|
||||
if size != "" && (r.ImageGenerationResponseParameters == nil || r.ImageGenerationResponseParameters.Size == "") {
|
||||
if r.ImageGenerationResponseParameters == nil {
|
||||
r.ImageGenerationResponseParameters = &ImageGenerationResponseParameters{}
|
||||
}
|
||||
r.ImageGenerationResponseParameters.Size = size
|
||||
}
|
||||
|
||||
// Backfill Quality if not already present from provider response
|
||||
if quality != "" && (r.ImageGenerationResponseParameters == nil || r.ImageGenerationResponseParameters.Quality == "") {
|
||||
if r.ImageGenerationResponseParameters == nil {
|
||||
r.ImageGenerationResponseParameters = &ImageGenerationResponseParameters{}
|
||||
}
|
||||
r.ImageGenerationResponseParameters.Quality = quality
|
||||
}
|
||||
}
|
||||
|
||||
// getModelFromRequest extracts the model from any image-related request.
|
||||
func getModelFromRequest(req *BifrostRequest) string {
|
||||
if req == nil {
|
||||
return ""
|
||||
}
|
||||
switch {
|
||||
case req.ImageGenerationRequest != nil:
|
||||
return req.ImageGenerationRequest.Model
|
||||
case req.ImageEditRequest != nil:
|
||||
return req.ImageEditRequest.Model
|
||||
case req.ImageVariationRequest != nil:
|
||||
return req.ImageVariationRequest.Model
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getNumInputImagesSizeAndQualityFromRequest extracts request params for cost calculation.
|
||||
// Quality is only returned when it is one of low, medium, high, auto.
|
||||
func getNumInputImagesSizeAndQualityFromRequest(req *BifrostRequest) (numInputImages int, size string, quality string) {
|
||||
if req == nil {
|
||||
return 0, "", ""
|
||||
}
|
||||
|
||||
switch {
|
||||
case req.ImageGenerationRequest != nil:
|
||||
if req.ImageGenerationRequest.Params != nil {
|
||||
p := req.ImageGenerationRequest.Params
|
||||
numInputImages = len(p.InputImages)
|
||||
if p.Size != nil {
|
||||
size = *p.Size
|
||||
}
|
||||
if p.Quality != nil {
|
||||
quality = normalizeImageQuality(*p.Quality)
|
||||
}
|
||||
}
|
||||
case req.ImageEditRequest != nil:
|
||||
if req.ImageEditRequest.Input != nil {
|
||||
numInputImages = len(req.ImageEditRequest.Input.Images)
|
||||
}
|
||||
if req.ImageEditRequest.Params != nil {
|
||||
p := req.ImageEditRequest.Params
|
||||
if p.Size != nil {
|
||||
size = *p.Size
|
||||
}
|
||||
if p.Quality != nil {
|
||||
quality = normalizeImageQuality(*p.Quality)
|
||||
}
|
||||
}
|
||||
case req.ImageVariationRequest != nil:
|
||||
if req.ImageVariationRequest.Input != nil {
|
||||
numInputImages = 1
|
||||
}
|
||||
if req.ImageVariationRequest.Params != nil && req.ImageVariationRequest.Params.Size != nil {
|
||||
size = *req.ImageVariationRequest.Params.Size
|
||||
}
|
||||
}
|
||||
return numInputImages, size, quality
|
||||
}
|
||||
|
||||
// normalizeImageQuality returns the quality string only if it is supported by gpt-image-1.5 (low, medium, high, auto).
|
||||
// All other values (hd, standard, etc.) are discarded and return empty.
|
||||
func normalizeImageQuality(q string) string {
|
||||
switch q {
|
||||
case "low", "medium", "high", "auto":
|
||||
return q
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
type ImageGenerationResponseParameters struct {
|
||||
Background string `json:"background,omitempty"`
|
||||
OutputFormat string `json:"output_format,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
FinishReasons []*string `json:"finish_reasons,omitempty"`
|
||||
Seeds []int `json:"seeds,omitempty"`
|
||||
}
|
||||
|
||||
type ImageData struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
B64JSON string `json:"b64_json,omitempty"`
|
||||
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type ImageUsage struct {
|
||||
InputTokens int `json:"input_tokens,omitempty"` // Always text tokens unless InputTokensDetails is not nil
|
||||
InputTokensDetails *ImageTokenDetails `json:"input_tokens_details,omitempty"`
|
||||
TotalTokens int `json:"total_tokens,omitempty"`
|
||||
OutputTokens int `json:"output_tokens,omitempty"` // Always image tokens unless OutputTokensDetails is not nil
|
||||
OutputTokensDetails *ImageTokenDetails `json:"output_tokens_details,omitempty"`
|
||||
NumInputImages int `json:"num_input_images,omitempty"` // Number of input images from the request (populated by Bifrost)
|
||||
}
|
||||
|
||||
type ImageTokenDetails struct {
|
||||
NImages int `json:"-"` // Number of images generated (used internally for bifrost)
|
||||
ImageTokens int `json:"image_tokens,omitempty"`
|
||||
TextTokens int `json:"text_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// Streaming Response
|
||||
type BifrostImageGenerationStreamResponse struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Type ImageEventType `json:"type,omitempty"`
|
||||
Index int `json:"-"` // Which image (0-N)
|
||||
ChunkIndex int `json:"-"` // Chunk order within image
|
||||
PartialImageIndex *int `json:"partial_image_index,omitempty"`
|
||||
SequenceNumber int `json:"sequence_number,omitempty"`
|
||||
B64JSON string `json:"b64_json,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
CreatedAt int64 `json:"created_at,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
Background string `json:"background,omitempty"`
|
||||
OutputFormat string `json:"output_format,omitempty"`
|
||||
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||
Usage *ImageUsage `json:"usage,omitempty"`
|
||||
Error *BifrostError `json:"error,omitempty"`
|
||||
RawRequest string `json:"-"`
|
||||
RawResponse string `json:"-"`
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields,omitempty"`
|
||||
}
|
||||
|
||||
// BackfillParams populates response fields from the original request that are needed
|
||||
// for cost calculation but may not be returned by the provider.
|
||||
// - NumInputImages on ImageUsage (count of input images from the request)
|
||||
// - Size on ImageGenerationResponseParameters (from request params if not in response)
|
||||
// - Quality (low, medium, high, auto) only
|
||||
func (r *BifrostImageGenerationStreamResponse) BackfillParams(req *BifrostRequest) {
|
||||
numInputImages, size, quality := getNumInputImagesSizeAndQualityFromRequest(req)
|
||||
|
||||
// Backfill NumInputImages
|
||||
if numInputImages > 0 {
|
||||
if r.Usage == nil {
|
||||
r.Usage = &ImageUsage{}
|
||||
}
|
||||
r.Usage.NumInputImages = numInputImages
|
||||
}
|
||||
|
||||
// Backfill Size if not already present from provider response
|
||||
if size != "" && r.Size == "" {
|
||||
r.Size = size
|
||||
}
|
||||
|
||||
// Backfill Quality if not already present (only low, medium, high, auto)
|
||||
if quality != "" && r.Quality == "" {
|
||||
r.Quality = quality
|
||||
}
|
||||
}
|
||||
|
||||
// BifrostImageEditRequest represents an image edit request in bifrost format
|
||||
type BifrostImageEditRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
Input *ImageEditInput `json:"input"`
|
||||
Params *ImageEditParameters `json:"params,omitempty"`
|
||||
Fallbacks []Fallback `json:"fallbacks,omitempty"`
|
||||
RawRequestBody []byte `json:"-"`
|
||||
}
|
||||
|
||||
// GetRawRequestBody implements [utils.RequestBodyGetter].
|
||||
func (b *BifrostImageEditRequest) GetRawRequestBody() []byte {
|
||||
return b.RawRequestBody
|
||||
}
|
||||
|
||||
type ImageEditInput struct {
|
||||
Images []ImageInput `json:"images"`
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
type ImageInput struct {
|
||||
Image []byte `json:"image"`
|
||||
}
|
||||
|
||||
type ImageEditParameters struct {
|
||||
Type *string `json:"type,omitempty"` // "inpainting", "outpainting", "background_removal", "remove_background", "erase_object", "recolor", "search_replace", "control_sketch", "control_structure", "style_guide", "style_transfer", "upscale_fast", "upscale_creative", "upscale_conservative"
|
||||
Background *string `json:"background,omitempty"` // "transparent", "opaque", "auto"
|
||||
InputFidelity *string `json:"input_fidelity,omitempty"` // "low", "high"
|
||||
Mask []byte `json:"mask,omitempty"`
|
||||
N *int `json:"n,omitempty"` // number of images to generate (1-10)
|
||||
OutputCompression *int `json:"output_compression,omitempty"` // compression level (0-100%)
|
||||
OutputFormat *string `json:"output_format,omitempty"` // "png", "webp", "jpeg"
|
||||
PartialImages *int `json:"partial_images,omitempty"` // 0-3
|
||||
Quality *string `json:"quality,omitempty"` // "auto", "high", "medium", "low", "standard"
|
||||
ResponseFormat *string `json:"response_format,omitempty"` // "url", "b64_json"
|
||||
Size *string `json:"size,omitempty"` // "256x256", "512x512", "1024x1024", "1536x1024", "1024x1536", "auto"
|
||||
User *string `json:"user,omitempty"`
|
||||
NegativePrompt *string `json:"negative_prompt,omitempty"` // negative prompt for image editing
|
||||
Seed *int `json:"seed,omitempty"` // seed for image editing
|
||||
NumInferenceSteps *int `json:"num_inference_steps,omitempty"` // number of inference steps
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// BifrostImageVariationRequest represents an image variation request in bifrost format
|
||||
type BifrostImageVariationRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
Input *ImageVariationInput `json:"input"`
|
||||
Params *ImageVariationParameters `json:"params,omitempty"`
|
||||
Fallbacks []Fallback `json:"fallbacks,omitempty"`
|
||||
RawRequestBody []byte `json:"-"`
|
||||
}
|
||||
|
||||
// GetRawRequestBody implements [utils.RequestBodyGetter].
|
||||
func (b *BifrostImageVariationRequest) GetRawRequestBody() []byte {
|
||||
return b.RawRequestBody
|
||||
}
|
||||
|
||||
type ImageVariationInput struct {
|
||||
Image ImageInput `json:"image"`
|
||||
}
|
||||
|
||||
type ImageVariationParameters struct {
|
||||
N *int `json:"n,omitempty"` // Number of images (1-10)
|
||||
ResponseFormat *string `json:"response_format,omitempty"` // "url", "b64_json"
|
||||
Size *string `json:"size,omitempty"` // "256x256", "512x512", "1024x1024", "1792x1024", "1024x1792", "1536x1024", "1024x1536", "auto"
|
||||
User *string `json:"user,omitempty"`
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// BifrostImageVariationResponse represents the image variation response in bifrost format
|
||||
// It uses the same structure as image generation response
|
||||
type BifrostImageVariationResponse = BifrostImageGenerationResponse
|
||||
135
core/schemas/json_native.go
Normal file
135
core/schemas/json_native.go
Normal file
@@ -0,0 +1,135 @@
|
||||
//go:build !tinygo && !wasm
|
||||
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
)
|
||||
|
||||
// Marshal encodes v to JSON bytes using the high-performance sonic library.
|
||||
func Marshal(v interface{}) ([]byte, error) {
|
||||
return sonic.Marshal(v)
|
||||
}
|
||||
|
||||
// MarshalString encodes v to a JSON string using sonic.
|
||||
func MarshalString(v interface{}) (string, error) {
|
||||
return sonic.MarshalString(v)
|
||||
}
|
||||
|
||||
// Unmarshal decodes JSON data into v using sonic.
|
||||
func Unmarshal(data []byte, v interface{}) error {
|
||||
return sonic.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
// Compact removes insignificant whitespace from JSON-encoded src
|
||||
// and appends the result to dst.
|
||||
func Compact(dst *bytes.Buffer, src []byte) error {
|
||||
return json.Compact(dst, src)
|
||||
}
|
||||
|
||||
// MarshalSorted encodes v to JSON with map keys sorted alphabetically.
|
||||
// Use this when deterministic output is needed (e.g., hashing, caching keys).
|
||||
// Uses sonic.ConfigStd which has SortMapKeys enabled.
|
||||
func MarshalSorted(v interface{}) ([]byte, error) {
|
||||
return sonic.ConfigStd.Marshal(v)
|
||||
}
|
||||
|
||||
// MarshalSortedIndent encodes v to indented JSON with map keys sorted alphabetically.
|
||||
func MarshalSortedIndent(v interface{}, prefix, indent string) ([]byte, error) {
|
||||
return sonic.ConfigStd.MarshalIndent(v, prefix, indent)
|
||||
}
|
||||
|
||||
// ConvertViaJSON converts src to type T via JSON round-trip using sorted marshaling.
|
||||
// Use as fallback when direct type assertion fails (e.g., map[string]interface{} from JSON).
|
||||
func ConvertViaJSON[T any](src interface{}) (T, error) {
|
||||
var zero T
|
||||
data, err := MarshalSorted(src)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
var result T
|
||||
if err := Unmarshal(data, &result); err != nil {
|
||||
return zero, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// MarshalDeeplySorted encodes v to JSON with all map keys sorted alphabetically,
|
||||
// including nested maps inside OrderedMap and other custom types with MarshalJSON.
|
||||
// This ensures fully deterministic output for hashing/caching purposes.
|
||||
//
|
||||
// Unlike MarshalSorted which relies on sonic's SortMapKeys (which doesn't affect
|
||||
// types with custom MarshalJSON like OrderedMap), this function first normalizes
|
||||
// the entire structure to plain maps, then marshals with sorted keys.
|
||||
func MarshalDeeplySorted(v interface{}) ([]byte, error) {
|
||||
normalized := normalizeForSortedMarshal(v)
|
||||
return sonic.ConfigStd.Marshal(normalized)
|
||||
}
|
||||
|
||||
// normalizeForSortedMarshal recursively converts OrderedMaps and structs to plain maps
|
||||
// so that sonic.ConfigStd.Marshal will sort all keys deterministically.
|
||||
func normalizeForSortedMarshal(v interface{}) interface{} {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch val := v.(type) {
|
||||
case *OrderedMap:
|
||||
if val == nil {
|
||||
return nil
|
||||
}
|
||||
result := make(map[string]interface{}, val.Len())
|
||||
val.Range(func(k string, v interface{}) bool {
|
||||
result[k] = normalizeForSortedMarshal(v)
|
||||
return true
|
||||
})
|
||||
return result
|
||||
case OrderedMap:
|
||||
result := make(map[string]interface{}, val.Len())
|
||||
val.Range(func(k string, v interface{}) bool {
|
||||
result[k] = normalizeForSortedMarshal(v)
|
||||
return true
|
||||
})
|
||||
return result
|
||||
case map[string]interface{}:
|
||||
result := make(map[string]interface{}, len(val))
|
||||
for k, v := range val {
|
||||
result[k] = normalizeForSortedMarshal(v)
|
||||
}
|
||||
return result
|
||||
case []interface{}:
|
||||
result := make([]interface{}, len(val))
|
||||
for i, elem := range val {
|
||||
result[i] = normalizeForSortedMarshal(elem)
|
||||
}
|
||||
return result
|
||||
default:
|
||||
// Intentional round-trip: converts structs with custom MarshalJSON into plain
|
||||
// maps so sonic.ConfigStd can sort all keys. Cannot use sjson since input is a Go struct.
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.Kind() == reflect.Ptr {
|
||||
if rv.IsNil() {
|
||||
return nil
|
||||
}
|
||||
rv = rv.Elem()
|
||||
}
|
||||
if rv.Kind() == reflect.Struct {
|
||||
// Marshal struct to JSON, then unmarshal to map for normalization
|
||||
data, err := sonic.Marshal(v)
|
||||
if err != nil {
|
||||
return v
|
||||
}
|
||||
var m map[string]interface{}
|
||||
if err := sonic.Unmarshal(data, &m); err != nil {
|
||||
return v
|
||||
}
|
||||
// Recursively normalize the resulting map
|
||||
return normalizeForSortedMarshal(m)
|
||||
}
|
||||
return v
|
||||
}
|
||||
}
|
||||
92
core/schemas/json_native_test.go
Normal file
92
core/schemas/json_native_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Test helper types for ConvertViaJSON tests
|
||||
type convertTestTarget struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
type convertTestSource struct {
|
||||
Name string `json:"name"`
|
||||
Value int `json:"value"`
|
||||
}
|
||||
|
||||
type convertTestDest struct {
|
||||
Name string `json:"name"`
|
||||
Value int `json:"value"`
|
||||
}
|
||||
|
||||
func TestConvertViaJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
validate func(t *testing.T)
|
||||
}{
|
||||
{
|
||||
name: "map_to_struct",
|
||||
input: map[string]interface{}{"type": "json_object", "name": "test_format"},
|
||||
validate: func(t *testing.T) {
|
||||
result, err := ConvertViaJSON[convertTestTarget](map[string]interface{}{"type": "json_object", "name": "test_format"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "json_object", result.Type)
|
||||
assert.Equal(t, "test_format", result.Name)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "struct_to_struct",
|
||||
validate: func(t *testing.T) {
|
||||
src := convertTestSource{Name: "hello", Value: 42}
|
||||
result, err := ConvertViaJSON[convertTestDest](src)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "hello", result.Name)
|
||||
assert.Equal(t, 42, result.Value)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "slice_conversion",
|
||||
validate: func(t *testing.T) {
|
||||
src := []interface{}{
|
||||
map[string]interface{}{"type": "image", "name": "a"},
|
||||
map[string]interface{}{"type": "video", "name": "b"},
|
||||
}
|
||||
result, err := ConvertViaJSON[[]convertTestTarget](src)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result, 2)
|
||||
assert.Equal(t, "image", result[0].Type)
|
||||
assert.Equal(t, "a", result[0].Name)
|
||||
assert.Equal(t, "video", result[1].Type)
|
||||
assert.Equal(t, "b", result[1].Name)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid_input_returns_error",
|
||||
validate: func(t *testing.T) {
|
||||
result, err := ConvertViaJSON[convertTestTarget](42)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, convertTestTarget{}, result)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nil_input",
|
||||
validate: func(t *testing.T) {
|
||||
result, err := ConvertViaJSON[convertTestTarget](nil)
|
||||
// nil marshals to "null" which unmarshals to zero struct — no error
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, convertTestTarget{}, result)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.validate(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
92
core/schemas/json_wasm.go
Normal file
92
core/schemas/json_wasm.go
Normal file
@@ -0,0 +1,92 @@
|
||||
//go:build tinygo || wasm
|
||||
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// Marshal encodes v to JSON bytes using the standard library.
|
||||
func Marshal(v interface{}) ([]byte, error) {
|
||||
return json.Marshal(v)
|
||||
}
|
||||
|
||||
// MarshalString encodes v to a JSON string using the standard library.
|
||||
func MarshalString(v interface{}) (string, error) {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// Unmarshal decodes JSON data into v using the standard library.
|
||||
func Unmarshal(data []byte, v interface{}) error {
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
// Compact removes insignificant whitespace from JSON-encoded src
|
||||
// and appends the result to dst.
|
||||
func Compact(dst *bytes.Buffer, src []byte) error {
|
||||
return json.Compact(dst, src)
|
||||
}
|
||||
|
||||
// MarshalSorted encodes v to JSON with map keys sorted alphabetically.
|
||||
// Use this when deterministic output is needed (e.g., hashing, caching keys).
|
||||
// Recursively normalizes OrderedMap values to plain maps so json.Marshal sorts keys.
|
||||
func MarshalSorted(v interface{}) ([]byte, error) {
|
||||
normalized := normalizeForSortedMarshal(v)
|
||||
return json.Marshal(normalized)
|
||||
}
|
||||
|
||||
// MarshalDeeplySorted encodes v to JSON with all map keys sorted alphabetically,
|
||||
// including nested maps inside OrderedMap and other custom types with MarshalJSON.
|
||||
// This ensures fully deterministic output for hashing/caching purposes.
|
||||
// In WASM builds, this is equivalent to MarshalSorted since both normalize recursively.
|
||||
func MarshalDeeplySorted(v interface{}) ([]byte, error) {
|
||||
normalized := normalizeForSortedMarshal(v)
|
||||
return json.Marshal(normalized)
|
||||
}
|
||||
|
||||
// normalizeForSortedMarshal recursively converts OrderedMaps and structs to plain maps
|
||||
// so that json.Marshal will sort their keys (Go 1.12+ sorts map keys).
|
||||
func normalizeForSortedMarshal(v interface{}) interface{} {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch val := v.(type) {
|
||||
case *OrderedMap:
|
||||
if val == nil {
|
||||
return nil
|
||||
}
|
||||
result := make(map[string]interface{}, val.Len())
|
||||
val.Range(func(k string, v interface{}) bool {
|
||||
result[k] = normalizeForSortedMarshal(v)
|
||||
return true
|
||||
})
|
||||
return result
|
||||
case OrderedMap:
|
||||
result := make(map[string]interface{}, val.Len())
|
||||
val.Range(func(k string, v interface{}) bool {
|
||||
result[k] = normalizeForSortedMarshal(v)
|
||||
return true
|
||||
})
|
||||
return result
|
||||
case map[string]interface{}:
|
||||
result := make(map[string]interface{}, len(val))
|
||||
for k, v := range val {
|
||||
result[k] = normalizeForSortedMarshal(v)
|
||||
}
|
||||
return result
|
||||
case []interface{}:
|
||||
result := make([]interface{}, len(val))
|
||||
for i, elem := range val {
|
||||
result[i] = normalizeForSortedMarshal(elem)
|
||||
}
|
||||
return result
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
207
core/schemas/jsonkeyorder.go
Normal file
207
core/schemas/jsonkeyorder.go
Normal file
@@ -0,0 +1,207 @@
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// JSONKeyOrder is a lightweight helper that preserves JSON key ordering through
|
||||
// struct serialization round-trips. Embed it in any struct with `json:"-"` tag.
|
||||
//
|
||||
// LLMs are autoregressive sequence models that are sensitive to JSON key ordering
|
||||
// in tool schemas. This helper ensures that when Bifrost deserializes and
|
||||
// re-serializes JSON, the original key order from the client is preserved.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// type MyStruct struct {
|
||||
// keyOrder JSONKeyOrder `json:"-"`
|
||||
// Field1 string `json:"field1"`
|
||||
// Field2 string `json:"field2"`
|
||||
// }
|
||||
//
|
||||
// func (s *MyStruct) UnmarshalJSON(data []byte) error {
|
||||
// type Alias MyStruct
|
||||
// if err := Unmarshal(data, (*Alias)(s)); err != nil { return err }
|
||||
// s.keyOrder.Capture(data)
|
||||
// return nil
|
||||
// }
|
||||
//
|
||||
// func (s MyStruct) MarshalJSON() ([]byte, error) {
|
||||
// type Alias MyStruct
|
||||
// data, err := Marshal(Alias(s))
|
||||
// if err != nil { return nil, err }
|
||||
// return s.keyOrder.Apply(data)
|
||||
// }
|
||||
type JSONKeyOrder struct {
|
||||
keys []string
|
||||
}
|
||||
|
||||
// Capture extracts and stores the top-level key order from raw JSON data.
|
||||
// Call this at the end of UnmarshalJSON.
|
||||
func (o *JSONKeyOrder) Capture(data []byte) {
|
||||
o.keys = ExtractTopLevelKeyOrder(data)
|
||||
}
|
||||
|
||||
// Apply reorders the keys in serialized JSON to match the captured order.
|
||||
// If no order was captured (programmatic construction), returns data unchanged.
|
||||
// Call this at the end of MarshalJSON.
|
||||
func (o *JSONKeyOrder) Apply(data []byte) ([]byte, error) {
|
||||
if len(o.keys) == 0 {
|
||||
return data, nil
|
||||
}
|
||||
return ReorderJSONKeys(data, o.keys)
|
||||
}
|
||||
|
||||
// ExtractTopLevelKeyOrder parses a JSON object and returns its top-level keys in
|
||||
// document order. Useful for capturing key order before struct deserialization
|
||||
// loses it, so that re-serialization can preserve the original order.
|
||||
func ExtractTopLevelKeyOrder(data []byte) []string {
|
||||
trimmed := bytes.TrimSpace(data)
|
||||
if len(trimmed) == 0 || trimmed[0] != '{' {
|
||||
return nil
|
||||
}
|
||||
|
||||
dec := json.NewDecoder(bytes.NewReader(trimmed))
|
||||
// Read opening '{'
|
||||
if _, err := dec.Token(); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var keys []string
|
||||
for dec.More() {
|
||||
tok, err := dec.Token()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
key, ok := tok.(string)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
keys = append(keys, key)
|
||||
// Skip the value (handles nested objects/arrays)
|
||||
if err := skipJSONValue(dec); err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// skipJSONValue reads and discards a single JSON value from a decoder.
|
||||
func skipJSONValue(dec *json.Decoder) error {
|
||||
tok, err := dec.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
delim, ok := tok.(json.Delim)
|
||||
if !ok {
|
||||
return nil // primitive value, already consumed
|
||||
}
|
||||
switch delim {
|
||||
case '{':
|
||||
for dec.More() {
|
||||
// skip key
|
||||
if _, err := dec.Token(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := skipJSONValue(dec); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err = dec.Token() // closing '}'
|
||||
return err
|
||||
case '[':
|
||||
for dec.More() {
|
||||
if err := skipJSONValue(dec); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err = dec.Token() // closing ']'
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReorderJSONKeys takes serialized JSON and a desired key order, and returns the
|
||||
// same JSON with top-level keys reordered. Keys present in `order` are emitted
|
||||
// first in that order; any remaining keys follow in their original order.
|
||||
// This is a general-purpose utility for preserving client-specified key order
|
||||
// through struct serialization/deserialization round-trips.
|
||||
func ReorderJSONKeys(data []byte, order []string) ([]byte, error) {
|
||||
// Parse into key → raw value pairs, preserving original values as-is
|
||||
trimmed := bytes.TrimSpace(data)
|
||||
if len(trimmed) < 2 || trimmed[0] != '{' {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// Use encoding/json decoder to get raw key-value pairs while preserving order
|
||||
dec := json.NewDecoder(bytes.NewReader(trimmed))
|
||||
dec.UseNumber()
|
||||
if _, err := dec.Token(); err != nil { // '{'
|
||||
return data, nil
|
||||
}
|
||||
|
||||
type kvPair struct {
|
||||
key string
|
||||
val json.RawMessage
|
||||
}
|
||||
|
||||
var pairs []kvPair
|
||||
pairMap := make(map[string]json.RawMessage)
|
||||
for dec.More() {
|
||||
tok, err := dec.Token()
|
||||
if err != nil {
|
||||
return data, nil
|
||||
}
|
||||
key, ok := tok.(string)
|
||||
if !ok {
|
||||
return data, nil
|
||||
}
|
||||
var val json.RawMessage
|
||||
if err := dec.Decode(&val); err != nil {
|
||||
return data, nil
|
||||
}
|
||||
pairs = append(pairs, kvPair{key, val})
|
||||
pairMap[key] = val
|
||||
}
|
||||
|
||||
// Rebuild JSON: first keys from `order`, then remaining keys in original order
|
||||
var buf bytes.Buffer
|
||||
buf.WriteByte('{')
|
||||
first := true
|
||||
emitted := make(map[string]bool, len(order))
|
||||
|
||||
for _, key := range order {
|
||||
val, exists := pairMap[key]
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
if !first {
|
||||
buf.WriteByte(',')
|
||||
}
|
||||
first = false
|
||||
keyBytes, _ := MarshalSorted(key)
|
||||
buf.Write(keyBytes)
|
||||
buf.WriteByte(':')
|
||||
buf.Write(val)
|
||||
emitted[key] = true
|
||||
}
|
||||
|
||||
// Remaining keys in their original document order
|
||||
for _, kv := range pairs {
|
||||
if emitted[kv.key] {
|
||||
continue
|
||||
}
|
||||
if !first {
|
||||
buf.WriteByte(',')
|
||||
}
|
||||
first = false
|
||||
keyBytes, _ := MarshalSorted(kv.key)
|
||||
buf.Write(keyBytes)
|
||||
buf.WriteByte(':')
|
||||
buf.Write(kv.val)
|
||||
}
|
||||
|
||||
buf.WriteByte('}')
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
17
core/schemas/kvstore.go
Normal file
17
core/schemas/kvstore.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package schemas
|
||||
|
||||
import "time"
|
||||
|
||||
// KVStore is a minimal interface for a key-value store used by Bifrost internals.
|
||||
// The concrete implementation (e.g. framework/kvstore.Store) is injected by the
|
||||
// caller and must satisfy this interface. Passing nil disables KV-backed features.
|
||||
type KVStore interface {
|
||||
Get(key string) (any, error)
|
||||
SetWithTTL(key string, value any, ttl time.Duration) error
|
||||
SetNXWithTTL(key string, value any, ttl time.Duration) (bool, error)
|
||||
Delete(key string) (bool, error)
|
||||
}
|
||||
|
||||
const (
|
||||
DefaultSessionStickyTTL = time.Hour
|
||||
)
|
||||
79
core/schemas/logger.go
Normal file
79
core/schemas/logger.go
Normal file
@@ -0,0 +1,79 @@
|
||||
// Package schemas defines the core schemas and types used by the Bifrost system.
|
||||
package schemas
|
||||
|
||||
// LogLevel represents the severity level of a log message.
|
||||
// Internally it maps to zerolog.Level for interoperability.
|
||||
type LogLevel string
|
||||
|
||||
// LogLevel constants for different severity levels.
|
||||
const (
|
||||
LogLevelDebug LogLevel = "debug"
|
||||
LogLevelInfo LogLevel = "info"
|
||||
LogLevelWarn LogLevel = "warn"
|
||||
LogLevelError LogLevel = "error"
|
||||
)
|
||||
|
||||
// LoggerOutputType represents the output type of a logger.
|
||||
type LoggerOutputType string
|
||||
|
||||
// LoggerOutputType constants for different output types.
|
||||
const (
|
||||
LoggerOutputTypeJSON LoggerOutputType = "json"
|
||||
LoggerOutputTypePretty LoggerOutputType = "pretty"
|
||||
)
|
||||
|
||||
// Logger defines the interface for logging operations in the Bifrost system.
|
||||
// Implementations of this interface should provide methods for logging messages
|
||||
// at different severity levels.
|
||||
type Logger interface {
|
||||
// Debug logs a debug-level message.
|
||||
// This is used for detailed debugging information that is typically only needed
|
||||
// during development or troubleshooting.
|
||||
Debug(msg string, args ...any)
|
||||
|
||||
// Info logs an info-level message.
|
||||
// This is used for general informational messages about normal operation.
|
||||
Info(msg string, args ...any)
|
||||
|
||||
// Warn logs a warning-level message.
|
||||
// This is used for potentially harmful situations that don't prevent normal operation.
|
||||
Warn(msg string, args ...any)
|
||||
|
||||
// Error logs an error-level message.
|
||||
// This is used for serious problems that need attention and may prevent normal operation.
|
||||
Error(msg string, args ...any)
|
||||
|
||||
// Fatal logs a fatal-level message.
|
||||
// This is used for critical situations that require immediate attention and will terminate the program.
|
||||
Fatal(msg string, args ...any)
|
||||
|
||||
// SetLevel sets the log level for the logger.
|
||||
SetLevel(level LogLevel)
|
||||
|
||||
// SetOutputType sets the output type for the logger.
|
||||
SetOutputType(outputType LoggerOutputType)
|
||||
|
||||
// LogHTTPRequest returns a LogEventBuilder for structured HTTP access logging.
|
||||
// The level parameter controls the log severity, msg is sent when Send() is called.
|
||||
// Use the fluent builder to attach typed fields before calling Send().
|
||||
LogHTTPRequest(level LogLevel, msg string) LogEventBuilder
|
||||
}
|
||||
|
||||
// LogEventBuilder provides a fluent interface for building structured log entries.
|
||||
type LogEventBuilder interface {
|
||||
Str(key, val string) LogEventBuilder
|
||||
Int(key string, val int) LogEventBuilder
|
||||
Int64(key string, val int64) LogEventBuilder
|
||||
Send()
|
||||
}
|
||||
|
||||
// noopLogEventBuilder is a no-op builder for loggers that don't need structured logging.
|
||||
type noopLogEventBuilder struct{}
|
||||
|
||||
func (noopLogEventBuilder) Str(string, string) LogEventBuilder { return noopLogEventBuilder{} }
|
||||
func (noopLogEventBuilder) Int(string, int) LogEventBuilder { return noopLogEventBuilder{} }
|
||||
func (noopLogEventBuilder) Int64(string, int64) LogEventBuilder { return noopLogEventBuilder{} }
|
||||
func (noopLogEventBuilder) Send() {}
|
||||
|
||||
// NoopLogEvent is a shared singleton no-op LogEventBuilder.
|
||||
var NoopLogEvent LogEventBuilder = noopLogEventBuilder{}
|
||||
243
core/schemas/mcp.go
Normal file
243
core/schemas/mcp.go
Normal file
@@ -0,0 +1,243 @@
|
||||
//go:build !tinygo && !wasm
|
||||
|
||||
// Package schemas defines the core schemas and types used by the Bifrost system.
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
)
|
||||
|
||||
// OAuth-related errors
|
||||
var (
|
||||
ErrOAuth2ConfigNotFound = errors.New("oauth2 config not found")
|
||||
ErrOAuth2ProviderNotAvailable = errors.New("oauth2 provider not available")
|
||||
ErrOAuth2TokenExpired = errors.New("oauth2 token expired")
|
||||
ErrOAuth2TokenInvalid = errors.New("oauth2 token invalid")
|
||||
ErrOAuth2RefreshFailed = errors.New("oauth2 token refresh failed")
|
||||
ErrOAuth2NotPerUserSession = errors.New("state does not match a per-user oauth session")
|
||||
ErrOAuth2TokenNotFound = errors.New("per-user oauth token not found for this identity and mcp server")
|
||||
ErrPerUserOAuthPendingFlowExpired = errors.New("per-user oauth pending flow has expired")
|
||||
)
|
||||
|
||||
// MCPUserOAuthRequiredError is returned when a per-user OAuth MCP server requires
|
||||
// the user to authenticate before tool execution can proceed.
|
||||
type MCPUserOAuthRequiredError struct {
|
||||
MCPClientID string `json:"mcp_client_id"`
|
||||
MCPClientName string `json:"mcp_client_name"`
|
||||
AuthorizeURL string `json:"authorize_url"`
|
||||
SessionID string `json:"session_id"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func (e *MCPUserOAuthRequiredError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// MCPConfig represents the configuration for MCP integration in Bifrost.
|
||||
// It enables tool auto-discovery and execution from local and external MCP servers.
|
||||
type MCPConfig struct {
|
||||
ClientConfigs []*MCPClientConfig `json:"client_configs,omitempty"` // Per-client execution configurations
|
||||
ToolManagerConfig *MCPToolManagerConfig `json:"tool_manager_config,omitempty"` // MCP tool manager configuration
|
||||
ToolSyncInterval time.Duration `json:"tool_sync_interval,omitempty"` // Global default interval for syncing tools from MCP servers (0 = use default 10 min)
|
||||
|
||||
// Function to fetch a new request ID for each tool call result message in agent mode,
|
||||
// this is used to ensure that the tool call result messages are unique and can be tracked in plugins or by the user.
|
||||
// This id is attached to ctx.Value(schemas.BifrostContextKeyRequestID) in the agent mode.
|
||||
// If not provider, same request ID is used for all tool call result messages without any overrides.
|
||||
FetchNewRequestIDFunc func(ctx *BifrostContext) string `json:"-"`
|
||||
|
||||
// PluginPipelineProvider returns a plugin pipeline for running MCP plugin hooks.
|
||||
// Used when executeCode tool calls nested MCP tools to ensure plugins run for them.
|
||||
// The plugin pipeline should be released back to the pool using ReleasePluginPipeline.
|
||||
PluginPipelineProvider func() interface{} `json:"-"`
|
||||
|
||||
// ReleasePluginPipeline releases a plugin pipeline back to the pool.
|
||||
// This should be called after the plugin pipeline is no longer needed.
|
||||
ReleasePluginPipeline func(pipeline interface{}) `json:"-"`
|
||||
}
|
||||
|
||||
type MCPToolManagerConfig struct {
|
||||
ToolExecutionTimeout time.Duration `json:"tool_execution_timeout"`
|
||||
MaxAgentDepth int `json:"max_agent_depth"`
|
||||
CodeModeBindingLevel CodeModeBindingLevel `json:"code_mode_binding_level,omitempty"` // How tools are exposed in VFS: "server" or "tool"
|
||||
DisableAutoToolInject bool `json:"disable_auto_tool_inject,omitempty"` // When true, MCP tools are not injected into requests by default
|
||||
}
|
||||
|
||||
const (
|
||||
DefaultMaxAgentDepth = 10
|
||||
DefaultToolExecutionTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// CodeModeBindingLevel defines how tools are exposed in the VFS for code execution
|
||||
type CodeModeBindingLevel string
|
||||
|
||||
const (
|
||||
CodeModeBindingLevelServer CodeModeBindingLevel = "server"
|
||||
CodeModeBindingLevelTool CodeModeBindingLevel = "tool"
|
||||
)
|
||||
|
||||
// MCPAuthType defines the authentication type for MCP connections
|
||||
type MCPAuthType string
|
||||
|
||||
const (
|
||||
MCPAuthTypeNone MCPAuthType = "none" // No authentication
|
||||
MCPAuthTypeHeaders MCPAuthType = "headers" // Header-based authentication (API keys, etc.)
|
||||
MCPAuthTypeOauth MCPAuthType = "oauth" // OAuth 2.0 authentication (server-level, admin authenticates once)
|
||||
MCPAuthTypePerUserOauth MCPAuthType = "per_user_oauth" // Per-user OAuth 2.0 authentication (each user authenticates individually)
|
||||
)
|
||||
|
||||
// MCPClientConfig defines tool filtering for an MCP client.
|
||||
type MCPClientConfig struct {
|
||||
ID string `json:"client_id"` // Client ID
|
||||
Name string `json:"name"` // Client name
|
||||
IsCodeModeClient bool `json:"is_code_mode_client"` // Whether the client is a code mode client
|
||||
ConnectionType MCPConnectionType `json:"connection_type"` // How to connect (HTTP, STDIO, SSE, or InProcess)
|
||||
ConnectionString *EnvVar `json:"connection_string,omitempty"` // HTTP or SSE URL (required for HTTP or SSE connections)
|
||||
StdioConfig *MCPStdioConfig `json:"stdio_config,omitempty"` // STDIO configuration (required for STDIO connections)
|
||||
AuthType MCPAuthType `json:"auth_type"` // Authentication type (none, headers, or oauth)
|
||||
OauthConfigID *string `json:"oauth_config_id,omitempty"` // OAuth config ID (references oauth_configs table)
|
||||
State string `json:"state,omitempty"` // Connection state (connected, disconnected, error)
|
||||
Headers map[string]EnvVar `json:"headers,omitempty"` // Headers to send with the request (for headers auth type)
|
||||
AllowedExtraHeaders WhiteList `json:"allowed_extra_headers,omitempty"` // Allowlist of request-level headers that callers may forward to this MCP server at execution time
|
||||
InProcessServer *server.MCPServer `json:"-"` // MCP server instance for in-process connections (Go package only)
|
||||
ToolsToExecute WhiteList `json:"tools_to_execute,omitempty"` // Include-only list.
|
||||
// ToolsToExecute semantics:
|
||||
// - ["*"] => all tools are included
|
||||
// - [] => no tools are included (deny-by-default)
|
||||
// - nil/omitted => treated as [] (no tools)
|
||||
// - ["tool1", "tool2"] => include only the specified tools
|
||||
ToolsToAutoExecute WhiteList `json:"tools_to_auto_execute,omitempty"` // Auto-execute list.
|
||||
// ToolsToAutoExecute semantics:
|
||||
// - ["*"] => all tools are auto-executed
|
||||
// - [] => no tools are auto-executed (deny-by-default)
|
||||
// - nil/omitted => treated as [] (no tools)
|
||||
// - ["tool1", "tool2"] => auto-execute only the specified tools
|
||||
// Note: If a tool is in ToolsToAutoExecute but not in ToolsToExecute, it will be skipped.
|
||||
IsPingAvailable *bool `json:"is_ping_available,omitempty"` // Whether the MCP server supports ping for health checks (nil/true = ping; false = listTools). Defaults to true.
|
||||
ToolSyncInterval time.Duration `json:"tool_sync_interval,omitempty"` // Per-client override for tool sync interval (0 = use global, negative = disabled)
|
||||
ToolPricing map[string]float64 `json:"tool_pricing,omitempty"` // Tool pricing for each tool (cost per execution)
|
||||
ConfigHash string `json:"-"` // Config hash for reconciliation (not serialized)
|
||||
AllowOnAllVirtualKeys bool `json:"allow_on_all_virtual_keys"` // Whether to allow the MCP client to run on all virtual keys
|
||||
|
||||
// Discovered tools for per-user OAuth clients (persisted so they survive restart)
|
||||
DiscoveredTools map[string]ChatTool `json:"-"` // Discovered tool schemas keyed by prefixed name
|
||||
DiscoveredToolNameMapping map[string]string `json:"-"` // Mapping from sanitized tool names to original MCP names
|
||||
}
|
||||
|
||||
// NewMCPClientConfigFromMap creates a new MCP client config from a map[string]any.
|
||||
func NewMCPClientConfigFromMap(configMap map[string]any) *MCPClientConfig {
|
||||
var config MCPClientConfig
|
||||
data, err := MarshalSorted(configMap)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if err := Unmarshal(data, &config); err != nil {
|
||||
return nil
|
||||
}
|
||||
return &config
|
||||
}
|
||||
|
||||
// HttpHeaders returns the HTTP headers for the MCP client config.
|
||||
func (c *MCPClientConfig) HttpHeaders(ctx context.Context, oauth2Provider OAuth2Provider) (map[string]string, error) {
|
||||
headers := make(map[string]string)
|
||||
|
||||
switch c.AuthType {
|
||||
case MCPAuthTypeOauth:
|
||||
if c.OauthConfigID == nil {
|
||||
return nil, ErrOAuth2ConfigNotFound
|
||||
}
|
||||
if oauth2Provider == nil {
|
||||
return nil, ErrOAuth2ProviderNotAvailable
|
||||
}
|
||||
accessToken, err := oauth2Provider.GetAccessToken(ctx, *c.OauthConfigID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Validate token format - trim whitespace and check for invalid characters
|
||||
accessToken = strings.TrimSpace(accessToken)
|
||||
if accessToken == "" {
|
||||
return nil, errors.New("access token is empty")
|
||||
}
|
||||
if strings.ContainsAny(accessToken, "\n\r\t") {
|
||||
return nil, errors.New("access token contains invalid characters")
|
||||
}
|
||||
headers["Authorization"] = "Bearer " + accessToken
|
||||
case MCPAuthTypeHeaders:
|
||||
for key, value := range c.Headers {
|
||||
headers[key] = value.GetValue()
|
||||
}
|
||||
case MCPAuthTypePerUserOauth:
|
||||
// Per-user OAuth: headers are injected per-call in executeToolInternal, not at connection level
|
||||
return headers, nil
|
||||
case MCPAuthTypeNone:
|
||||
// No headers to add
|
||||
default:
|
||||
// Default to headers behavior for backward compatibility
|
||||
for key, value := range c.Headers {
|
||||
headers[key] = value.GetValue()
|
||||
}
|
||||
}
|
||||
|
||||
return headers, nil
|
||||
}
|
||||
|
||||
// MCPConnectionType defines the communication protocol for MCP connections
|
||||
type MCPConnectionType string
|
||||
|
||||
const (
|
||||
MCPConnectionTypeHTTP MCPConnectionType = "http" // HTTP-based connection
|
||||
MCPConnectionTypeSTDIO MCPConnectionType = "stdio" // STDIO-based connection
|
||||
MCPConnectionTypeSSE MCPConnectionType = "sse" // Server-Sent Events connection
|
||||
MCPConnectionTypeInProcess MCPConnectionType = "inprocess" // In-process (in-memory) connection
|
||||
)
|
||||
|
||||
// MCPStdioConfig defines how to launch a STDIO-based MCP server.
|
||||
type MCPStdioConfig struct {
|
||||
Command string `json:"command"` // Executable command to run
|
||||
Args []string `json:"args"` // Command line arguments
|
||||
Envs []string `json:"envs"` // Environment variables required
|
||||
}
|
||||
|
||||
type MCPConnectionState string
|
||||
|
||||
const (
|
||||
MCPConnectionStateConnected MCPConnectionState = "connected" // Client is connected and ready to use
|
||||
MCPConnectionStateDisconnected MCPConnectionState = "disconnected" // Client is not connected
|
||||
MCPConnectionStateError MCPConnectionState = "error" // Client is in an error state, and cannot be used
|
||||
MCPConnectionStatePendingTools MCPConnectionState = "pending_tools" // Connected but tools not yet populated
|
||||
)
|
||||
|
||||
// MCPClientState represents a connected MCP client with its configuration and tools.
|
||||
// It is used internally by the MCP manager to track the state of a connected MCP client.
|
||||
type MCPClientState struct {
|
||||
Name string // Unique name for this client
|
||||
Conn *client.Client // Active MCP client connection
|
||||
ExecutionConfig *MCPClientConfig // Tool filtering settings
|
||||
ToolMap map[string]ChatTool // Available tools mapped by name
|
||||
ToolNameMapping map[string]string // Maps sanitized_name -> original_mcp_name (e.g., "notion_search" -> "notion-search")
|
||||
ConnectionInfo *MCPClientConnectionInfo `json:"connection_info"` // Connection metadata for management
|
||||
CancelFunc context.CancelFunc `json:"-"` // Cancel function for SSE connections (not serialized)
|
||||
State MCPConnectionState // Connection state (connected, disconnected, error)
|
||||
}
|
||||
|
||||
// MCPClientConnectionInfo stores metadata about how a client is connected.
|
||||
type MCPClientConnectionInfo struct {
|
||||
Type MCPConnectionType `json:"type"` // Connection type (HTTP, STDIO, SSE, or InProcess)
|
||||
ConnectionURL *string `json:"connection_url,omitempty"` // HTTP/SSE endpoint URL (for HTTP/SSE connections)
|
||||
StdioCommandString *string `json:"stdio_command_string,omitempty"` // Command string for display (for STDIO connections)
|
||||
}
|
||||
|
||||
// MCPClient represents a connected MCP client with its configuration and tools,
|
||||
// and connection information, after it has been initialized.
|
||||
// It is returned by GetMCPClients() method in bifrost.
|
||||
type MCPClient struct {
|
||||
Config *MCPClientConfig `json:"config"` // Tool filtering settings
|
||||
Tools []ChatToolFunction `json:"tools"` // Available tools
|
||||
State MCPConnectionState `json:"state"` // Connection state
|
||||
}
|
||||
7
core/schemas/mcp_wasm.go
Normal file
7
core/schemas/mcp_wasm.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build tinygo || wasm
|
||||
|
||||
package schemas
|
||||
|
||||
// MCPConfig is a stub for WASM builds.
|
||||
// MCP functionality is not available in WASM plugins.
|
||||
type MCPConfig struct{}
|
||||
269
core/schemas/models.go
Normal file
269
core/schemas/models.go
Normal file
@@ -0,0 +1,269 @@
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// DefaultPageSize is the default page size for listing models
|
||||
const DefaultPageSize = 1000
|
||||
|
||||
// MaxPaginationRequests is the maximum number of pagination requests to make
|
||||
const MaxPaginationRequests = 20
|
||||
|
||||
// Structure to collect results from goroutines
|
||||
type ListModelsByKeyResult struct {
|
||||
Response *BifrostListModelsResponse
|
||||
Err *BifrostError
|
||||
KeyID string
|
||||
}
|
||||
|
||||
// KeyStatus represents the status of model listing for a specific key
|
||||
type KeyStatus struct {
|
||||
KeyID string `json:"key_id"` // Empty for keyless providers
|
||||
Status KeyStatusType `json:"status"` // "success", "failed"
|
||||
Provider ModelProvider `json:"provider"` // Always populated
|
||||
Error *BifrostError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// MarshalJSON implements custom JSON marshaling for KeyStatus to prevent
|
||||
// circular reference: KeyStatus.Error → BifrostError.ExtraFields.KeyStatuses → KeyStatus.
|
||||
func (k KeyStatus) MarshalJSON() ([]byte, error) {
|
||||
type Alias KeyStatus
|
||||
alias := Alias(k)
|
||||
if alias.Error != nil {
|
||||
errCopy := *alias.Error
|
||||
errCopy.ExtraFields.KeyStatuses = nil
|
||||
alias.Error = &errCopy
|
||||
}
|
||||
return MarshalSorted(alias)
|
||||
}
|
||||
|
||||
type BifrostListModelsRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
|
||||
PageSize int `json:"page_size"`
|
||||
|
||||
// PageToken: Token received from previous request to retrieve next page
|
||||
PageToken string `json:"page_token"`
|
||||
|
||||
// Unfiltered: If true, the response will include all models for the provider, regardless of the allowed models (internal bifrost use only, not sent to the provider)
|
||||
Unfiltered bool `json:"-"`
|
||||
|
||||
// ExtraParams: Additional provider-specific query parameters
|
||||
// This allows for flexibility to pass any custom parameters that specific providers might support
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
type BifrostListModelsResponse struct {
|
||||
Data []Model `json:"data"`
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
NextPageToken string `json:"next_page_token,omitempty"` // Token to retrieve next page
|
||||
|
||||
// Key-level status tracking for multi-key providers
|
||||
KeyStatuses []KeyStatus `json:"key_statuses,omitempty"`
|
||||
|
||||
// Anthropic specific fields
|
||||
FirstID *string `json:"-"`
|
||||
LastID *string `json:"-"`
|
||||
HasMore *bool `json:"-"`
|
||||
}
|
||||
|
||||
// ApplyPagination applies offset-based pagination to a BifrostListModelsResponse.
|
||||
// Uses opaque tokens with LastID validation to ensure cursor integrity.
|
||||
// Returns the paginated response with properly set NextPageToken.
|
||||
func (response *BifrostListModelsResponse) ApplyPagination(pageSize int, pageToken string) *BifrostListModelsResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
totalItems := len(response.Data)
|
||||
|
||||
if pageSize <= 0 {
|
||||
return response
|
||||
}
|
||||
|
||||
cursor := decodePaginationCursor(pageToken)
|
||||
offset := cursor.Offset
|
||||
|
||||
// Validate cursor integrity if LastID is present
|
||||
if cursor.LastID != "" && !validatePaginationCursor(cursor, response.Data) {
|
||||
// Invalid cursor: reset to beginning
|
||||
offset = 0
|
||||
}
|
||||
|
||||
if offset >= totalItems {
|
||||
// Return empty page, no next token
|
||||
return &BifrostListModelsResponse{
|
||||
Data: []Model{},
|
||||
ExtraFields: response.ExtraFields,
|
||||
NextPageToken: "",
|
||||
KeyStatuses: response.KeyStatuses,
|
||||
}
|
||||
}
|
||||
|
||||
endIndex := offset + pageSize
|
||||
if endIndex > totalItems {
|
||||
endIndex = totalItems
|
||||
}
|
||||
|
||||
paginatedData := response.Data[offset:endIndex]
|
||||
|
||||
paginatedResponse := &BifrostListModelsResponse{
|
||||
Data: paginatedData,
|
||||
ExtraFields: response.ExtraFields,
|
||||
KeyStatuses: response.KeyStatuses,
|
||||
}
|
||||
|
||||
if endIndex < totalItems {
|
||||
// Get the last item ID for cursor validation
|
||||
var lastID string
|
||||
if len(paginatedData) > 0 {
|
||||
lastID = paginatedData[len(paginatedData)-1].ID
|
||||
}
|
||||
|
||||
nextToken, err := encodePaginationCursor(endIndex, lastID)
|
||||
if err == nil {
|
||||
paginatedResponse.NextPageToken = nextToken
|
||||
}
|
||||
} else {
|
||||
paginatedResponse.NextPageToken = ""
|
||||
}
|
||||
|
||||
return paginatedResponse
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
CanonicalSlug *string `json:"canonical_slug,omitempty"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
Alias *string `json:"alias,omitempty"` // Provider API identifier this model alias maps to (e.g. Azure deployment name, Bedrock ARN)
|
||||
Created *int64 `json:"created,omitempty"`
|
||||
ContextLength *int `json:"context_length,omitempty"`
|
||||
MaxInputTokens *int `json:"max_input_tokens,omitempty"`
|
||||
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
|
||||
Architecture *Architecture `json:"architecture,omitempty"`
|
||||
Pricing *Pricing `json:"pricing,omitempty"`
|
||||
TopProvider *TopProvider `json:"top_provider,omitempty"`
|
||||
PerRequestLimits *PerRequestLimits `json:"per_request_limits,omitempty"`
|
||||
SupportedParameters []string `json:"supported_parameters,omitempty"`
|
||||
DefaultParameters *DefaultParameters `json:"default_parameters,omitempty"`
|
||||
HuggingFaceID *string `json:"hugging_face_id,omitempty"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
|
||||
OwnedBy *string `json:"owned_by,omitempty"`
|
||||
SupportedMethods []string `json:"supported_methods,omitempty"`
|
||||
|
||||
// ProviderExtra carries opaque provider-specific data (e.g. Anthropic capabilities)
|
||||
// through the Bifrost pipeline for integration reverse-conversion. Never serialized.
|
||||
ProviderExtra json.RawMessage `json:"-"`
|
||||
}
|
||||
|
||||
type Architecture struct {
|
||||
Modality *string `json:"modality,omitempty"`
|
||||
Tokenizer *string `json:"tokenizer,omitempty"`
|
||||
InstructType *string `json:"instruct_type,omitempty"`
|
||||
InputModalities []string `json:"input_modalities,omitempty"`
|
||||
OutputModalities []string `json:"output_modalities,omitempty"`
|
||||
}
|
||||
|
||||
type Pricing struct {
|
||||
Prompt *string `json:"prompt,omitempty"`
|
||||
Completion *string `json:"completion,omitempty"`
|
||||
Request *string `json:"request,omitempty"`
|
||||
Image *string `json:"image,omitempty"`
|
||||
WebSearch *string `json:"web_search,omitempty"`
|
||||
InternalReasoning *string `json:"internal_reasoning,omitempty"`
|
||||
InputCacheRead *string `json:"input_cache_read,omitempty"`
|
||||
InputCacheWrite *string `json:"input_cache_write,omitempty"`
|
||||
}
|
||||
|
||||
type TopProvider struct {
|
||||
IsModerated *bool `json:"is_moderated,omitempty"`
|
||||
ContextLength *int `json:"context_length,omitempty"`
|
||||
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
|
||||
}
|
||||
|
||||
type PerRequestLimits struct {
|
||||
PromptTokens *int `json:"prompt_tokens,omitempty"`
|
||||
CompletionTokens *int `json:"completion_tokens,omitempty"`
|
||||
}
|
||||
|
||||
type DefaultParameters struct {
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
}
|
||||
|
||||
// paginationCursor represents the internal cursor structure for pagination.
|
||||
type paginationCursor struct {
|
||||
Offset int `json:"o"`
|
||||
LastID string `json:"l,omitempty"`
|
||||
}
|
||||
|
||||
// encodePaginationCursor creates an opaque base64-encoded page token from cursor data.
|
||||
// Returns empty string if offset is 0 or negative.
|
||||
func encodePaginationCursor(offset int, lastID string) (string, error) {
|
||||
if offset <= 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
cursor := paginationCursor{
|
||||
Offset: offset,
|
||||
LastID: lastID,
|
||||
}
|
||||
|
||||
jsonData, err := MarshalSorted(cursor)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal pagination cursor: %w", err)
|
||||
}
|
||||
|
||||
// Use URL-safe base64 encoding without padding for opaque token
|
||||
encoded := base64.RawURLEncoding.EncodeToString(jsonData)
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
// decodePaginationCursor extracts cursor data from an opaque base64-encoded page token.
|
||||
// Returns cursor with 0 offset for empty or invalid tokens.
|
||||
func decodePaginationCursor(token string) paginationCursor {
|
||||
if token == "" {
|
||||
return paginationCursor{}
|
||||
}
|
||||
|
||||
// Decode base64
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(token)
|
||||
if err != nil {
|
||||
return paginationCursor{}
|
||||
}
|
||||
|
||||
var cursor paginationCursor
|
||||
if err := Unmarshal(decoded, &cursor); err != nil {
|
||||
return paginationCursor{}
|
||||
}
|
||||
|
||||
if cursor.Offset < 0 {
|
||||
return paginationCursor{}
|
||||
}
|
||||
|
||||
return cursor
|
||||
}
|
||||
|
||||
// validatePaginationCursor validates that the cursor matches the expected position in the data.
|
||||
// Returns true if the cursor is valid, false otherwise.
|
||||
func validatePaginationCursor(cursor paginationCursor, data []Model) bool {
|
||||
if cursor.LastID == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
if cursor.Offset <= 0 || cursor.Offset > len(data) {
|
||||
return false
|
||||
}
|
||||
|
||||
prevIndex := cursor.Offset - 1
|
||||
if prevIndex >= 0 && prevIndex < len(data) {
|
||||
return data[prevIndex].ID == cursor.LastID
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
117
core/schemas/models_test.go
Normal file
117
core/schemas/models_test.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Raw types that mirror the real KeyStatus → BifrostError → ExtraFields → KeyStatus
|
||||
// chain but without any custom MarshalJSON. Used to reproduce the cycle error
|
||||
// that would occur without the fix.
|
||||
type rawKeyStatusExtraFields struct {
|
||||
KeyStatuses []rawKeyStatus `json:"key_statuses,omitempty"`
|
||||
}
|
||||
|
||||
type rawKeyStatusBifrostError struct {
|
||||
IsBifrostError bool `json:"is_bifrost_error"`
|
||||
Error *ErrorField `json:"error"`
|
||||
ExtraFields rawKeyStatusExtraFields `json:"extra_fields"`
|
||||
}
|
||||
|
||||
type rawKeyStatus struct {
|
||||
KeyID string `json:"key_id"`
|
||||
Status KeyStatusType `json:"status"`
|
||||
Provider ModelProvider `json:"provider"`
|
||||
Error *rawKeyStatusBifrostError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// TestKeyStatusMarshalJSON_ReproduceCycle proves that without the custom MarshalJSON,
|
||||
// the circular reference between KeyStatus and BifrostError causes a marshaling failure.
|
||||
func TestKeyStatusMarshalJSON_ReproduceCycle(t *testing.T) {
|
||||
bifrostErr := &rawKeyStatusBifrostError{
|
||||
IsBifrostError: true,
|
||||
Error: &ErrorField{Message: "test error"},
|
||||
}
|
||||
keyStatus := rawKeyStatus{
|
||||
KeyID: "key-1",
|
||||
Status: KeyStatusListModelsFailed,
|
||||
Provider: "test-provider",
|
||||
Error: bifrostErr,
|
||||
}
|
||||
// Create the same cycle that HandleKeylessListModelsRequest creates
|
||||
bifrostErr.ExtraFields.KeyStatuses = []rawKeyStatus{keyStatus}
|
||||
|
||||
// Without any custom MarshalJSON, this must fail with a cycle error
|
||||
_, err := json.Marshal(keyStatus)
|
||||
require.Error(t, err, "expected cycle error without the MarshalJSON fix")
|
||||
assert.Contains(t, err.Error(), "cycle", "error should mention a cycle")
|
||||
}
|
||||
|
||||
func TestKeyStatusMarshalJSON_NoCycle(t *testing.T) {
|
||||
bifrostErr := &BifrostError{
|
||||
IsBifrostError: true,
|
||||
Error: &ErrorField{Message: "test error"},
|
||||
}
|
||||
keyStatus := KeyStatus{
|
||||
KeyID: "key-1",
|
||||
Status: KeyStatusListModelsFailed,
|
||||
Provider: "test-provider",
|
||||
Error: bifrostErr,
|
||||
}
|
||||
// Create the same cycle that HandleKeylessListModelsRequest creates
|
||||
bifrostErr.ExtraFields.KeyStatuses = []KeyStatus{keyStatus}
|
||||
|
||||
data, err := Marshal(keyStatus)
|
||||
require.NoError(t, err, "Marshal should not fail on circular KeyStatus")
|
||||
|
||||
// Verify the output doesn't contain nested key_statuses
|
||||
assert.False(t, bytes.Contains(data, []byte(`"key_statuses"`)),
|
||||
"expected key_statuses to be omitted from nested error")
|
||||
}
|
||||
|
||||
func TestKeyStatusMarshalJSON_NilError(t *testing.T) {
|
||||
keyStatus := KeyStatus{
|
||||
KeyID: "key-2",
|
||||
Status: "success",
|
||||
Provider: "test-provider",
|
||||
}
|
||||
|
||||
data, err := Marshal(keyStatus)
|
||||
require.NoError(t, err, "Marshal should not fail on KeyStatus with nil error")
|
||||
assert.Contains(t, string(data), `"key_id":"key-2"`)
|
||||
assert.NotContains(t, string(data), `"error"`)
|
||||
}
|
||||
|
||||
func TestKeyStatusMarshalJSON_PreservesErrorFields(t *testing.T) {
|
||||
statusCode := 401
|
||||
bifrostErr := &BifrostError{
|
||||
IsBifrostError: true,
|
||||
StatusCode: &statusCode,
|
||||
Error: &ErrorField{Message: "unauthorized"},
|
||||
ExtraFields: BifrostErrorExtraFields{
|
||||
Provider: "openai",
|
||||
OriginalModelRequested: "gpt-4",
|
||||
},
|
||||
}
|
||||
keyStatus := KeyStatus{
|
||||
KeyID: "key-3",
|
||||
Status: KeyStatusListModelsFailed,
|
||||
Provider: "openai",
|
||||
Error: bifrostErr,
|
||||
}
|
||||
// Create cycle
|
||||
bifrostErr.ExtraFields.KeyStatuses = []KeyStatus{keyStatus}
|
||||
|
||||
data, err := Marshal(keyStatus)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Error fields other than key_statuses should be preserved
|
||||
dataStr := string(data)
|
||||
assert.Contains(t, dataStr, `"unauthorized"`)
|
||||
assert.Contains(t, dataStr, `"original_model_requested":"gpt-4"`)
|
||||
assert.Contains(t, dataStr, `"status_code":401`)
|
||||
}
|
||||
2378
core/schemas/mux.go
Normal file
2378
core/schemas/mux.go
Normal file
File diff suppressed because it is too large
Load Diff
702
core/schemas/mux_test.go
Normal file
702
core/schemas/mux_test.go
Normal file
@@ -0,0 +1,702 @@
|
||||
package schemas
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestToChatMessages_PreservesDeveloperRole(t *testing.T) {
|
||||
messages := []ResponsesMessage{
|
||||
{
|
||||
Role: Ptr(ResponsesInputMessageRoleDeveloper),
|
||||
Content: &ResponsesMessageContent{
|
||||
ContentBlocks: []ResponsesMessageContentBlock{
|
||||
{
|
||||
Type: ResponsesInputMessageContentBlockTypeText,
|
||||
Text: Ptr("You are helpful"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chatMessages := ToChatMessages(messages)
|
||||
if len(chatMessages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(chatMessages))
|
||||
}
|
||||
if chatMessages[0].Role != ChatMessageRoleDeveloper {
|
||||
t.Fatalf("expected role %q, got %q", ChatMessageRoleDeveloper, chatMessages[0].Role)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToChatRequest_NormalizesDeveloperRoleToSystemForFallback(t *testing.T) {
|
||||
req := &BifrostResponsesRequest{
|
||||
Input: []ResponsesMessage{
|
||||
{
|
||||
Role: Ptr(ResponsesInputMessageRoleDeveloper),
|
||||
Content: &ResponsesMessageContent{
|
||||
ContentBlocks: []ResponsesMessageContentBlock{
|
||||
{
|
||||
Type: ResponsesInputMessageContentBlockTypeText,
|
||||
Text: Ptr("You are helpful"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Params: &ResponsesParameters{},
|
||||
}
|
||||
|
||||
chatReq := req.ToChatRequest()
|
||||
if chatReq == nil {
|
||||
t.Fatal("expected non-nil chat request")
|
||||
}
|
||||
if len(chatReq.Input) != 1 {
|
||||
t.Fatalf("expected 1 chat message, got %d", len(chatReq.Input))
|
||||
}
|
||||
if chatReq.Input[0].Role != ChatMessageRoleSystem {
|
||||
t.Fatalf("expected role %q in fallback conversion, got %q", ChatMessageRoleSystem, chatReq.Input[0].Role)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToChatMessages_LeavesExistingSupportedRolesUnchanged(t *testing.T) {
|
||||
messages := []ResponsesMessage{
|
||||
{Role: Ptr(ResponsesInputMessageRoleSystem)},
|
||||
{Role: Ptr(ResponsesInputMessageRoleUser)},
|
||||
{Role: Ptr(ResponsesInputMessageRoleAssistant)},
|
||||
}
|
||||
|
||||
chatMessages := ToChatMessages(messages)
|
||||
if len(chatMessages) != len(messages) {
|
||||
t.Fatalf("expected %d messages, got %d", len(messages), len(chatMessages))
|
||||
}
|
||||
|
||||
if chatMessages[0].Role != ChatMessageRoleSystem {
|
||||
t.Fatalf("expected system role, got %q", chatMessages[0].Role)
|
||||
}
|
||||
if chatMessages[1].Role != ChatMessageRoleUser {
|
||||
t.Fatalf("expected user role, got %q", chatMessages[1].Role)
|
||||
}
|
||||
if chatMessages[2].Role != ChatMessageRoleAssistant {
|
||||
t.Fatalf("expected assistant role, got %q", chatMessages[2].Role)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToChatRequest_FiltersUnsupportedResponsesToolsForFallback(t *testing.T) {
|
||||
validName := "valid_tool"
|
||||
invalidName := " "
|
||||
req := &BifrostResponsesRequest{
|
||||
Params: &ResponsesParameters{
|
||||
Tools: []ResponsesTool{
|
||||
{
|
||||
Type: ResponsesToolTypeFunction,
|
||||
Name: &validName,
|
||||
ResponsesToolFunction: &ResponsesToolFunction{
|
||||
Parameters: &ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: &OrderedMap{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: ResponsesToolTypeFunction,
|
||||
Name: &invalidName,
|
||||
},
|
||||
{
|
||||
Type: ResponsesToolTypeMCP,
|
||||
Name: Ptr("mcp_tool"),
|
||||
},
|
||||
{
|
||||
Type: ResponsesToolTypeWebSearch,
|
||||
Name: Ptr("web_search"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chatReq := req.ToChatRequest()
|
||||
if chatReq == nil || chatReq.Params == nil {
|
||||
t.Fatal("expected non-nil chat request params")
|
||||
}
|
||||
if len(chatReq.Params.Tools) != 1 {
|
||||
t.Fatalf("expected 1 valid fallback tool, got %d", len(chatReq.Params.Tools))
|
||||
}
|
||||
if chatReq.Params.Tools[0].Type != ChatToolTypeFunction {
|
||||
t.Fatalf("expected tool type %q, got %q", ChatToolTypeFunction, chatReq.Params.Tools[0].Type)
|
||||
}
|
||||
if chatReq.Params.Tools[0].Function == nil || chatReq.Params.Tools[0].Function.Name != validName {
|
||||
t.Fatalf("expected function tool %q to be preserved", validName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToChatRequest_DropsInvalidToolChoiceForFallback(t *testing.T) {
|
||||
validName := "valid_tool"
|
||||
invalidChoiceName := "missing_tool"
|
||||
req := &BifrostResponsesRequest{
|
||||
Params: &ResponsesParameters{
|
||||
Tools: []ResponsesTool{
|
||||
{
|
||||
Type: ResponsesToolTypeFunction,
|
||||
Name: &validName,
|
||||
},
|
||||
},
|
||||
ToolChoice: &ResponsesToolChoice{
|
||||
ResponsesToolChoiceStruct: &ResponsesToolChoiceStruct{
|
||||
Type: ResponsesToolChoiceTypeFunction,
|
||||
Name: &invalidChoiceName,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chatReq := req.ToChatRequest()
|
||||
if chatReq == nil || chatReq.Params == nil {
|
||||
t.Fatal("expected non-nil chat request params")
|
||||
}
|
||||
if chatReq.Params.ToolChoice != nil {
|
||||
t.Fatal("expected incompatible tool choice to be removed for fallback")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToChatRequest_AllNonFunctionToolsDropsToolsAndToolChoice(t *testing.T) {
|
||||
auto := string(ChatToolChoiceTypeAuto)
|
||||
req := &BifrostResponsesRequest{
|
||||
Params: &ResponsesParameters{
|
||||
Tools: []ResponsesTool{
|
||||
{Type: ResponsesToolTypeMCP, Name: Ptr("mcp")},
|
||||
{Type: ResponsesToolTypeWebSearch, Name: Ptr("search")},
|
||||
},
|
||||
ToolChoice: &ResponsesToolChoice{
|
||||
ResponsesToolChoiceStr: &auto,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chatReq := req.ToChatRequest()
|
||||
if chatReq == nil || chatReq.Params == nil {
|
||||
t.Fatal("expected non-nil chat request params")
|
||||
}
|
||||
if chatReq.Params.Tools != nil {
|
||||
t.Fatalf("expected nil tools when all tools are unsupported, got %d", len(chatReq.Params.Tools))
|
||||
}
|
||||
if chatReq.Params.ToolChoice != nil {
|
||||
t.Fatal("expected tool choice to be dropped when no valid tools remain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToChatRequest_DropsAllowedToolsAndCustomToolChoiceForFallback(t *testing.T) {
|
||||
validName := "valid_tool"
|
||||
tests := []ResponsesToolChoiceType{
|
||||
ResponsesToolChoiceTypeAllowedTools,
|
||||
ResponsesToolChoiceTypeCustom,
|
||||
}
|
||||
|
||||
for _, choiceType := range tests {
|
||||
t.Run(string(choiceType), func(t *testing.T) {
|
||||
req := &BifrostResponsesRequest{
|
||||
Params: &ResponsesParameters{
|
||||
Tools: []ResponsesTool{
|
||||
{
|
||||
Type: ResponsesToolTypeFunction,
|
||||
Name: &validName,
|
||||
},
|
||||
},
|
||||
ToolChoice: &ResponsesToolChoice{
|
||||
ResponsesToolChoiceStruct: &ResponsesToolChoiceStruct{
|
||||
Type: choiceType,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chatReq := req.ToChatRequest()
|
||||
if chatReq == nil || chatReq.Params == nil {
|
||||
t.Fatal("expected non-nil chat request params")
|
||||
}
|
||||
if chatReq.Params.ToolChoice != nil {
|
||||
t.Fatalf("expected %q tool choice to be dropped for fallback", choiceType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToChatRequest_PreservesStringToolChoiceAutoAndNone(t *testing.T) {
|
||||
validName := "valid_tool"
|
||||
tests := []string{
|
||||
string(ChatToolChoiceTypeAuto),
|
||||
string(ChatToolChoiceTypeNone),
|
||||
}
|
||||
|
||||
for _, choice := range tests {
|
||||
t.Run(choice, func(t *testing.T) {
|
||||
req := &BifrostResponsesRequest{
|
||||
Params: &ResponsesParameters{
|
||||
Tools: []ResponsesTool{
|
||||
{
|
||||
Type: ResponsesToolTypeFunction,
|
||||
Name: &validName,
|
||||
},
|
||||
},
|
||||
ToolChoice: &ResponsesToolChoice{
|
||||
ResponsesToolChoiceStr: &choice,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chatReq := req.ToChatRequest()
|
||||
if chatReq == nil || chatReq.Params == nil {
|
||||
t.Fatal("expected non-nil chat request params")
|
||||
}
|
||||
if chatReq.Params.ToolChoice == nil || chatReq.Params.ToolChoice.ChatToolChoiceStr == nil {
|
||||
t.Fatal("expected string tool choice to be preserved")
|
||||
}
|
||||
if *chatReq.Params.ToolChoice.ChatToolChoiceStr != choice {
|
||||
t.Fatalf("expected tool choice %q, got %q", choice, *chatReq.Params.ToolChoice.ChatToolChoiceStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToBifrostResponsesStreamResponse_PopulatesFinalDoneTextAndCompletedOutput(t *testing.T) {
|
||||
state := AcquireChatToResponsesStreamState()
|
||||
defer ReleaseChatToResponsesStreamState(state)
|
||||
|
||||
makeChunk := func(role *string, content *string, finishReason *string) *BifrostChatResponse {
|
||||
return &BifrostChatResponse{
|
||||
ID: "chatcmpl-test",
|
||||
Model: "test-model",
|
||||
Choices: []BifrostResponseChoice{
|
||||
{
|
||||
FinishReason: finishReason,
|
||||
ChatStreamResponseChoice: &ChatStreamResponseChoice{
|
||||
Delta: &ChatStreamResponseChoiceDelta{
|
||||
Role: role,
|
||||
Content: content,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
role := string(ChatMessageRoleAssistant)
|
||||
part1 := "Hello"
|
||||
part2 := " world"
|
||||
stop := string(BifrostFinishReasonStop)
|
||||
|
||||
var all []*BifrostResponsesStreamResponse
|
||||
all = append(all, makeChunk(&role, nil, nil).ToBifrostResponsesStreamResponse(state)...)
|
||||
all = append(all, makeChunk(nil, &part1, nil).ToBifrostResponsesStreamResponse(state)...)
|
||||
all = append(all, makeChunk(nil, &part2, nil).ToBifrostResponsesStreamResponse(state)...)
|
||||
all = append(all, makeChunk(nil, nil, &stop).ToBifrostResponsesStreamResponse(state)...)
|
||||
|
||||
var outputTextDone *BifrostResponsesStreamResponse
|
||||
var completed *BifrostResponsesStreamResponse
|
||||
for _, evt := range all {
|
||||
if evt == nil {
|
||||
continue
|
||||
}
|
||||
if evt.Type == ResponsesStreamResponseTypeOutputTextDone {
|
||||
outputTextDone = evt
|
||||
}
|
||||
if evt.Type == ResponsesStreamResponseTypeCompleted {
|
||||
completed = evt
|
||||
}
|
||||
}
|
||||
|
||||
if outputTextDone == nil || outputTextDone.Text == nil {
|
||||
t.Fatal("expected response.output_text.done with text")
|
||||
}
|
||||
if *outputTextDone.Text != "Hello world" {
|
||||
t.Fatalf("expected output_text.done text %q, got %q", "Hello world", *outputTextDone.Text)
|
||||
}
|
||||
|
||||
if completed == nil || completed.Response == nil || len(completed.Response.Output) != 1 {
|
||||
t.Fatal("expected response.completed with one output message")
|
||||
}
|
||||
msg := completed.Response.Output[0]
|
||||
if msg.Content == nil || len(msg.Content.ContentBlocks) == 0 || msg.Content.ContentBlocks[0].Text == nil {
|
||||
t.Fatal("expected completed output message to include text content block")
|
||||
}
|
||||
if *msg.Content.ContentBlocks[0].Text != "Hello world" {
|
||||
t.Fatalf("expected completed output text %q, got %q", "Hello world", *msg.Content.ContentBlocks[0].Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToBifrostResponsesResponse_MapsLengthToIncomplete(t *testing.T) {
|
||||
length := string(BifrostFinishReasonLength)
|
||||
resp := (&BifrostChatResponse{
|
||||
Choices: []BifrostResponseChoice{
|
||||
{FinishReason: &length},
|
||||
},
|
||||
}).ToBifrostResponsesResponse()
|
||||
|
||||
if resp == nil || resp.Status == nil {
|
||||
t.Fatal("expected status to be set")
|
||||
}
|
||||
if *resp.Status != "incomplete" {
|
||||
t.Fatalf("expected status %q, got %q", "incomplete", *resp.Status)
|
||||
}
|
||||
if resp.IncompleteDetails == nil {
|
||||
t.Fatal("expected incomplete_details to be set")
|
||||
}
|
||||
if resp.IncompleteDetails.Reason != "max_output_tokens" {
|
||||
t.Fatalf("expected incomplete_details.reason %q, got %q", "max_output_tokens", resp.IncompleteDetails.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToBifrostResponsesResponse_MapsToolCallsToCompleted(t *testing.T) {
|
||||
toolCalls := string(BifrostFinishReasonToolCalls)
|
||||
resp := (&BifrostChatResponse{
|
||||
Choices: []BifrostResponseChoice{
|
||||
{FinishReason: &toolCalls},
|
||||
},
|
||||
}).ToBifrostResponsesResponse()
|
||||
|
||||
if resp == nil || resp.Status == nil {
|
||||
t.Fatal("expected status to be set")
|
||||
}
|
||||
if *resp.Status != "completed" {
|
||||
t.Fatalf("expected status %q, got %q", "completed", *resp.Status)
|
||||
}
|
||||
if resp.IncompleteDetails != nil {
|
||||
t.Fatal("expected incomplete_details to be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToBifrostResponsesResponse_PrioritizesLengthAcrossChoices(t *testing.T) {
|
||||
stop := string(BifrostFinishReasonStop)
|
||||
length := string(BifrostFinishReasonLength)
|
||||
resp := (&BifrostChatResponse{
|
||||
Choices: []BifrostResponseChoice{
|
||||
{FinishReason: &stop},
|
||||
{FinishReason: &length},
|
||||
},
|
||||
}).ToBifrostResponsesResponse()
|
||||
|
||||
if resp == nil || resp.Status == nil {
|
||||
t.Fatal("expected status to be set")
|
||||
}
|
||||
if *resp.Status != "incomplete" {
|
||||
t.Fatalf("expected status %q, got %q", "incomplete", *resp.Status)
|
||||
}
|
||||
if resp.IncompleteDetails == nil || resp.IncompleteDetails.Reason != "max_output_tokens" {
|
||||
t.Fatal("expected max_output_tokens incomplete_details")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToBifrostResponsesResponse_UnknownFinishReasonLeavesStatusUnset(t *testing.T) {
|
||||
unknown := "content_filter"
|
||||
resp := (&BifrostChatResponse{
|
||||
Choices: []BifrostResponseChoice{
|
||||
{FinishReason: &unknown},
|
||||
},
|
||||
}).ToBifrostResponsesResponse()
|
||||
|
||||
if resp == nil {
|
||||
t.Fatal("expected non-nil response")
|
||||
}
|
||||
if resp.Status != nil {
|
||||
t.Fatalf("expected status to be nil, got %q", *resp.Status)
|
||||
}
|
||||
if resp.IncompleteDetails != nil {
|
||||
t.Fatal("expected incomplete_details to be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToBifrostResponsesStreamResponse_IncludesFunctionCallsInCompletedOutput(t *testing.T) {
|
||||
state := AcquireChatToResponsesStreamState()
|
||||
defer ReleaseChatToResponsesStreamState(state)
|
||||
|
||||
role := string(ChatMessageRoleAssistant)
|
||||
part1 := "Let me help"
|
||||
toolCallsFinish := string(BifrostFinishReasonToolCalls)
|
||||
funcName := "get_weather"
|
||||
toolCallID := "call_abc123"
|
||||
|
||||
var all []*BifrostResponsesStreamResponse
|
||||
|
||||
// Role chunk
|
||||
all = append(all, (&BifrostChatResponse{
|
||||
ID: "chatcmpl-test",
|
||||
Model: "test-model",
|
||||
Choices: []BifrostResponseChoice{
|
||||
{
|
||||
ChatStreamResponseChoice: &ChatStreamResponseChoice{
|
||||
Delta: &ChatStreamResponseChoiceDelta{
|
||||
Role: &role,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}).ToBifrostResponsesStreamResponse(state)...)
|
||||
|
||||
// Text content chunk
|
||||
all = append(all, (&BifrostChatResponse{
|
||||
ID: "chatcmpl-test",
|
||||
Model: "test-model",
|
||||
Choices: []BifrostResponseChoice{
|
||||
{
|
||||
ChatStreamResponseChoice: &ChatStreamResponseChoice{
|
||||
Delta: &ChatStreamResponseChoiceDelta{
|
||||
Content: &part1,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}).ToBifrostResponsesStreamResponse(state)...)
|
||||
|
||||
// Tool call chunk with function name
|
||||
all = append(all, (&BifrostChatResponse{
|
||||
ID: "chatcmpl-test",
|
||||
Model: "test-model",
|
||||
Choices: []BifrostResponseChoice{
|
||||
{
|
||||
ChatStreamResponseChoice: &ChatStreamResponseChoice{
|
||||
Delta: &ChatStreamResponseChoiceDelta{
|
||||
ToolCalls: []ChatAssistantMessageToolCall{
|
||||
{
|
||||
Index: 0,
|
||||
ID: &toolCallID,
|
||||
Function: ChatAssistantMessageToolCallFunction{Name: &funcName, Arguments: `{"city":`},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}).ToBifrostResponsesStreamResponse(state)...)
|
||||
|
||||
// Tool call argument continuation
|
||||
all = append(all, (&BifrostChatResponse{
|
||||
ID: "chatcmpl-test",
|
||||
Model: "test-model",
|
||||
Choices: []BifrostResponseChoice{
|
||||
{
|
||||
ChatStreamResponseChoice: &ChatStreamResponseChoice{
|
||||
Delta: &ChatStreamResponseChoiceDelta{
|
||||
ToolCalls: []ChatAssistantMessageToolCall{
|
||||
{
|
||||
Index: 0,
|
||||
Function: ChatAssistantMessageToolCallFunction{Arguments: `"Paris"}`},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}).ToBifrostResponsesStreamResponse(state)...)
|
||||
|
||||
// Finish with tool_calls
|
||||
all = append(all, (&BifrostChatResponse{
|
||||
ID: "chatcmpl-test",
|
||||
Model: "test-model",
|
||||
Choices: []BifrostResponseChoice{
|
||||
{
|
||||
FinishReason: &toolCallsFinish,
|
||||
ChatStreamResponseChoice: &ChatStreamResponseChoice{
|
||||
Delta: &ChatStreamResponseChoiceDelta{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}).ToBifrostResponsesStreamResponse(state)...)
|
||||
|
||||
var completed *BifrostResponsesStreamResponse
|
||||
for _, evt := range all {
|
||||
if evt != nil && evt.Type == ResponsesStreamResponseTypeCompleted {
|
||||
completed = evt
|
||||
}
|
||||
}
|
||||
|
||||
if completed == nil || completed.Response == nil {
|
||||
t.Fatal("expected response.completed event")
|
||||
}
|
||||
|
||||
output := completed.Response.Output
|
||||
if len(output) < 2 {
|
||||
t.Fatalf("expected at least 2 output items (text + function_call), got %d", len(output))
|
||||
}
|
||||
|
||||
var hasText, hasFunctionCall bool
|
||||
for _, item := range output {
|
||||
if item.Type != nil && *item.Type == ResponsesMessageTypeMessage {
|
||||
hasText = true
|
||||
if item.Content == nil || len(item.Content.ContentBlocks) == 0 || item.Content.ContentBlocks[0].Text == nil {
|
||||
t.Fatal("text message missing content")
|
||||
}
|
||||
if *item.Content.ContentBlocks[0].Text != "Let me help" {
|
||||
t.Fatalf("expected text %q, got %q", "Let me help", *item.Content.ContentBlocks[0].Text)
|
||||
}
|
||||
}
|
||||
if item.Type != nil && *item.Type == ResponsesMessageTypeFunctionCall {
|
||||
hasFunctionCall = true
|
||||
if item.ResponsesToolMessage == nil {
|
||||
t.Fatal("function_call item missing ResponsesToolMessage")
|
||||
}
|
||||
if item.Name == nil || *item.Name != "get_weather" {
|
||||
t.Fatalf("expected function name %q, got %v", "get_weather", item.Name)
|
||||
}
|
||||
if item.Arguments == nil || *item.Arguments != `{"city":"Paris"}` {
|
||||
t.Fatalf("expected arguments %q, got %v", `{"city":"Paris"}`, item.Arguments)
|
||||
}
|
||||
if item.CallID == nil || *item.CallID != toolCallID {
|
||||
t.Fatalf("expected call_id %q, got %v", toolCallID, item.CallID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasText {
|
||||
t.Fatal("expected text message in completed output")
|
||||
}
|
||||
if !hasFunctionCall {
|
||||
t.Fatal("expected function_call item in completed output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToBifrostResponsesStreamResponse_ToolCallsOnlyInCompletedOutput(t *testing.T) {
|
||||
state := AcquireChatToResponsesStreamState()
|
||||
defer ReleaseChatToResponsesStreamState(state)
|
||||
|
||||
role := string(ChatMessageRoleAssistant)
|
||||
toolCallsFinish := string(BifrostFinishReasonToolCalls)
|
||||
funcName := "get_weather"
|
||||
toolCallID := "call_xyz789"
|
||||
|
||||
var all []*BifrostResponsesStreamResponse
|
||||
|
||||
// Role chunk
|
||||
all = append(all, (&BifrostChatResponse{
|
||||
ID: "chatcmpl-test",
|
||||
Model: "test-model",
|
||||
Choices: []BifrostResponseChoice{
|
||||
{
|
||||
ChatStreamResponseChoice: &ChatStreamResponseChoice{
|
||||
Delta: &ChatStreamResponseChoiceDelta{
|
||||
Role: &role,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}).ToBifrostResponsesStreamResponse(state)...)
|
||||
|
||||
// Tool call chunk (no text content at all)
|
||||
all = append(all, (&BifrostChatResponse{
|
||||
ID: "chatcmpl-test",
|
||||
Model: "test-model",
|
||||
Choices: []BifrostResponseChoice{
|
||||
{
|
||||
ChatStreamResponseChoice: &ChatStreamResponseChoice{
|
||||
Delta: &ChatStreamResponseChoiceDelta{
|
||||
ToolCalls: []ChatAssistantMessageToolCall{
|
||||
{
|
||||
Index: 0,
|
||||
ID: &toolCallID,
|
||||
Function: ChatAssistantMessageToolCallFunction{Name: &funcName, Arguments: `{"q":"test"}`},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}).ToBifrostResponsesStreamResponse(state)...)
|
||||
|
||||
// Finish with tool_calls
|
||||
all = append(all, (&BifrostChatResponse{
|
||||
ID: "chatcmpl-test",
|
||||
Model: "test-model",
|
||||
Choices: []BifrostResponseChoice{
|
||||
{
|
||||
FinishReason: &toolCallsFinish,
|
||||
ChatStreamResponseChoice: &ChatStreamResponseChoice{
|
||||
Delta: &ChatStreamResponseChoiceDelta{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}).ToBifrostResponsesStreamResponse(state)...)
|
||||
|
||||
var completed *BifrostResponsesStreamResponse
|
||||
for _, evt := range all {
|
||||
if evt != nil && evt.Type == ResponsesStreamResponseTypeCompleted {
|
||||
completed = evt
|
||||
}
|
||||
}
|
||||
|
||||
if completed == nil || completed.Response == nil {
|
||||
t.Fatal("expected response.completed event")
|
||||
}
|
||||
|
||||
output := completed.Response.Output
|
||||
if len(output) != 1 {
|
||||
t.Fatalf("expected 1 output item (function_call only), got %d", len(output))
|
||||
}
|
||||
|
||||
item := output[0]
|
||||
if item.Type == nil || *item.Type != ResponsesMessageTypeFunctionCall {
|
||||
t.Fatal("expected function_call type")
|
||||
}
|
||||
if item.ResponsesToolMessage == nil {
|
||||
t.Fatal("function_call item missing ResponsesToolMessage")
|
||||
}
|
||||
if item.Name == nil || *item.Name != "get_weather" {
|
||||
t.Fatalf("expected function name %q, got %v", "get_weather", item.Name)
|
||||
}
|
||||
if item.Arguments == nil || *item.Arguments != `{"q":"test"}` {
|
||||
t.Fatalf("expected arguments %q, got %v", `{"q":"test"}`, item.Arguments)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToBifrostResponsesStreamResponse_MapsLengthToIncompleteEvent(t *testing.T) {
|
||||
state := AcquireChatToResponsesStreamState()
|
||||
defer ReleaseChatToResponsesStreamState(state)
|
||||
|
||||
makeChunk := func(role *string, content *string, finishReason *string) *BifrostChatResponse {
|
||||
return &BifrostChatResponse{
|
||||
ID: "chatcmpl-test",
|
||||
Model: "test-model",
|
||||
Choices: []BifrostResponseChoice{
|
||||
{
|
||||
FinishReason: finishReason,
|
||||
ChatStreamResponseChoice: &ChatStreamResponseChoice{
|
||||
Delta: &ChatStreamResponseChoiceDelta{
|
||||
Role: role,
|
||||
Content: content,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
role := string(ChatMessageRoleAssistant)
|
||||
part := "Hello"
|
||||
length := string(BifrostFinishReasonLength)
|
||||
|
||||
var all []*BifrostResponsesStreamResponse
|
||||
all = append(all, makeChunk(&role, nil, nil).ToBifrostResponsesStreamResponse(state)...)
|
||||
all = append(all, makeChunk(nil, &part, nil).ToBifrostResponsesStreamResponse(state)...)
|
||||
all = append(all, makeChunk(nil, nil, &length).ToBifrostResponsesStreamResponse(state)...)
|
||||
|
||||
var completed *BifrostResponsesStreamResponse
|
||||
var incomplete *BifrostResponsesStreamResponse
|
||||
for _, evt := range all {
|
||||
if evt == nil {
|
||||
continue
|
||||
}
|
||||
if evt.Type == ResponsesStreamResponseTypeCompleted {
|
||||
completed = evt
|
||||
}
|
||||
if evt.Type == ResponsesStreamResponseTypeIncomplete {
|
||||
incomplete = evt
|
||||
}
|
||||
}
|
||||
|
||||
if completed != nil {
|
||||
t.Fatal("did not expect response.completed for finish_reason=length")
|
||||
}
|
||||
if incomplete == nil || incomplete.Response == nil {
|
||||
t.Fatal("expected response.incomplete with response payload")
|
||||
}
|
||||
if incomplete.Response.Status == nil || *incomplete.Response.Status != "incomplete" {
|
||||
t.Fatal("expected terminal response status to be incomplete")
|
||||
}
|
||||
if incomplete.Response.IncompleteDetails == nil || incomplete.Response.IncompleteDetails.Reason != "max_output_tokens" {
|
||||
t.Fatal("expected incomplete_details.reason to be max_output_tokens")
|
||||
}
|
||||
}
|
||||
99
core/schemas/oauth.go
Normal file
99
core/schemas/oauth.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OauthProvider interface defines OAuth operations
|
||||
type OAuth2Provider interface {
|
||||
// GetAccessToken retrieves the access token for a given oauth_config_id (server-level OAuth)
|
||||
GetAccessToken(ctx context.Context, oauthConfigID string) (string, error)
|
||||
|
||||
// RefreshAccessToken refreshes the access token for a given oauth_config_id
|
||||
RefreshAccessToken(ctx context.Context, oauthConfigID string) error
|
||||
|
||||
// ValidateToken checks if the token is still valid
|
||||
ValidateToken(ctx context.Context, oauthConfigID string) (bool, error)
|
||||
|
||||
// RevokeToken revokes the OAuth token
|
||||
RevokeToken(ctx context.Context, oauthConfigID string) error
|
||||
|
||||
// Per-user OAuth methods
|
||||
|
||||
// GetUserAccessToken retrieves the access token for a per-user OAuth session.
|
||||
// If the token is expired, it automatically attempts a refresh.
|
||||
GetUserAccessToken(ctx context.Context, sessionToken string) (string, error)
|
||||
|
||||
// GetUserAccessTokenByIdentity retrieves the upstream access token for a user
|
||||
// identified by virtualKeyID, userID, or sessionToken (fallback), for a specific
|
||||
// MCP client. Tokens looked up by identity persist across sessions.
|
||||
GetUserAccessTokenByIdentity(ctx context.Context, virtualKeyID, userID, sessionToken, mcpClientID string) (string, error)
|
||||
|
||||
// InitiateUserOAuthFlow creates a per-user OAuth session and returns the authorization URL.
|
||||
// Returns (flow initiation details, session ID for polling, error).
|
||||
InitiateUserOAuthFlow(ctx context.Context, oauthConfigID string, mcpClientID string, redirectURI string) (*OAuth2FlowInitiation, string, error)
|
||||
|
||||
// CompleteUserOAuthFlow handles the OAuth callback for a per-user flow.
|
||||
// Returns the session token that the user should send on subsequent requests.
|
||||
CompleteUserOAuthFlow(ctx context.Context, state string, code string) (string, error)
|
||||
|
||||
// RefreshUserAccessToken refreshes a per-user OAuth access token.
|
||||
RefreshUserAccessToken(ctx context.Context, sessionToken string) error
|
||||
|
||||
// RevokeUserToken revokes a per-user OAuth token and marks the session as revoked.
|
||||
RevokeUserToken(ctx context.Context, sessionToken string) error
|
||||
}
|
||||
|
||||
// OauthConfig represents OAuth client configuration
|
||||
type OAuth2Config struct {
|
||||
ID string `json:"id"`
|
||||
ClientID string `json:"client_id,omitempty"` // Optional: Will be obtained via dynamic registration (RFC 7591) if not provided
|
||||
ClientSecret string `json:"client_secret,omitempty"` // Optional: For public clients using PKCE, or obtained via dynamic registration
|
||||
AuthorizeURL string `json:"authorize_url,omitempty"` // Optional: Will be discovered from ServerURL if not provided
|
||||
TokenURL string `json:"token_url,omitempty"` // Optional: Will be discovered from ServerURL if not provided
|
||||
RegistrationURL *string `json:"registration_url,omitempty"` // Optional: For dynamic client registration (RFC 7591), can be discovered
|
||||
RedirectURI string `json:"redirect_uri"` // Required
|
||||
Scopes []string `json:"scopes,omitempty"` // Optional: Can be discovered
|
||||
ServerURL string `json:"server_url"` // MCP server URL for OAuth discovery (required if URLs not provided)
|
||||
UseDiscovery bool `json:"use_discovery,omitempty"` // Deprecated: Discovery now happens automatically when URLs are missing
|
||||
}
|
||||
|
||||
// OauthToken represents OAuth access and refresh tokens
|
||||
type OAuth2Token struct {
|
||||
ID string `json:"id"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
Scopes []string `json:"scopes"`
|
||||
LastRefreshedAt *time.Time `json:"last_refreshed_at,omitempty"`
|
||||
}
|
||||
|
||||
// OauthFlowInitiation represents the response when initiating an OAuth flow
|
||||
type OAuth2FlowInitiation struct {
|
||||
OauthConfigID string `json:"oauth_config_id"`
|
||||
AuthorizeURL string `json:"authorize_url"`
|
||||
State string `json:"state"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
// OAuth2TokenExchangeRequest represents the OAuth token exchange request
|
||||
type OAuth2TokenExchangeRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
Code string `json:"code,omitempty"`
|
||||
RedirectURI string `json:"redirect_uri,omitempty"`
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
CodeVerifier string `json:"code_verifier,omitempty"` // PKCE verifier for authorization_code grant
|
||||
}
|
||||
|
||||
// OAuth2TokenExchangeResponse represents the OAuth token exchange response
|
||||
type OAuth2TokenExchangeResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
89
core/schemas/ocr.go
Normal file
89
core/schemas/ocr.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package schemas
|
||||
|
||||
// OCRDocumentType specifies the type of document input for an OCR request.
|
||||
type OCRDocumentType string
|
||||
|
||||
const (
|
||||
// OCRDocumentTypeDocumentURL represents a document URL input (e.g., PDF URL or base64 data URL).
|
||||
OCRDocumentTypeDocumentURL OCRDocumentType = "document_url"
|
||||
// OCRDocumentTypeImageURL represents an image URL input.
|
||||
OCRDocumentTypeImageURL OCRDocumentType = "image_url"
|
||||
)
|
||||
|
||||
// OCRDocument represents the document input for an OCR request.
|
||||
type OCRDocument struct {
|
||||
Type OCRDocumentType `json:"type"`
|
||||
DocumentURL *string `json:"document_url,omitempty"`
|
||||
ImageURL *string `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
// OCRParameters contains optional parameters for an OCR request.
|
||||
type OCRParameters struct {
|
||||
IncludeImageBase64 *bool `json:"include_image_base64,omitempty"`
|
||||
Pages []int `json:"pages,omitempty"`
|
||||
ImageLimit *int `json:"image_limit,omitempty"`
|
||||
ImageMinSize *int `json:"image_min_size,omitempty"`
|
||||
TableFormat *string `json:"table_format,omitempty"`
|
||||
ExtractHeader *bool `json:"extract_header,omitempty"`
|
||||
ExtractFooter *bool `json:"extract_footer,omitempty"`
|
||||
BBoxAnnotationFormat *string `json:"bbox_annotation_format,omitempty"`
|
||||
DocumentAnnotationFormat *string `json:"document_annotation_format,omitempty"`
|
||||
DocumentAnnotationPrompt *string `json:"document_annotation_prompt,omitempty"`
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// BifrostOCRRequest represents a request to perform OCR on a document.
|
||||
type BifrostOCRRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
ID *string `json:"id,omitempty"`
|
||||
Document OCRDocument `json:"document"`
|
||||
Params *OCRParameters `json:"params,omitempty"`
|
||||
Fallbacks []Fallback `json:"fallbacks,omitempty"`
|
||||
RawRequestBody []byte `json:"-"`
|
||||
}
|
||||
|
||||
// GetRawRequestBody returns the raw request body for the OCR request.
|
||||
func (r *BifrostOCRRequest) GetRawRequestBody() []byte {
|
||||
return r.RawRequestBody
|
||||
}
|
||||
|
||||
// OCRPageImage represents an extracted image from an OCR page.
|
||||
type OCRPageImage struct {
|
||||
ID string `json:"id"`
|
||||
TopLeftX float64 `json:"top_left_x"`
|
||||
TopLeftY float64 `json:"top_left_y"`
|
||||
BottomRightX float64 `json:"bottom_right_x"`
|
||||
BottomRightY float64 `json:"bottom_right_y"`
|
||||
ImageBase64 *string `json:"image_base64,omitempty"`
|
||||
}
|
||||
|
||||
// OCRPageDimensions represents the dimensions of an OCR page.
|
||||
type OCRPageDimensions struct {
|
||||
DPI int `json:"dpi"`
|
||||
Height int `json:"height"`
|
||||
Width int `json:"width"`
|
||||
}
|
||||
|
||||
// OCRPage represents a single processed page from an OCR response.
|
||||
type OCRPage struct {
|
||||
Index int `json:"index"`
|
||||
Markdown string `json:"markdown"`
|
||||
Images []OCRPageImage `json:"images,omitempty"`
|
||||
Dimensions *OCRPageDimensions `json:"dimensions,omitempty"`
|
||||
}
|
||||
|
||||
// OCRUsageInfo represents usage information from an OCR response.
|
||||
type OCRUsageInfo struct {
|
||||
PagesProcessed int `json:"pages_processed"`
|
||||
DocSizeBytes int `json:"doc_size_bytes"`
|
||||
}
|
||||
|
||||
// BifrostOCRResponse represents the response from an OCR request.
|
||||
type BifrostOCRResponse struct {
|
||||
Model string `json:"model"`
|
||||
Pages []OCRPage `json:"pages"`
|
||||
UsageInfo *OCRUsageInfo `json:"usage_info,omitempty"`
|
||||
DocumentAnnotation *string `json:"document_annotation,omitempty"`
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
634
core/schemas/orderedmap.go
Normal file
634
core/schemas/orderedmap.go
Normal file
@@ -0,0 +1,634 @@
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// OrderedMap is a map that preserves insertion order of keys.
|
||||
// It stores key-value pairs and maintains the order in which keys were first inserted.
|
||||
// It is NOT safe for concurrent use.
|
||||
type OrderedMap struct {
|
||||
keys []string
|
||||
values map[string]interface{}
|
||||
}
|
||||
|
||||
// Pair is a key-value pair for constructing OrderedMaps with order preserved.
|
||||
type Pair struct {
|
||||
Key string
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
// KV is a shorthand constructor for Pair.
|
||||
func KV(key string, value interface{}) Pair {
|
||||
return Pair{Key: key, Value: value}
|
||||
}
|
||||
|
||||
// NewOrderedMap creates a new empty OrderedMap.
|
||||
func NewOrderedMap() *OrderedMap {
|
||||
return &OrderedMap{
|
||||
values: make(map[string]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
// NewOrderedMapWithCapacity creates a new empty OrderedMap with preallocated capacity.
|
||||
func NewOrderedMapWithCapacity(cap int) *OrderedMap {
|
||||
return &OrderedMap{
|
||||
keys: make([]string, 0, cap),
|
||||
values: make(map[string]interface{}, cap),
|
||||
}
|
||||
}
|
||||
|
||||
// NewOrderedMapFromPairs creates an OrderedMap from key-value pairs, preserving the given order.
|
||||
func NewOrderedMapFromPairs(pairs ...Pair) *OrderedMap {
|
||||
om := &OrderedMap{
|
||||
keys: make([]string, 0, len(pairs)),
|
||||
values: make(map[string]interface{}, len(pairs)),
|
||||
}
|
||||
for _, p := range pairs {
|
||||
om.Set(p.Key, p.Value)
|
||||
}
|
||||
return om
|
||||
}
|
||||
|
||||
// OrderedMapFromMap creates an OrderedMap from a plain map.
|
||||
// Key order is NOT guaranteed since Go maps have undefined iteration order.
|
||||
// Use this only when insertion order doesn't matter (e.g., for hashing).
|
||||
func OrderedMapFromMap(m map[string]interface{}) *OrderedMap {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
om := &OrderedMap{
|
||||
keys: make([]string, 0, len(m)),
|
||||
values: make(map[string]interface{}, len(m)),
|
||||
}
|
||||
for k, v := range m {
|
||||
om.keys = append(om.keys, k)
|
||||
om.values[k] = v
|
||||
}
|
||||
return om
|
||||
}
|
||||
|
||||
// Get returns the value associated with the key and whether the key exists.
|
||||
func (om *OrderedMap) Get(key string) (interface{}, bool) {
|
||||
if om == nil {
|
||||
return nil, false
|
||||
}
|
||||
v, ok := om.values[key]
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// Set sets the value for a key. If the key is new, it is appended to the end.
|
||||
// If the key already exists, its value is updated in place without changing order.
|
||||
func (om *OrderedMap) Set(key string, value interface{}) {
|
||||
if om.values == nil {
|
||||
om.values = make(map[string]interface{})
|
||||
}
|
||||
if _, exists := om.values[key]; !exists {
|
||||
om.keys = append(om.keys, key)
|
||||
}
|
||||
om.values[key] = value
|
||||
}
|
||||
|
||||
// Delete removes a key and its value. The key is also removed from the ordered keys list.
|
||||
func (om *OrderedMap) Delete(key string) {
|
||||
if om == nil {
|
||||
return
|
||||
}
|
||||
if _, exists := om.values[key]; !exists {
|
||||
return
|
||||
}
|
||||
delete(om.values, key)
|
||||
for i, k := range om.keys {
|
||||
if k == key {
|
||||
om.keys = append(om.keys[:i], om.keys[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Len returns the number of key-value pairs.
|
||||
func (om *OrderedMap) Len() int {
|
||||
if om == nil {
|
||||
return 0
|
||||
}
|
||||
return len(om.keys)
|
||||
}
|
||||
|
||||
// Keys returns the keys in insertion order. The returned slice is a copy.
|
||||
func (om *OrderedMap) Keys() []string {
|
||||
if om == nil {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, len(om.keys))
|
||||
copy(out, om.keys)
|
||||
return out
|
||||
}
|
||||
|
||||
// Range iterates over key-value pairs in insertion order.
|
||||
// If fn returns false, iteration stops.
|
||||
func (om *OrderedMap) Range(fn func(key string, value interface{}) bool) {
|
||||
if om == nil {
|
||||
return
|
||||
}
|
||||
for _, k := range om.keys {
|
||||
if !fn(k, om.values[k]) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clone creates a shallow copy of the OrderedMap (keys and top-level values are copied,
|
||||
// but nested values share references).
|
||||
func (om *OrderedMap) Clone() *OrderedMap {
|
||||
if om == nil {
|
||||
return nil
|
||||
}
|
||||
clone := &OrderedMap{
|
||||
keys: make([]string, len(om.keys)),
|
||||
values: make(map[string]interface{}, len(om.values)),
|
||||
}
|
||||
copy(clone.keys, om.keys)
|
||||
for k, v := range om.values {
|
||||
clone.values[k] = v
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
// ToMap returns a plain map[string]interface{} with the same key-value pairs.
|
||||
// The returned map does not preserve insertion order.
|
||||
func (om *OrderedMap) ToMap() map[string]interface{} {
|
||||
if om == nil {
|
||||
return nil
|
||||
}
|
||||
m := make(map[string]interface{}, len(om.values))
|
||||
for k, v := range om.values {
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// MarshalJSON serializes the OrderedMap to JSON, preserving insertion order of keys.
|
||||
// Uses a value receiver so that both OrderedMap and *OrderedMap invoke this method
|
||||
// (critical for []OrderedMap slices like AnyOf/OneOf/AllOf in ToolFunctionParameters).
|
||||
func (om OrderedMap) MarshalJSON() ([]byte, error) {
|
||||
if om.values == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
buf.WriteByte('{')
|
||||
|
||||
for i, k := range om.keys {
|
||||
if i > 0 {
|
||||
buf.WriteByte(',')
|
||||
}
|
||||
|
||||
// key
|
||||
keyBytes, err := MarshalSorted(k)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf.Write(keyBytes)
|
||||
buf.WriteByte(':')
|
||||
|
||||
// value — nested *OrderedMap values will use their own MarshalJSON
|
||||
valBytes, err := MarshalSorted(om.values[k])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf.Write(valBytes)
|
||||
}
|
||||
|
||||
buf.WriteByte('}')
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// MarshalSorted serializes the OrderedMap to JSON with keys sorted alphabetically.
|
||||
// Use this when deterministic output is needed regardless of insertion order (e.g., hashing).
|
||||
func (om *OrderedMap) MarshalSorted() ([]byte, error) {
|
||||
if om == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
|
||||
keys := make([]string, len(om.keys))
|
||||
copy(keys, om.keys)
|
||||
sort.Strings(keys)
|
||||
|
||||
var buf bytes.Buffer
|
||||
buf.WriteByte('{')
|
||||
|
||||
for i, k := range keys {
|
||||
if i > 0 {
|
||||
buf.WriteByte(',')
|
||||
}
|
||||
|
||||
keyBytes, err := MarshalSorted(k)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf.Write(keyBytes)
|
||||
buf.WriteByte(':')
|
||||
|
||||
valBytes, err := MarshalSorted(om.values[k])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf.Write(valBytes)
|
||||
}
|
||||
|
||||
buf.WriteByte('}')
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON deserializes JSON into the OrderedMap, preserving the key order
|
||||
// from the JSON document. Nested objects are also deserialized as *OrderedMap.
|
||||
// Note: uses encoding/json.Decoder (not sonic) because token-by-token decoding
|
||||
// is required to preserve key order from the JSON document.
|
||||
func (om *OrderedMap) UnmarshalJSON(data []byte) error {
|
||||
// Handle null
|
||||
trimmed := bytes.TrimSpace(data)
|
||||
if bytes.Equal(trimmed, []byte("null")) {
|
||||
om.keys = nil
|
||||
om.values = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
dec := json.NewDecoder(bytes.NewReader(data))
|
||||
|
||||
// Read opening brace
|
||||
t, err := dec.Token()
|
||||
if err != nil {
|
||||
return fmt.Errorf("orderedmap: expected '{': %w", err)
|
||||
}
|
||||
delim, ok := t.(json.Delim)
|
||||
if !ok || delim != '{' {
|
||||
return fmt.Errorf("orderedmap: expected '{', got %v", t)
|
||||
}
|
||||
|
||||
om.keys = om.keys[:0]
|
||||
if om.values == nil {
|
||||
om.values = make(map[string]interface{})
|
||||
} else {
|
||||
for k := range om.values {
|
||||
delete(om.values, k)
|
||||
}
|
||||
}
|
||||
|
||||
for dec.More() {
|
||||
// Read key
|
||||
keyToken, err := dec.Token()
|
||||
if err != nil {
|
||||
return fmt.Errorf("orderedmap: reading key: %w", err)
|
||||
}
|
||||
key, ok := keyToken.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("orderedmap: expected string key, got %T", keyToken)
|
||||
}
|
||||
|
||||
// Read value, preserving nested object order
|
||||
value, err := decodeOrderedValue(dec)
|
||||
if err != nil {
|
||||
return fmt.Errorf("orderedmap: reading value for key %q: %w", key, err)
|
||||
}
|
||||
|
||||
om.Set(key, value)
|
||||
}
|
||||
|
||||
// Read closing brace
|
||||
if _, err := dec.Token(); err != nil {
|
||||
return fmt.Errorf("orderedmap: expected '}': %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// jsonSchemaPriority maps JSON Schema keywords to their preferred
|
||||
// serialization position. Keys present in this map are emitted first
|
||||
// (in the given order), followed by all remaining keys alphabetically.
|
||||
// This matches the optimal ordering for LLM tool schemas: the model
|
||||
// sees type and description before properties, constraints, etc.
|
||||
var jsonSchemaPriority = map[string]int{
|
||||
"type": 0,
|
||||
"description": 1,
|
||||
"properties": 2,
|
||||
"required": 3,
|
||||
}
|
||||
|
||||
// SortKeys sorts the keys of this OrderedMap using JSON Schema priority
|
||||
// ordering (type, description, properties, required first), with remaining
|
||||
// keys sorted alphabetically. Nested *OrderedMap values are also sorted
|
||||
// recursively.
|
||||
func (om *OrderedMap) SortKeys() {
|
||||
if om == nil || len(om.keys) == 0 {
|
||||
return
|
||||
}
|
||||
sort.Slice(om.keys, func(i, j int) bool {
|
||||
pi, okI := jsonSchemaPriority[om.keys[i]]
|
||||
pj, okJ := jsonSchemaPriority[om.keys[j]]
|
||||
switch {
|
||||
case okI && okJ:
|
||||
return pi < pj
|
||||
case okI:
|
||||
return true
|
||||
case okJ:
|
||||
return false
|
||||
default:
|
||||
return om.keys[i] < om.keys[j]
|
||||
}
|
||||
})
|
||||
for k, v := range om.values {
|
||||
switch nested := v.(type) {
|
||||
case *OrderedMap:
|
||||
nested.SortKeys()
|
||||
case map[string]interface{}:
|
||||
converted := OrderedMapFromMap(nested)
|
||||
converted.SortKeys()
|
||||
om.values[k] = converted
|
||||
case []interface{}:
|
||||
sortOrderedMapsInSlice(nested)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sortOrderedMapsInSlice(s []interface{}) {
|
||||
for i, item := range s {
|
||||
switch v := item.(type) {
|
||||
case *OrderedMap:
|
||||
v.SortKeys()
|
||||
case map[string]interface{}:
|
||||
converted := OrderedMapFromMap(v)
|
||||
converted.SortKeys()
|
||||
s[i] = converted
|
||||
case []interface{}:
|
||||
sortOrderedMapsInSlice(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SortedCopy returns a new OrderedMap with keys sorted using JSON Schema
|
||||
// priority ordering. Nested *OrderedMap values are recursively copied and
|
||||
// sorted. Primitive values (strings, numbers, bools) are shared, not cloned.
|
||||
// This is much cheaper than a full JSON marshal/unmarshal Clone because it
|
||||
// only allocates new key slices and value maps.
|
||||
func (om *OrderedMap) SortedCopy() *OrderedMap {
|
||||
if om == nil {
|
||||
return nil
|
||||
}
|
||||
if len(om.keys) == 0 {
|
||||
return &OrderedMap{values: make(map[string]interface{})}
|
||||
}
|
||||
|
||||
newKeys := make([]string, len(om.keys))
|
||||
copy(newKeys, om.keys)
|
||||
sort.Slice(newKeys, func(i, j int) bool {
|
||||
pi, okI := jsonSchemaPriority[newKeys[i]]
|
||||
pj, okJ := jsonSchemaPriority[newKeys[j]]
|
||||
switch {
|
||||
case okI && okJ:
|
||||
return pi < pj
|
||||
case okI:
|
||||
return true
|
||||
case okJ:
|
||||
return false
|
||||
default:
|
||||
return newKeys[i] < newKeys[j]
|
||||
}
|
||||
})
|
||||
|
||||
newValues := make(map[string]interface{}, len(om.values))
|
||||
for k, v := range om.values {
|
||||
switch nested := v.(type) {
|
||||
case *OrderedMap:
|
||||
newValues[k] = nested.SortedCopy()
|
||||
case map[string]interface{}:
|
||||
newValues[k] = OrderedMapFromMap(nested).SortedCopy()
|
||||
case []interface{}:
|
||||
newValues[k] = sortedCopySlice(nested)
|
||||
default:
|
||||
newValues[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return &OrderedMap{keys: newKeys, values: newValues}
|
||||
}
|
||||
|
||||
func sortedCopySlice(s []interface{}) []interface{} {
|
||||
out := make([]interface{}, len(s))
|
||||
for i, item := range s {
|
||||
switch v := item.(type) {
|
||||
case *OrderedMap:
|
||||
out[i] = v.SortedCopy()
|
||||
case map[string]interface{}:
|
||||
out[i] = OrderedMapFromMap(v).SortedCopy()
|
||||
case []interface{}:
|
||||
out[i] = sortedCopySlice(v)
|
||||
default:
|
||||
out[i] = item
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// SortedCopyPreservingProperties is like SortedCopy but preserves the key
|
||||
// order of user-defined property names inside "properties" maps. Structural
|
||||
// JSON Schema keys (type, description, properties, required) are still sorted
|
||||
// by priority, and all other keys alphabetically. When the key "properties"
|
||||
// is encountered, its value (an OrderedMap of user-defined field names) has
|
||||
// its top-level key order preserved while each nested schema value is
|
||||
// recursively processed with the same property-aware logic.
|
||||
//
|
||||
// This ensures deterministic serialization for prompt caching (structural keys
|
||||
// are always in the same order) while preserving the client's intended field
|
||||
// generation order for LLM structured output.
|
||||
func (om *OrderedMap) SortedCopyPreservingProperties() *OrderedMap {
|
||||
if om == nil {
|
||||
return nil
|
||||
}
|
||||
if len(om.keys) == 0 {
|
||||
return &OrderedMap{values: make(map[string]interface{})}
|
||||
}
|
||||
|
||||
newKeys := make([]string, len(om.keys))
|
||||
copy(newKeys, om.keys)
|
||||
sort.Slice(newKeys, func(i, j int) bool {
|
||||
pi, okI := jsonSchemaPriority[newKeys[i]]
|
||||
pj, okJ := jsonSchemaPriority[newKeys[j]]
|
||||
switch {
|
||||
case okI && okJ:
|
||||
return pi < pj
|
||||
case okI:
|
||||
return true
|
||||
case okJ:
|
||||
return false
|
||||
default:
|
||||
return newKeys[i] < newKeys[j]
|
||||
}
|
||||
})
|
||||
|
||||
newValues := make(map[string]interface{}, len(om.values))
|
||||
for k, v := range om.values {
|
||||
if k == "properties" {
|
||||
// User-defined property names: preserve key order, sort nested schemas
|
||||
newValues[k] = preserveKeysOrderedCopyWithAwareness(v)
|
||||
} else {
|
||||
switch nested := v.(type) {
|
||||
case *OrderedMap:
|
||||
newValues[k] = nested.SortedCopyPreservingProperties()
|
||||
case map[string]interface{}:
|
||||
newValues[k] = OrderedMapFromMap(nested).SortedCopyPreservingProperties()
|
||||
case []interface{}:
|
||||
newValues[k] = sortedCopySlicePreservingProperties(nested)
|
||||
default:
|
||||
newValues[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &OrderedMap{keys: newKeys, values: newValues}
|
||||
}
|
||||
|
||||
// preserveKeysOrderedCopyWithAwareness copies an OrderedMap preserving its
|
||||
// top-level key order (these are user-defined property names) while
|
||||
// recursively applying SortedCopyPreservingProperties to each value (each
|
||||
// value is a schema that may itself contain "properties").
|
||||
// If the input is not an *OrderedMap, it falls back to SortedCopyPreservingProperties.
|
||||
func preserveKeysOrderedCopyWithAwareness(v interface{}) interface{} {
|
||||
switch om := v.(type) {
|
||||
case *OrderedMap:
|
||||
return om.preserveKeysWithPropertyAwareness()
|
||||
case map[string]interface{}:
|
||||
// Plain maps have non-deterministic iteration order in Go;
|
||||
// convert and sort since we can't preserve an order that doesn't exist.
|
||||
return OrderedMapFromMap(om).SortedCopyPreservingProperties()
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
// preserveKeysWithPropertyAwareness preserves the top-level key order of this
|
||||
// OrderedMap while recursively applying SortedCopyPreservingProperties to each
|
||||
// nested value.
|
||||
func (om *OrderedMap) preserveKeysWithPropertyAwareness() *OrderedMap {
|
||||
if om == nil {
|
||||
return nil
|
||||
}
|
||||
if len(om.keys) == 0 {
|
||||
return &OrderedMap{values: make(map[string]interface{})}
|
||||
}
|
||||
|
||||
// Preserve original key order (no sorting)
|
||||
newKeys := make([]string, len(om.keys))
|
||||
copy(newKeys, om.keys)
|
||||
|
||||
newValues := make(map[string]interface{}, len(om.values))
|
||||
for k, v := range om.values {
|
||||
switch nested := v.(type) {
|
||||
case *OrderedMap:
|
||||
newValues[k] = nested.SortedCopyPreservingProperties()
|
||||
case map[string]interface{}:
|
||||
newValues[k] = OrderedMapFromMap(nested).SortedCopyPreservingProperties()
|
||||
case []interface{}:
|
||||
newValues[k] = sortedCopySlicePreservingProperties(nested)
|
||||
default:
|
||||
newValues[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return &OrderedMap{keys: newKeys, values: newValues}
|
||||
}
|
||||
|
||||
func sortedCopySlicePreservingProperties(s []interface{}) []interface{} {
|
||||
out := make([]interface{}, len(s))
|
||||
for i, item := range s {
|
||||
switch v := item.(type) {
|
||||
case *OrderedMap:
|
||||
out[i] = v.SortedCopyPreservingProperties()
|
||||
case map[string]interface{}:
|
||||
out[i] = OrderedMapFromMap(v).SortedCopyPreservingProperties()
|
||||
case []interface{}:
|
||||
out[i] = sortedCopySlicePreservingProperties(v)
|
||||
default:
|
||||
out[i] = item
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// decodeOrderedValue reads a single JSON value from the decoder.
|
||||
// Objects are decoded as *OrderedMap (preserving key order).
|
||||
// Arrays are decoded as []interface{} with each element recursively decoded.
|
||||
// Primitives are returned as their Go equivalents.
|
||||
func decodeOrderedValue(dec *json.Decoder) (interface{}, error) {
|
||||
t, err := dec.Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch v := t.(type) {
|
||||
case json.Delim:
|
||||
if v == '{' {
|
||||
// Recursively parse nested object as *OrderedMap
|
||||
nested := NewOrderedMap()
|
||||
for dec.More() {
|
||||
keyToken, err := dec.Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key, ok := keyToken.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected string key, got %T", keyToken)
|
||||
}
|
||||
val, err := decodeOrderedValue(dec)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nested.Set(key, val)
|
||||
}
|
||||
// Consume closing '}'
|
||||
if _, err := dec.Token(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nested, nil
|
||||
}
|
||||
if v == '[' {
|
||||
// Parse array elements recursively
|
||||
var arr []interface{}
|
||||
for dec.More() {
|
||||
val, err := decodeOrderedValue(dec)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
arr = append(arr, val)
|
||||
}
|
||||
// Consume closing ']'
|
||||
if _, err := dec.Token(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if arr == nil {
|
||||
arr = []interface{}{}
|
||||
}
|
||||
return arr, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected delimiter: %v", v)
|
||||
|
||||
case string:
|
||||
return v, nil
|
||||
case float64:
|
||||
return v, nil
|
||||
case bool:
|
||||
return v, nil
|
||||
case nil:
|
||||
return nil, nil
|
||||
case json.Number:
|
||||
f, err := v.Float64()
|
||||
if err != nil {
|
||||
return v.String(), nil
|
||||
}
|
||||
return f, nil
|
||||
default:
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
619
core/schemas/orderedmap_test.go
Normal file
619
core/schemas/orderedmap_test.go
Normal file
@@ -0,0 +1,619 @@
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewOrderedMap(t *testing.T) {
|
||||
om := NewOrderedMap()
|
||||
assert.NotNil(t, om)
|
||||
assert.Equal(t, 0, om.Len())
|
||||
assert.Empty(t, om.Keys())
|
||||
}
|
||||
|
||||
func TestNewOrderedMapFromPairs(t *testing.T) {
|
||||
om := NewOrderedMapFromPairs(
|
||||
KV("b", 2),
|
||||
KV("a", 1),
|
||||
KV("c", 3),
|
||||
)
|
||||
assert.Equal(t, 3, om.Len())
|
||||
assert.Equal(t, []string{"b", "a", "c"}, om.Keys())
|
||||
|
||||
v, ok := om.Get("a")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, 1, v)
|
||||
}
|
||||
|
||||
func TestOrderedMap_SetPreservesInsertionOrder(t *testing.T) {
|
||||
om := NewOrderedMap()
|
||||
om.Set("z", 1)
|
||||
om.Set("a", 2)
|
||||
om.Set("m", 3)
|
||||
|
||||
assert.Equal(t, []string{"z", "a", "m"}, om.Keys())
|
||||
}
|
||||
|
||||
func TestOrderedMap_SetUpdateInPlace(t *testing.T) {
|
||||
om := NewOrderedMap()
|
||||
om.Set("a", 1)
|
||||
om.Set("b", 2)
|
||||
om.Set("a", 10) // update, not re-append
|
||||
|
||||
assert.Equal(t, []string{"a", "b"}, om.Keys())
|
||||
v, ok := om.Get("a")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, 10, v)
|
||||
}
|
||||
|
||||
func TestOrderedMap_Delete(t *testing.T) {
|
||||
om := NewOrderedMapFromPairs(
|
||||
KV("a", 1),
|
||||
KV("b", 2),
|
||||
KV("c", 3),
|
||||
)
|
||||
om.Delete("b")
|
||||
|
||||
assert.Equal(t, 2, om.Len())
|
||||
assert.Equal(t, []string{"a", "c"}, om.Keys())
|
||||
|
||||
_, ok := om.Get("b")
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestOrderedMap_DeleteNonExistent(t *testing.T) {
|
||||
om := NewOrderedMapFromPairs(KV("a", 1))
|
||||
om.Delete("b") // should not panic
|
||||
assert.Equal(t, 1, om.Len())
|
||||
}
|
||||
|
||||
func TestOrderedMap_Range(t *testing.T) {
|
||||
om := NewOrderedMapFromPairs(
|
||||
KV("x", 1),
|
||||
KV("y", 2),
|
||||
KV("z", 3),
|
||||
)
|
||||
|
||||
var keys []string
|
||||
var vals []interface{}
|
||||
om.Range(func(key string, value interface{}) bool {
|
||||
keys = append(keys, key)
|
||||
vals = append(vals, value)
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, []string{"x", "y", "z"}, keys)
|
||||
assert.Equal(t, []interface{}{1, 2, 3}, vals)
|
||||
}
|
||||
|
||||
func TestOrderedMap_RangeEarlyStop(t *testing.T) {
|
||||
om := NewOrderedMapFromPairs(
|
||||
KV("a", 1),
|
||||
KV("b", 2),
|
||||
KV("c", 3),
|
||||
)
|
||||
|
||||
var keys []string
|
||||
om.Range(func(key string, _ interface{}) bool {
|
||||
keys = append(keys, key)
|
||||
return key != "b" // stop after "b"
|
||||
})
|
||||
|
||||
assert.Equal(t, []string{"a", "b"}, keys)
|
||||
}
|
||||
|
||||
func TestOrderedMap_Clone(t *testing.T) {
|
||||
om := NewOrderedMapFromPairs(
|
||||
KV("a", 1),
|
||||
KV("b", 2),
|
||||
)
|
||||
|
||||
clone := om.Clone()
|
||||
assert.Equal(t, om.Keys(), clone.Keys())
|
||||
|
||||
// Modifying clone doesn't affect original
|
||||
clone.Set("c", 3)
|
||||
assert.Equal(t, 2, om.Len())
|
||||
assert.Equal(t, 3, clone.Len())
|
||||
}
|
||||
|
||||
func TestOrderedMap_ToMap(t *testing.T) {
|
||||
om := NewOrderedMapFromPairs(
|
||||
KV("a", 1),
|
||||
KV("b", "hello"),
|
||||
)
|
||||
|
||||
m := om.ToMap()
|
||||
assert.Equal(t, map[string]interface{}{"a": 1, "b": "hello"}, m)
|
||||
}
|
||||
|
||||
func TestOrderedMap_NilSafety(t *testing.T) {
|
||||
var om *OrderedMap
|
||||
|
||||
assert.Equal(t, 0, om.Len())
|
||||
assert.Nil(t, om.Keys())
|
||||
assert.Nil(t, om.Clone())
|
||||
assert.Nil(t, om.ToMap())
|
||||
|
||||
v, ok := om.Get("key")
|
||||
assert.Nil(t, v)
|
||||
assert.False(t, ok)
|
||||
|
||||
// Range on nil should not panic
|
||||
om.Range(func(key string, value interface{}) bool {
|
||||
t.Fatal("should not be called")
|
||||
return true
|
||||
})
|
||||
|
||||
// Delete on nil should not panic
|
||||
om.Delete("key")
|
||||
}
|
||||
|
||||
func TestOrderedMap_MarshalJSON_PreservesOrder(t *testing.T) {
|
||||
om := NewOrderedMapFromPairs(
|
||||
KV("z_last", 1),
|
||||
KV("a_first", 2),
|
||||
KV("m_middle", 3),
|
||||
)
|
||||
|
||||
data, err := json.Marshal(om)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, `{"z_last":1,"a_first":2,"m_middle":3}`, string(data))
|
||||
}
|
||||
|
||||
func TestOrderedMap_MarshalJSON_Empty(t *testing.T) {
|
||||
om := NewOrderedMap()
|
||||
data, err := json.Marshal(om)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, `{}`, string(data))
|
||||
}
|
||||
|
||||
func TestOrderedMap_MarshalJSON_NilValues(t *testing.T) {
|
||||
om := OrderedMap{} // zero value, values is nil
|
||||
data, err := json.Marshal(om)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, `null`, string(data))
|
||||
}
|
||||
|
||||
func TestOrderedMap_UnmarshalJSON_PreservesOrder(t *testing.T) {
|
||||
input := `{"z_last":1,"a_first":"two","m_middle":true}`
|
||||
|
||||
var om OrderedMap
|
||||
err := json.Unmarshal([]byte(input), &om)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, []string{"z_last", "a_first", "m_middle"}, om.Keys())
|
||||
|
||||
v, ok := om.Get("z_last")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, float64(1), v) // JSON numbers are float64
|
||||
|
||||
v, ok = om.Get("a_first")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "two", v)
|
||||
|
||||
v, ok = om.Get("m_middle")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, true, v)
|
||||
}
|
||||
|
||||
func TestOrderedMap_UnmarshalJSON_NestedObjects(t *testing.T) {
|
||||
input := `{"outer_b":{"inner_z":1,"inner_a":2},"outer_a":"simple"}`
|
||||
|
||||
var om OrderedMap
|
||||
err := json.Unmarshal([]byte(input), &om)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, []string{"outer_b", "outer_a"}, om.Keys())
|
||||
|
||||
nested, ok := om.Get("outer_b")
|
||||
assert.True(t, ok)
|
||||
|
||||
nestedOM, ok := nested.(*OrderedMap)
|
||||
require.True(t, ok, "nested object should be *OrderedMap, got %T", nested)
|
||||
assert.Equal(t, []string{"inner_z", "inner_a"}, nestedOM.Keys())
|
||||
}
|
||||
|
||||
func TestOrderedMap_UnmarshalJSON_Null(t *testing.T) {
|
||||
var om OrderedMap
|
||||
err := json.Unmarshal([]byte("null"), &om)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, om.Len())
|
||||
}
|
||||
|
||||
func TestOrderedMap_JSONRoundTrip(t *testing.T) {
|
||||
original := NewOrderedMapFromPairs(
|
||||
KV("answer", map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The answer to the question",
|
||||
}),
|
||||
KV("chain_of_thought", map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Reasoning chain",
|
||||
}),
|
||||
KV("citations", map[string]interface{}{
|
||||
"type": "array",
|
||||
"description": "Sources",
|
||||
}),
|
||||
)
|
||||
|
||||
data, err := json.Marshal(original)
|
||||
require.NoError(t, err)
|
||||
|
||||
var restored OrderedMap
|
||||
err = json.Unmarshal(data, &restored)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, original.Keys(), restored.Keys())
|
||||
}
|
||||
|
||||
func TestOrderedMap_MarshalSorted(t *testing.T) {
|
||||
om := NewOrderedMapFromPairs(
|
||||
KV("z", 1),
|
||||
KV("a", 2),
|
||||
KV("m", 3),
|
||||
)
|
||||
|
||||
data, err := om.MarshalSorted()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, `{"a":2,"m":3,"z":1}`, string(data))
|
||||
}
|
||||
|
||||
func TestOrderedMap_MarshalSorted_Nil(t *testing.T) {
|
||||
var om *OrderedMap
|
||||
data, err := om.MarshalSorted()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, `null`, string(data))
|
||||
}
|
||||
|
||||
func TestOrderedMapFromMap(t *testing.T) {
|
||||
m := map[string]interface{}{"a": 1, "b": 2}
|
||||
om := OrderedMapFromMap(m)
|
||||
assert.Equal(t, 2, om.Len())
|
||||
|
||||
v, ok := om.Get("a")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, 1, v)
|
||||
}
|
||||
|
||||
func TestOrderedMapFromMap_Nil(t *testing.T) {
|
||||
om := OrderedMapFromMap(nil)
|
||||
assert.Nil(t, om)
|
||||
}
|
||||
|
||||
func TestOrderedMap_NestedOrderedMapMarshal(t *testing.T) {
|
||||
inner := NewOrderedMapFromPairs(
|
||||
KV("z_prop", "last"),
|
||||
KV("a_prop", "first"),
|
||||
)
|
||||
outer := NewOrderedMapFromPairs(
|
||||
KV("properties", inner),
|
||||
KV("type", "object"),
|
||||
)
|
||||
|
||||
data, err := json.Marshal(outer)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, `{"properties":{"z_prop":"last","a_prop":"first"},"type":"object"}`, string(data))
|
||||
}
|
||||
|
||||
func TestOrderedMap_UnmarshalThenMarshalPreservesOrder(t *testing.T) {
|
||||
// This is the core use case: JSON comes in with a specific order,
|
||||
// we deserialize, then re-serialize and the order is preserved.
|
||||
input := `{"answer":"string","chain_of_thought":"string","citations":"array","is_unanswered":"boolean"}`
|
||||
|
||||
var om OrderedMap
|
||||
err := json.Unmarshal([]byte(input), &om)
|
||||
require.NoError(t, err)
|
||||
|
||||
output, err := json.Marshal(om)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, input, string(output))
|
||||
}
|
||||
|
||||
func TestOrderedMap_SortKeys_PlainMapValues(t *testing.T) {
|
||||
om := NewOrderedMapFromPairs(
|
||||
KV("properties", map[string]interface{}{
|
||||
"z_field": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "last field",
|
||||
},
|
||||
"a_field": map[string]interface{}{
|
||||
"type": "number",
|
||||
"description": "first field",
|
||||
},
|
||||
}),
|
||||
KV("type", "object"),
|
||||
)
|
||||
|
||||
om.SortKeys()
|
||||
|
||||
assert.Equal(t, []string{"type", "properties"}, om.Keys(), "top-level keys should be schema-sorted")
|
||||
|
||||
props, ok := om.Get("properties")
|
||||
require.True(t, ok)
|
||||
propsOM, ok := props.(*OrderedMap)
|
||||
require.True(t, ok, "plain map should be converted to *OrderedMap, got %T", props)
|
||||
|
||||
assert.Equal(t, []string{"a_field", "z_field"}, propsOM.Keys(), "nested plain map keys should be sorted")
|
||||
|
||||
zField, ok := propsOM.Get("z_field")
|
||||
require.True(t, ok)
|
||||
zFieldOM, ok := zField.(*OrderedMap)
|
||||
require.True(t, ok, "deeply nested plain map should be converted to *OrderedMap, got %T", zField)
|
||||
assert.Equal(t, []string{"type", "description"}, zFieldOM.Keys(), "deeply nested keys should be schema-sorted")
|
||||
}
|
||||
|
||||
func TestOrderedMap_SortKeys_PlainMapInSlice(t *testing.T) {
|
||||
om := NewOrderedMapFromPairs(
|
||||
KV("anyOf", []interface{}{
|
||||
map[string]interface{}{
|
||||
"description": "first option",
|
||||
"type": "string",
|
||||
},
|
||||
}),
|
||||
)
|
||||
|
||||
om.SortKeys()
|
||||
|
||||
anyOf, ok := om.Get("anyOf")
|
||||
require.True(t, ok)
|
||||
slice, ok := anyOf.([]interface{})
|
||||
require.True(t, ok)
|
||||
require.Len(t, slice, 1)
|
||||
|
||||
elemOM, ok := slice[0].(*OrderedMap)
|
||||
require.True(t, ok, "plain map in slice should be converted to *OrderedMap, got %T", slice[0])
|
||||
assert.Equal(t, []string{"type", "description"}, elemOM.Keys())
|
||||
}
|
||||
|
||||
func TestOrderedMap_SortedCopy_PlainMapValues(t *testing.T) {
|
||||
om := NewOrderedMapFromPairs(
|
||||
KV("properties", map[string]interface{}{
|
||||
"z_field": "string",
|
||||
"a_field": "number",
|
||||
}),
|
||||
KV("type", "object"),
|
||||
)
|
||||
|
||||
sorted := om.SortedCopy()
|
||||
|
||||
assert.Equal(t, []string{"type", "properties"}, sorted.Keys())
|
||||
|
||||
props, ok := sorted.Get("properties")
|
||||
require.True(t, ok)
|
||||
propsOM, ok := props.(*OrderedMap)
|
||||
require.True(t, ok, "plain map should be converted to *OrderedMap in SortedCopy, got %T", props)
|
||||
assert.Equal(t, []string{"a_field", "z_field"}, propsOM.Keys())
|
||||
|
||||
origProps, _ := om.Get("properties")
|
||||
_, isMap := origProps.(map[string]interface{})
|
||||
assert.True(t, isMap, "original should be unmodified (still a plain map)")
|
||||
}
|
||||
|
||||
// --- SortedCopyPreservingProperties tests ---
|
||||
|
||||
func TestOrderedMap_SortedCopyPreservingProperties_Basic(t *testing.T) {
|
||||
// Schema with structural keys in wrong order and user properties in non-alpha order
|
||||
om := NewOrderedMapFromPairs(
|
||||
KV("required", []interface{}{"chain_of_thought"}),
|
||||
KV("properties", NewOrderedMapFromPairs(
|
||||
KV("chain_of_thought", NewOrderedMapFromPairs(
|
||||
KV("description", "Reasoning steps"),
|
||||
KV("type", "string"),
|
||||
)),
|
||||
KV("answer", NewOrderedMapFromPairs(
|
||||
KV("description", "The answer"),
|
||||
KV("type", "string"),
|
||||
)),
|
||||
KV("citations", NewOrderedMapFromPairs(
|
||||
KV("type", "array"),
|
||||
)),
|
||||
)),
|
||||
KV("type", "object"),
|
||||
)
|
||||
|
||||
result := om.SortedCopyPreservingProperties()
|
||||
|
||||
// Caching: structural keys sorted by JSON Schema priority
|
||||
assert.Equal(t, []string{"type", "properties", "required"}, result.Keys())
|
||||
|
||||
// CoT: user-defined property names preserved in original order
|
||||
props, ok := result.Get("properties")
|
||||
require.True(t, ok)
|
||||
propsOM := props.(*OrderedMap)
|
||||
assert.Equal(t, []string{"chain_of_thought", "answer", "citations"}, propsOM.Keys())
|
||||
|
||||
// Caching: structural keys within each property are sorted
|
||||
cot, _ := propsOM.Get("chain_of_thought")
|
||||
cotOM := cot.(*OrderedMap)
|
||||
assert.Equal(t, []string{"type", "description"}, cotOM.Keys(), "structural keys within property should be sorted")
|
||||
|
||||
// Immutability: original unchanged
|
||||
assert.Equal(t, []string{"required", "properties", "type"}, om.Keys())
|
||||
origProps, _ := om.Get("properties")
|
||||
origPropsOM := origProps.(*OrderedMap)
|
||||
assert.Equal(t, []string{"chain_of_thought", "answer", "citations"}, origPropsOM.Keys())
|
||||
}
|
||||
|
||||
func TestOrderedMap_SortedCopyPreservingProperties_NestedObjects(t *testing.T) {
|
||||
// Schema where a property is itself an object with nested properties
|
||||
om := NewOrderedMapFromPairs(
|
||||
KV("type", "object"),
|
||||
KV("properties", NewOrderedMapFromPairs(
|
||||
KV("reasoning", NewOrderedMapFromPairs(
|
||||
KV("type", "string"),
|
||||
)),
|
||||
KV("address", NewOrderedMapFromPairs(
|
||||
KV("type", "object"),
|
||||
KV("properties", NewOrderedMapFromPairs(
|
||||
KV("street", NewOrderedMapFromPairs(KV("type", "string"))),
|
||||
KV("city", NewOrderedMapFromPairs(KV("type", "string"))),
|
||||
KV("zip", NewOrderedMapFromPairs(KV("type", "string"))),
|
||||
)),
|
||||
)),
|
||||
KV("answer", NewOrderedMapFromPairs(
|
||||
KV("type", "string"),
|
||||
)),
|
||||
)),
|
||||
)
|
||||
|
||||
result := om.SortedCopyPreservingProperties()
|
||||
|
||||
// CoT: outer property names preserved
|
||||
props, _ := result.Get("properties")
|
||||
propsOM := props.(*OrderedMap)
|
||||
assert.Equal(t, []string{"reasoning", "address", "answer"}, propsOM.Keys())
|
||||
|
||||
// CoT: inner nested property names preserved
|
||||
addr, _ := propsOM.Get("address")
|
||||
addrOM := addr.(*OrderedMap)
|
||||
innerProps, _ := addrOM.Get("properties")
|
||||
innerPropsOM := innerProps.(*OrderedMap)
|
||||
assert.Equal(t, []string{"street", "city", "zip"}, innerPropsOM.Keys())
|
||||
}
|
||||
|
||||
func TestOrderedMap_SortedCopyPreservingProperties_ThreeLevelNesting(t *testing.T) {
|
||||
om := NewOrderedMapFromPairs(
|
||||
KV("type", "object"),
|
||||
KV("properties", NewOrderedMapFromPairs(
|
||||
KV("organization", NewOrderedMapFromPairs(
|
||||
KV("type", "object"),
|
||||
KV("properties", NewOrderedMapFromPairs(
|
||||
KV("department", NewOrderedMapFromPairs(
|
||||
KV("type", "object"),
|
||||
KV("properties", NewOrderedMapFromPairs(
|
||||
KV("team_lead", NewOrderedMapFromPairs(KV("type", "string"))),
|
||||
KV("team_name", NewOrderedMapFromPairs(KV("type", "string"))),
|
||||
)),
|
||||
)),
|
||||
KV("budget", NewOrderedMapFromPairs(KV("type", "number"))),
|
||||
)),
|
||||
)),
|
||||
KV("summary", NewOrderedMapFromPairs(KV("type", "string"))),
|
||||
)),
|
||||
)
|
||||
|
||||
result := om.SortedCopyPreservingProperties()
|
||||
|
||||
// Level 1: organization, summary
|
||||
props, _ := result.Get("properties")
|
||||
propsOM := props.(*OrderedMap)
|
||||
assert.Equal(t, []string{"organization", "summary"}, propsOM.Keys())
|
||||
|
||||
// Level 2: department, budget
|
||||
org, _ := propsOM.Get("organization")
|
||||
orgProps, _ := org.(*OrderedMap).Get("properties")
|
||||
orgPropsOM := orgProps.(*OrderedMap)
|
||||
assert.Equal(t, []string{"department", "budget"}, orgPropsOM.Keys())
|
||||
|
||||
// Level 3: team_lead, team_name
|
||||
dept, _ := orgPropsOM.Get("department")
|
||||
deptProps, _ := dept.(*OrderedMap).Get("properties")
|
||||
deptPropsOM := deptProps.(*OrderedMap)
|
||||
assert.Equal(t, []string{"team_lead", "team_name"}, deptPropsOM.Keys())
|
||||
}
|
||||
|
||||
func TestOrderedMap_SortedCopyPreservingProperties_WithDefs(t *testing.T) {
|
||||
// $defs definition names should be sorted (for caching), but properties
|
||||
// within each definition should be preserved
|
||||
om := NewOrderedMapFromPairs(
|
||||
KV("$defs", NewOrderedMapFromPairs(
|
||||
KV("Metadata", NewOrderedMapFromPairs(
|
||||
KV("type", "object"),
|
||||
KV("properties", NewOrderedMapFromPairs(
|
||||
KV("latency_ms", NewOrderedMapFromPairs(KV("type", "number"))),
|
||||
KV("model_version", NewOrderedMapFromPairs(KV("type", "string"))),
|
||||
)),
|
||||
)),
|
||||
KV("Citation", NewOrderedMapFromPairs(
|
||||
KV("type", "object"),
|
||||
KV("properties", NewOrderedMapFromPairs(
|
||||
KV("url", NewOrderedMapFromPairs(KV("type", "string"))),
|
||||
KV("text", NewOrderedMapFromPairs(KV("type", "string"))),
|
||||
)),
|
||||
)),
|
||||
)),
|
||||
KV("type", "object"),
|
||||
KV("properties", NewOrderedMapFromPairs(
|
||||
KV("answer", NewOrderedMapFromPairs(KV("type", "string"))),
|
||||
)),
|
||||
)
|
||||
|
||||
result := om.SortedCopyPreservingProperties()
|
||||
|
||||
// Caching: $defs definition names are sorted alphabetically
|
||||
defs, _ := result.Get("$defs")
|
||||
defsOM := defs.(*OrderedMap)
|
||||
assert.Equal(t, []string{"Citation", "Metadata"}, defsOM.Keys())
|
||||
|
||||
// CoT: properties within each $def are preserved
|
||||
meta, _ := defsOM.Get("Metadata")
|
||||
metaProps, _ := meta.(*OrderedMap).Get("properties")
|
||||
metaPropsOM := metaProps.(*OrderedMap)
|
||||
assert.Equal(t, []string{"latency_ms", "model_version"}, metaPropsOM.Keys())
|
||||
|
||||
citation, _ := defsOM.Get("Citation")
|
||||
citProps, _ := citation.(*OrderedMap).Get("properties")
|
||||
citPropsOM := citProps.(*OrderedMap)
|
||||
assert.Equal(t, []string{"url", "text"}, citPropsOM.Keys())
|
||||
}
|
||||
|
||||
func TestOrderedMap_SortedCopyPreservingProperties_NilAndEmpty(t *testing.T) {
|
||||
// nil returns nil
|
||||
var nilOM *OrderedMap
|
||||
assert.Nil(t, nilOM.SortedCopyPreservingProperties())
|
||||
|
||||
// Empty returns empty
|
||||
empty := &OrderedMap{keys: []string{}, values: make(map[string]interface{})}
|
||||
result := empty.SortedCopyPreservingProperties()
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, 0, result.Len())
|
||||
|
||||
// No "properties" key behaves like SortedCopy
|
||||
noProps := NewOrderedMapFromPairs(
|
||||
KV("required", []interface{}{"a"}),
|
||||
KV("type", "object"),
|
||||
KV("description", "test"),
|
||||
)
|
||||
result = noProps.SortedCopyPreservingProperties()
|
||||
assert.Equal(t, []string{"type", "description", "required"}, result.Keys())
|
||||
}
|
||||
|
||||
func TestOrderedMap_SortedCopyPreservingProperties_PlainMapInsideProperties(t *testing.T) {
|
||||
// When properties contains a plain map (not OrderedMap), it should be
|
||||
// converted and have its nested values processed
|
||||
om := NewOrderedMapFromPairs(
|
||||
KV("type", "object"),
|
||||
KV("properties", map[string]interface{}{
|
||||
"field_a": map[string]interface{}{"description": "first", "type": "string"},
|
||||
}),
|
||||
)
|
||||
|
||||
result := om.SortedCopyPreservingProperties()
|
||||
|
||||
// The properties value should be converted to *OrderedMap
|
||||
props, ok := result.Get("properties")
|
||||
require.True(t, ok)
|
||||
_, isOM := props.(*OrderedMap)
|
||||
assert.True(t, isOM, "plain map should be converted to *OrderedMap")
|
||||
}
|
||||
|
||||
func TestOrderedMap_EmptyArray(t *testing.T) {
|
||||
input := `{"items":[],"name":"test"}`
|
||||
|
||||
var om OrderedMap
|
||||
err := json.Unmarshal([]byte(input), &om)
|
||||
require.NoError(t, err)
|
||||
|
||||
v, ok := om.Get("items")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, []interface{}{}, v)
|
||||
|
||||
output, err := json.Marshal(om)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, input, string(output))
|
||||
}
|
||||
63
core/schemas/pagination.go
Normal file
63
core/schemas/pagination.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// SerialCursor tracks pagination state for serial key exhaustion.
|
||||
// When paginating across multiple keys, we exhaust all pages from one key
|
||||
// before moving to the next, ensuring only one API call per pagination request.
|
||||
type SerialCursor struct {
|
||||
Version int `json:"v"` // Version for compatibility
|
||||
KeyIndex int `json:"i"` // Current key index in sorted keys array
|
||||
Cursor string `json:"c"` // Native cursor for current key (empty = start fresh)
|
||||
}
|
||||
|
||||
// EncodeSerialCursor encodes a SerialCursor to a base64 string for transport.
|
||||
func EncodeSerialCursor(cursor *SerialCursor) string {
|
||||
if cursor == nil {
|
||||
return ""
|
||||
}
|
||||
data, err := json.Marshal(cursor)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(data)
|
||||
}
|
||||
|
||||
// DecodeSerialCursor decodes a base64 string back to a SerialCursor.
|
||||
// Returns (nil, nil) if the encoded string is empty; returns an error for invalid data.
|
||||
func DecodeSerialCursor(encoded string) (*SerialCursor, error) {
|
||||
if encoded == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
data, err := base64.URLEncoding.DecodeString(encoded)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode cursor: %w", err)
|
||||
}
|
||||
|
||||
var cursor SerialCursor
|
||||
if err := json.Unmarshal(data, &cursor); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal cursor: %w", err)
|
||||
}
|
||||
|
||||
// Validate version
|
||||
if cursor.Version != 1 {
|
||||
return nil, fmt.Errorf("unsupported cursor version: %d", cursor.Version)
|
||||
}
|
||||
|
||||
return &cursor, nil
|
||||
}
|
||||
|
||||
// NewSerialCursor creates a new SerialCursor with version 1.
|
||||
func NewSerialCursor(keyIndex int, cursor string) *SerialCursor {
|
||||
return &SerialCursor{
|
||||
Version: 1,
|
||||
KeyIndex: keyIndex,
|
||||
Cursor: cursor,
|
||||
}
|
||||
}
|
||||
|
||||
26
core/schemas/passthrough.go
Normal file
26
core/schemas/passthrough.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package schemas
|
||||
|
||||
type BifrostPassthroughRequest struct {
|
||||
Provider ModelProvider // provider extracted from path or body, used for key selection when non-empty
|
||||
Model string // model extracted from path or body, used for key selection when non-empty
|
||||
Method string
|
||||
Path string // stripped path, e.g. "/v1/fine-tuning/jobs"
|
||||
RawQuery string // raw query string, no "?"
|
||||
Body []byte
|
||||
SafeHeaders map[string]string // client headers, auth already stripped
|
||||
}
|
||||
|
||||
type BifrostPassthroughResponse struct {
|
||||
StatusCode int
|
||||
Headers map[string]string
|
||||
Body []byte
|
||||
BodyTruncated bool
|
||||
ExtraFields BifrostResponseExtraFields
|
||||
}
|
||||
|
||||
type PassthroughLogParams struct {
|
||||
Method string `json:"method"`
|
||||
Path string `json:"path"` // stripped path, e.g. "/v1/fine-tuning/jobs"
|
||||
RawQuery string `json:"raw_query"` // raw query string, no "?"
|
||||
StatusCode int `json:"status_code"`
|
||||
}
|
||||
327
core/schemas/plugin.go
Normal file
327
core/schemas/plugin.go
Normal file
@@ -0,0 +1,327 @@
|
||||
// Package schemas defines the core schemas and types used by the Bifrost system.
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// PluginStatus constants
|
||||
const (
|
||||
PluginStatusActive = "active"
|
||||
PluginStatusError = "error"
|
||||
PluginStatusDisabled = "disabled"
|
||||
PluginStatusLoading = "loading"
|
||||
PluginStatusUninitialized = "uninitialized"
|
||||
PluginStatusUnloaded = "unloaded"
|
||||
PluginStatusLoaded = "loaded"
|
||||
)
|
||||
|
||||
// PluginStatus represents the status of a plugin.
|
||||
type PluginStatus struct {
|
||||
Name string `json:"name"` // Display name of the plugin
|
||||
Status string `json:"status"`
|
||||
Logs []string `json:"logs"`
|
||||
Types []PluginType `json:"types"` // Plugin types (LLM, MCP, HTTP)
|
||||
}
|
||||
|
||||
// PluginType represents the type of plugin.
|
||||
type PluginType string
|
||||
|
||||
const (
|
||||
PluginTypeLLM PluginType = "llm"
|
||||
PluginTypeMCP PluginType = "mcp"
|
||||
PluginTypeHTTP PluginType = "http"
|
||||
)
|
||||
|
||||
// HTTPRequest is a serializable representation of an HTTP request.
|
||||
// Used for plugin HTTP transport interception (supports both native .so and WASM plugins).
|
||||
// This type is pooled for allocation control - use AcquireHTTPRequest and ReleaseHTTPRequest.
|
||||
type HTTPRequest struct {
|
||||
Method string `json:"method"`
|
||||
Path string `json:"path"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
Query map[string]string `json:"query"`
|
||||
Body []byte `json:"body"`
|
||||
PathParams map[string]string `json:"path_params"` // Path variables extracted from the URL pattern (e.g., {model})
|
||||
}
|
||||
|
||||
// CaseInsensitiveHeaderLookup looks up a header key in a case-insensitive manner
|
||||
func (req *HTTPRequest) CaseInsensitiveHeaderLookup(key string) string {
|
||||
return caseInsensitiveLookup(req.Headers, key)
|
||||
}
|
||||
|
||||
// CaseInsensitiveQueryLookup looks up a query key in a case-insensitive manner
|
||||
func (req *HTTPRequest) CaseInsensitiveQueryLookup(key string) string {
|
||||
return caseInsensitiveLookup(req.Query, key)
|
||||
}
|
||||
|
||||
// CaseInsensitivePathParamLookup looks up a path parameter key in a case-insensitive manner
|
||||
func (req *HTTPRequest) CaseInsensitivePathParamLookup(key string) string {
|
||||
return caseInsensitiveLookup(req.PathParams, key)
|
||||
}
|
||||
|
||||
// CaseInsensitiveLookup looks up a key in a case-insensitive manner for a map of strings
|
||||
// Returns the value if found, otherwise an empty string
|
||||
func caseInsensitiveLookup(data map[string]string, key string) string {
|
||||
if data == nil || key == "" {
|
||||
return ""
|
||||
}
|
||||
// exact match
|
||||
if v, ok := data[key]; ok {
|
||||
return v
|
||||
}
|
||||
// lower key checks
|
||||
lowerKey := strings.ToLower(key)
|
||||
if v, ok := data[lowerKey]; ok {
|
||||
return v
|
||||
}
|
||||
// case-insensitive iteration
|
||||
for k, v := range data {
|
||||
if strings.EqualFold(k, key) {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// HTTPResponse is a serializable representation of an HTTP response.
|
||||
// Used for short-circuit responses in plugin HTTP transport interception.
|
||||
// This type is pooled for allocation control - use AcquireHTTPResponse and ReleaseHTTPResponse.
|
||||
type HTTPResponse struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
Body []byte `json:"body"`
|
||||
}
|
||||
|
||||
// httpRequestPool is the pool for HTTPRequest objects to reduce allocations.
|
||||
var httpRequestPool = sync.Pool{
|
||||
New: func() any {
|
||||
return &HTTPRequest{
|
||||
Headers: make(map[string]string, 16),
|
||||
Query: make(map[string]string, 8),
|
||||
PathParams: make(map[string]string, 4),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// AcquireHTTPRequest gets an HTTPRequest from the pool.
|
||||
// The returned HTTPRequest is ready to use with pre-allocated maps.
|
||||
// Call ReleaseHTTPRequest when done to return it to the pool.
|
||||
func AcquireHTTPRequest() *HTTPRequest {
|
||||
return httpRequestPool.Get().(*HTTPRequest)
|
||||
}
|
||||
|
||||
// ReleaseHTTPRequest returns an HTTPRequest to the pool.
|
||||
// The HTTPRequest is reset before being returned to the pool.
|
||||
// Do not use the HTTPRequest after calling this function.
|
||||
func ReleaseHTTPRequest(req *HTTPRequest) {
|
||||
if req == nil {
|
||||
return
|
||||
}
|
||||
// Clear the maps
|
||||
clear(req.Headers)
|
||||
clear(req.Query)
|
||||
clear(req.PathParams)
|
||||
// Reset fields
|
||||
req.Method = ""
|
||||
req.Path = ""
|
||||
req.Body = nil
|
||||
httpRequestPool.Put(req)
|
||||
}
|
||||
|
||||
// httpResponsePool is the pool for HTTPResponse objects to reduce allocations.
|
||||
var httpResponsePool = sync.Pool{
|
||||
New: func() any {
|
||||
return &HTTPResponse{
|
||||
Headers: make(map[string]string, 8),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// AcquireHTTPResponse gets an HTTPResponse from the pool.
|
||||
// The returned HTTPResponse is ready to use with a pre-allocated Headers map.
|
||||
// Call ReleaseHTTPResponse when done to return it to the pool.
|
||||
func AcquireHTTPResponse() *HTTPResponse {
|
||||
return httpResponsePool.Get().(*HTTPResponse)
|
||||
}
|
||||
|
||||
// ReleaseHTTPResponse returns an HTTPResponse to the pool.
|
||||
// The HTTPResponse is reset before being returned to the pool.
|
||||
// Do not use the HTTPResponse after calling this function.
|
||||
func ReleaseHTTPResponse(resp *HTTPResponse) {
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
// Clear the map
|
||||
clear(resp.Headers)
|
||||
// Reset fields
|
||||
resp.StatusCode = 0
|
||||
resp.Body = nil
|
||||
httpResponsePool.Put(resp)
|
||||
}
|
||||
|
||||
// Plugin defines the interface for Bifrost plugins.
|
||||
// Plugins can intercept and modify requests and responses at different stages
|
||||
// of the processing pipeline.
|
||||
// User can provide multiple plugins in the BifrostConfig.
|
||||
// PreHooks are executed in the order they are registered.
|
||||
// PostHooks are executed in the reverse order of PreHooks.
|
||||
//
|
||||
// Execution order:
|
||||
// 1. HTTPTransportPreHook (HTTP transport only, executed in registration order)
|
||||
// 2. PreLLMHook (executed in registration order)
|
||||
// 3. Provider call
|
||||
// 4. PostLLMHook (executed in reverse order of PreHooks)
|
||||
// 5. HTTPTransportPostHook (HTTP transport only, executed in reverse order)
|
||||
// 5a. HTTPTransportStreamChunkHook (for streaming responses, called per-chunk in reverse order)
|
||||
//
|
||||
// Common use cases: rate limiting, caching, logging, monitoring, request transformation, governance.
|
||||
//
|
||||
// Plugin error handling:
|
||||
// - No Plugin errors are returned to the caller; they are logged as warnings by the Bifrost instance.
|
||||
// - PreLLMHook and PostLLMHook can both modify the request/response and the error. Plugins can recover from errors (set error to nil and provide a response), or invalidate a response (set response to nil and provide an error).
|
||||
// - PostLLMHook is always called with both the current response and error, and should handle either being nil.
|
||||
// - Only truly empty errors (no message, no error, no status code, no type) are treated as recoveries by the pipeline.
|
||||
// - If a PreLLMHook returns a LLMPluginShortCircuit, the provider call may be skipped and only the PostLLMHook methods of plugins that had their PreLLMHook executed are called in reverse order.
|
||||
// - The plugin pipeline ensures symmetry: for every PreLLMHook executed, the corresponding PostLLMHook will be called in reverse order.
|
||||
//
|
||||
// IMPORTANT: When returning BifrostError from PreLLMHook or PostLLMHook:
|
||||
// - You can set the AllowFallbacks field to control fallback behavior
|
||||
// - AllowFallbacks = &true: Allow Bifrost to try fallback providers
|
||||
// - AllowFallbacks = &false: Do not try fallbacks, return error immediately
|
||||
// - AllowFallbacks = nil: Treated as true by default (allow fallbacks for resilience)
|
||||
//
|
||||
// Plugin authors should ensure their hooks are robust to both response and error being nil, and should not assume either is always present.
|
||||
|
||||
type BasePlugin interface {
|
||||
// GetName returns the name of the plugin.
|
||||
GetName() string
|
||||
|
||||
// Cleanup is called on bifrost shutdown.
|
||||
// It allows plugins to clean up any resources they have allocated.
|
||||
// Returns any error that occurred during cleanup, which will be logged as a warning by the Bifrost instance.
|
||||
Cleanup() error
|
||||
}
|
||||
|
||||
type HTTPTransportPlugin interface {
|
||||
BasePlugin
|
||||
|
||||
// HTTPTransportPreHook is called at the HTTP transport layer before requests enter Bifrost core.
|
||||
// It receives a serializable HTTPRequest and allows plugins to modify it in-place.
|
||||
// Only invoked when using HTTP transport (bifrost-http), not when using Bifrost as a Go SDK directly.
|
||||
// Works with both native .so plugins and WASM plugins due to serializable types.
|
||||
//
|
||||
// Return values:
|
||||
// - (nil, nil): Continue to next plugin/handler, request modifications are applied
|
||||
// - (*HTTPResponse, nil): Short-circuit with this response, skip remaining plugins and provider call
|
||||
// - (nil, error): Short-circuit with error response
|
||||
//
|
||||
// Return nil for both values if the plugin doesn't need HTTP transport interception.
|
||||
HTTPTransportPreHook(ctx *BifrostContext, req *HTTPRequest) (*HTTPResponse, error)
|
||||
|
||||
// HTTPTransportPostHook is called at the HTTP transport layer after requests exit Bifrost core.
|
||||
// It receives a serializable HTTPRequest and HTTPResponse and allows plugins to modify it in-place.
|
||||
// Only invoked when using HTTP transport (bifrost-http), not when using Bifrost as a Go SDK directly.
|
||||
// Works with both native .so plugins and WASM plugins due to serializable types.
|
||||
// NOTE: This hook is NOT called for streaming responses. Use HTTPTransportStreamChunkHook instead.
|
||||
// NOTE: For large streamed responses (non-streaming APIs that switch to body streaming for memory safety),
|
||||
// resp.Body may be nil by design while StatusCode and Headers remain populated.
|
||||
//
|
||||
// Return values:
|
||||
// - nil: Continue to next plugin/handler, response modifications are applied
|
||||
// - error: Short-circuit with error response and skip remaining plugins
|
||||
//
|
||||
// Return nil if the plugin doesn't need HTTP transport interception.
|
||||
HTTPTransportPostHook(ctx *BifrostContext, req *HTTPRequest, resp *HTTPResponse) error
|
||||
|
||||
// HTTPTransportStreamChunkHook is called for each chunk during streaming responses.
|
||||
// It receives the BifrostStreamChunk BEFORE they are written to the client.
|
||||
// Only invoked for streaming responses when using HTTP transport (bifrost-http).
|
||||
// Works with both native .so plugins and WASM plugins due to serializable types.
|
||||
//
|
||||
// Plugins can modify the chunk by returning a different BifrostStreamChunk.
|
||||
// Return the original chunk unchanged if no modification is needed.
|
||||
//
|
||||
// Return values:
|
||||
// - (*BifrostStreamChunk, nil): Continue with the (potentially modified) BifrostStreamChunk
|
||||
// - (nil, nil): Skip this BifrostStreamChunk entirely (don't send to client)
|
||||
// - (*BifrostStreamChunk, error): Log warning and continue with the BifrostStreamChunk
|
||||
// - (nil, error): Send back error to the client and stop the streaming
|
||||
//
|
||||
// Return (*BifrostStreamChunk, nil) unchanged if the plugin doesn't need streaming chunk interception.
|
||||
HTTPTransportStreamChunkHook(ctx *BifrostContext, req *HTTPRequest, chunk *BifrostStreamChunk) (*BifrostStreamChunk, error)
|
||||
}
|
||||
|
||||
type LLMPlugin interface {
|
||||
BasePlugin
|
||||
|
||||
PreLLMHook(ctx *BifrostContext, req *BifrostRequest) (*BifrostRequest, *LLMPluginShortCircuit, error)
|
||||
PostLLMHook(ctx *BifrostContext, resp *BifrostResponse, bifrostErr *BifrostError) (*BifrostResponse, *BifrostError, error)
|
||||
}
|
||||
|
||||
type MCPPlugin interface {
|
||||
BasePlugin
|
||||
|
||||
PreMCPHook(ctx *BifrostContext, req *BifrostMCPRequest) (*BifrostMCPRequest, *MCPPluginShortCircuit, error)
|
||||
PostMCPHook(ctx *BifrostContext, resp *BifrostMCPResponse, bifrostErr *BifrostError) (*BifrostMCPResponse, *BifrostError, error)
|
||||
}
|
||||
|
||||
// Plugin placement constants control where custom plugins execute relative to built-in plugins.
|
||||
type PluginPlacement string
|
||||
|
||||
const (
|
||||
PluginPlacementPostBuiltin PluginPlacement = "post_builtin"
|
||||
PluginPlacementPreBuiltin PluginPlacement = "pre_builtin"
|
||||
PluginPlacementBuiltin PluginPlacement = "builtin"
|
||||
PluginPlacementDefault PluginPlacement = PluginPlacementPostBuiltin
|
||||
)
|
||||
|
||||
// PluginConfig is the configuration for a plugin.
|
||||
// It contains the name of the plugin, whether it is enabled, and the configuration for the plugin.
|
||||
type PluginConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Name string `json:"name"`
|
||||
Path *string `json:"path,omitempty"`
|
||||
Version *int16 `json:"version,omitempty"`
|
||||
Config any `json:"config,omitempty"`
|
||||
Placement *PluginPlacement `json:"placement,omitempty"` // "pre_builtin" or "post_builtin". Default: "post_builtin"
|
||||
Order *int `json:"order,omitempty"` // Position within placement group. Lower = earlier. Default: 0
|
||||
}
|
||||
|
||||
// ObservabilityPlugin is an interface for plugins that receive completed traces
|
||||
// for forwarding to observability backends (e.g., OTEL collectors, Datadog, etc.)
|
||||
//
|
||||
// ObservabilityPlugins are called asynchronously after the HTTP response has been
|
||||
// written to the wire, ensuring they don't add latency to the client response.
|
||||
//
|
||||
// Plugins implementing this interface will:
|
||||
// 1. Continue to work as regular plugins via PreLLMHook/PostLLMHook
|
||||
// 2. Additionally receive completed traces via the Inject method
|
||||
//
|
||||
// Example backends: OpenTelemetry collectors, Datadog, Jaeger, Maxim, etc.
|
||||
//
|
||||
// Note: Go type assertion (plugin.(ObservabilityPlugin)) is used to identify
|
||||
// plugins implementing this interface - no marker method is needed.
|
||||
type ObservabilityPlugin interface {
|
||||
BasePlugin
|
||||
|
||||
// Inject receives a completed trace for forwarding to observability backends.
|
||||
// This method is called asynchronously after the response has been written to the client.
|
||||
// The trace contains all spans that were added during request processing.
|
||||
//
|
||||
// Implementations should:
|
||||
// - Convert the trace to their backend's format
|
||||
// - Send the trace to the backend (can be async, but see retention note below)
|
||||
// - Handle errors gracefully (log and continue)
|
||||
//
|
||||
// The context passed is a fresh background context, not the request context.
|
||||
//
|
||||
// Retention: implementations MUST NOT retain the *Trace pointer after Inject
|
||||
// returns. The caller releases the trace back to a sync.Pool immediately after
|
||||
// Inject completes, so any background goroutine that still references it will
|
||||
// race with pool reuse. If a plugin needs to forward the trace asynchronously,
|
||||
// it must copy the data it needs before returning.
|
||||
Inject(ctx context.Context, trace *Trace) error
|
||||
}
|
||||
36
core/schemas/plugin_native.go
Normal file
36
core/schemas/plugin_native.go
Normal file
@@ -0,0 +1,36 @@
|
||||
//go:build !tinygo && !wasm
|
||||
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// BifrostHTTPMiddleware is a middleware function for the Bifrost HTTP transport.
|
||||
// It follows the standard pattern: receives the next handler and returns a new handler.
|
||||
// Used internally for CORS, Auth, Tracing middleware. Plugins use HTTPTransportIntercept instead.
|
||||
type BifrostHTTPMiddleware func(next fasthttp.RequestHandler) fasthttp.RequestHandler
|
||||
|
||||
// EventBroadcaster is a generic callback for broadcasting typed events to connected clients (e.g., via WebSocket).
|
||||
// Any plugin or subsystem can use this to push real-time updates to the frontend.
|
||||
// eventType identifies the message (e.g., "governance_update"), data is the JSON-serializable payload.
|
||||
type EventBroadcaster func(eventType string, data interface{})
|
||||
|
||||
// LLMPluginShortCircuit represents a plugin's decision to short-circuit the normal flow.
|
||||
// It can contain either a response (success short-circuit), a stream (streaming short-circuit), or an error (error short-circuit).
|
||||
type LLMPluginShortCircuit struct {
|
||||
Response *BifrostResponse // If set, short-circuit with this response (skips provider call)
|
||||
Stream chan *BifrostStreamChunk // If set, short-circuit with this stream (skips provider call)
|
||||
Error *BifrostError // If set, short-circuit with this error (can set AllowFallbacks field)
|
||||
}
|
||||
|
||||
// MCPPluginShortCircuit represents a plugin's decision to short-circuit the normal flow.
|
||||
// It can contain either a response (success short-circuit), or an error (error short-circuit).
|
||||
type MCPPluginShortCircuit struct {
|
||||
Response *BifrostMCPResponse // If set, short-circuit with this response (skips MCP call)
|
||||
Error *BifrostError // If set, short-circuit with this error (can set AllowFallbacks field)
|
||||
}
|
||||
|
||||
// PluginShortCircuit is the legacy name for LLMPluginShortCircuit (v1.3.x compatibility).
|
||||
// Deprecated: Use LLMPluginShortCircuit instead.
|
||||
type PluginShortCircuit = LLMPluginShortCircuit
|
||||
220
core/schemas/plugin_test.go
Normal file
220
core/schemas/plugin_test.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCaseInsensitiveLookup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data map[string]string
|
||||
key string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "nil map returns empty string",
|
||||
data: nil,
|
||||
key: "Content-Type",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "empty key returns empty string",
|
||||
data: map[string]string{"Content-Type": "application/json"},
|
||||
key: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "key not found returns empty string",
|
||||
data: map[string]string{"Content-Type": "application/json"},
|
||||
key: "Authorization",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "exact match",
|
||||
data: map[string]string{"Content-Type": "application/json"},
|
||||
key: "Content-Type",
|
||||
expected: "application/json",
|
||||
},
|
||||
{
|
||||
name: "lowercase key match - map has lowercase key",
|
||||
data: map[string]string{"content-type": "application/json"},
|
||||
key: "Content-Type",
|
||||
expected: "application/json",
|
||||
},
|
||||
{
|
||||
name: "lowercase key match - query is lowercase",
|
||||
data: map[string]string{"content-type": "application/json"},
|
||||
key: "content-type",
|
||||
expected: "application/json",
|
||||
},
|
||||
{
|
||||
name: "case-insensitive iteration - map has mixed case",
|
||||
data: map[string]string{"Content-Type": "application/json"},
|
||||
key: "content-type",
|
||||
expected: "application/json",
|
||||
},
|
||||
{
|
||||
name: "case-insensitive iteration - uppercase query",
|
||||
data: map[string]string{"Content-Type": "application/json"},
|
||||
key: "CONTENT-TYPE",
|
||||
expected: "application/json",
|
||||
},
|
||||
{
|
||||
name: "multiple keys - finds correct one",
|
||||
data: map[string]string{"Accept": "text/html", "Content-Type": "application/json"},
|
||||
key: "content-type",
|
||||
expected: "application/json",
|
||||
},
|
||||
// x-bf-vk header variations
|
||||
{
|
||||
name: "x-bf-vk exact match lowercase",
|
||||
data: map[string]string{"x-bf-vk": "sk-bf-test123"},
|
||||
key: "x-bf-vk",
|
||||
expected: "sk-bf-test123",
|
||||
},
|
||||
{
|
||||
name: "x-bf-vk mixed case in map",
|
||||
data: map[string]string{"X-Bf-Vk": "sk-bf-test123"},
|
||||
key: "x-bf-vk",
|
||||
expected: "sk-bf-test123",
|
||||
},
|
||||
{
|
||||
name: "x-bf-vk uppercase in map",
|
||||
data: map[string]string{"X-BF-VK": "sk-bf-test123"},
|
||||
key: "x-bf-vk",
|
||||
expected: "sk-bf-test123",
|
||||
},
|
||||
// authorization header variations
|
||||
{
|
||||
name: "authorization exact match lowercase",
|
||||
data: map[string]string{"authorization": "Bearer sk-bf-test123"},
|
||||
key: "authorization",
|
||||
expected: "Bearer sk-bf-test123",
|
||||
},
|
||||
{
|
||||
name: "authorization capitalized in map",
|
||||
data: map[string]string{"Authorization": "Bearer sk-bf-test123"},
|
||||
key: "authorization",
|
||||
expected: "Bearer sk-bf-test123",
|
||||
},
|
||||
{
|
||||
name: "authorization uppercase in map",
|
||||
data: map[string]string{"AUTHORIZATION": "Bearer sk-bf-test123"},
|
||||
key: "authorization",
|
||||
expected: "Bearer sk-bf-test123",
|
||||
},
|
||||
// x-api-key header variations
|
||||
{
|
||||
name: "x-api-key exact match lowercase",
|
||||
data: map[string]string{"x-api-key": "sk-bf-apikey123"},
|
||||
key: "x-api-key",
|
||||
expected: "sk-bf-apikey123",
|
||||
},
|
||||
{
|
||||
name: "x-api-key mixed case in map",
|
||||
data: map[string]string{"X-Api-Key": "sk-bf-apikey123"},
|
||||
key: "x-api-key",
|
||||
expected: "sk-bf-apikey123",
|
||||
},
|
||||
{
|
||||
name: "x-api-key uppercase in map",
|
||||
data: map[string]string{"X-API-KEY": "sk-bf-apikey123"},
|
||||
key: "x-api-key",
|
||||
expected: "sk-bf-apikey123",
|
||||
},
|
||||
// x-goog-api-key header variations
|
||||
{
|
||||
name: "x-goog-api-key exact match lowercase",
|
||||
data: map[string]string{"x-goog-api-key": "sk-bf-google123"},
|
||||
key: "x-goog-api-key",
|
||||
expected: "sk-bf-google123",
|
||||
},
|
||||
{
|
||||
name: "x-goog-api-key mixed case in map",
|
||||
data: map[string]string{"X-Goog-Api-Key": "sk-bf-google123"},
|
||||
key: "x-goog-api-key",
|
||||
expected: "sk-bf-google123",
|
||||
},
|
||||
{
|
||||
name: "x-goog-api-key uppercase in map",
|
||||
data: map[string]string{"X-GOOG-API-KEY": "sk-bf-google123"},
|
||||
key: "x-goog-api-key",
|
||||
expected: "sk-bf-google123",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := caseInsensitiveLookup(tt.data, tt.key)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPRequest_CaseInsensitiveHeaderLookup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
headers map[string]string
|
||||
key string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "exact match",
|
||||
headers: map[string]string{"Content-Type": "application/json"},
|
||||
key: "Content-Type",
|
||||
expected: "application/json",
|
||||
},
|
||||
{
|
||||
name: "case-insensitive match",
|
||||
headers: map[string]string{"Content-Type": "application/json"},
|
||||
key: "content-type",
|
||||
expected: "application/json",
|
||||
},
|
||||
{
|
||||
name: "authorization header",
|
||||
headers: map[string]string{"Authorization": "Bearer token123"},
|
||||
key: "authorization",
|
||||
expected: "Bearer token123",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := &HTTPRequest{Headers: tt.headers}
|
||||
result := req.CaseInsensitiveHeaderLookup(tt.key)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPRequest_CaseInsensitiveQueryLookup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query map[string]string
|
||||
key string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "exact match",
|
||||
query: map[string]string{"apiKey": "test123"},
|
||||
key: "apiKey",
|
||||
expected: "test123",
|
||||
},
|
||||
{
|
||||
name: "case-insensitive match",
|
||||
query: map[string]string{"ApiKey": "test123"},
|
||||
key: "apikey",
|
||||
expected: "test123",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := &HTTPRequest{Query: tt.query}
|
||||
result := req.CaseInsensitiveQueryLookup(tt.key)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
15
core/schemas/plugin_wasm.go
Normal file
15
core/schemas/plugin_wasm.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build tinygo || wasm
|
||||
|
||||
package schemas
|
||||
|
||||
// LLMPluginShortCircuit represents a plugin's decision to short-circuit the normal flow.
|
||||
// It can contain either a response (success short-circuit), a stream (streaming short-circuit), or an error (error short-circuit).
|
||||
// Streams are not supported in WASM plugins.
|
||||
type LLMPluginShortCircuit struct {
|
||||
Response *BifrostResponse // If set, short-circuit with this response (skips provider call)
|
||||
Error *BifrostError // If set, short-circuit with this error (can set AllowFallbacks field)
|
||||
}
|
||||
|
||||
// PluginShortCircuit is the legacy name for LLMPluginShortCircuit (v1.3.x compatibility).
|
||||
// Deprecated: Use LLMPluginShortCircuit instead.
|
||||
type PluginShortCircuit = LLMPluginShortCircuit
|
||||
614
core/schemas/provider.go
Normal file
614
core/schemas/provider.go
Normal file
@@ -0,0 +1,614 @@
|
||||
// Package schemas defines the core schemas and types used by the Bifrost system.
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"maps"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultMaxRetries = 0
|
||||
DefaultRetryBackoffInitial = 500 * time.Millisecond
|
||||
DefaultRetryBackoffMax = 5 * time.Second
|
||||
DefaultRequestTimeoutInSeconds = 30
|
||||
DefaultMaxConnDurationInSeconds = 300 // 5 minutes — forces connection recycling to prevent stale connections from NAT/LB silent drops
|
||||
DefaultBufferSize = 5000
|
||||
DefaultConcurrency = 1000
|
||||
DefaultStreamBufferSize = 256
|
||||
DefaultStreamIdleTimeoutInSeconds = 60 // Idle timeout per stream chunk — if no data for this many seconds, bifrost closes the connection
|
||||
DefaultMaxConnsPerHost = 5000
|
||||
MaxConnsPerHostUpperBound = 10000
|
||||
DefaultMaxIdleConnsPerHost = 40
|
||||
)
|
||||
|
||||
// Pre-defined errors for provider operations
|
||||
const (
|
||||
ErrProviderRequestTimedOut = "request timed out (default is 30 seconds). You can increase it by setting the default_request_timeout_in_seconds in the network_config or in UI - Providers > Provider Name > Network Config."
|
||||
ErrRequestCancelled = "request cancelled by caller"
|
||||
ErrRequestBodyConversion = "failed to convert bifrost request to the expected provider request body"
|
||||
ErrProviderRequestMarshal = "failed to marshal request body to JSON"
|
||||
ErrProviderCreateRequest = "failed to create HTTP request to provider API"
|
||||
ErrProviderDoRequest = "failed to execute HTTP request to provider API"
|
||||
ErrProviderNetworkError = "network error occurred while connecting to provider API (DNS lookup, connection refused, etc.)"
|
||||
ErrProviderResponseDecode = "failed to decode response body from provider API"
|
||||
ErrProviderResponseUnmarshal = "failed to unmarshal response from provider API"
|
||||
ErrProviderResponseEmpty = "empty response received from provider"
|
||||
ErrProviderResponseHTML = "HTML response received from provider"
|
||||
ErrProviderRawRequestUnmarshal = "failed to unmarshal raw request from provider API"
|
||||
ErrProviderRawResponseUnmarshal = "failed to unmarshal raw response from provider API"
|
||||
ErrProviderResponseDecompress = "failed to decompress provider's response"
|
||||
)
|
||||
|
||||
// NetworkConfig represents the network configuration for provider connections.
|
||||
// ExtraHeaders is automatically copied during provider initialization to prevent data races.
|
||||
//
|
||||
// RetryBackoffInitial and RetryBackoffMax are stored internally as time.Duration (nanoseconds),
|
||||
// but are serialized/deserialized to/from JSON as milliseconds (integers).
|
||||
// This means:
|
||||
// - In JSON: values are represented as milliseconds (e.g., 1000 means 1000ms)
|
||||
// - In Go: values are time.Duration (e.g., 1000ms = 1000000000 nanoseconds)
|
||||
// - When unmarshaling from JSON: a value of 1000 is interpreted as 1000ms, not 1000ns
|
||||
// - When marshaling to JSON: a time.Duration is converted to milliseconds
|
||||
type NetworkConfig struct {
|
||||
// BaseURL is supported for OpenAI, Anthropic, Cohere, Mistral, and Ollama providers (required for Ollama)
|
||||
BaseURL string `json:"base_url,omitempty"` // Base URL for the provider (optional)
|
||||
ExtraHeaders map[string]string `json:"extra_headers,omitempty"` // Additional headers to include in requests (optional)
|
||||
DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"` // Default timeout for requests
|
||||
MaxRetries int `json:"max_retries"` // Maximum number of retries
|
||||
RetryBackoffInitial time.Duration `json:"retry_backoff_initial"` // Initial backoff duration (stored as nanoseconds, JSON as milliseconds)
|
||||
RetryBackoffMax time.Duration `json:"retry_backoff_max"` // Maximum backoff duration (stored as nanoseconds, JSON as milliseconds)
|
||||
InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"` // Disables TLS certificate verification for provider connections
|
||||
CACertPEM string `json:"ca_cert_pem,omitempty"` // PEM-encoded CA certificate to trust for provider endpoint connections
|
||||
StreamIdleTimeoutInSeconds int `json:"stream_idle_timeout_in_seconds,omitempty"` // Idle timeout per stream chunk (0 = use default 60s)
|
||||
MaxConnsPerHost int `json:"max_conns_per_host,omitempty"` // Max TCP connections per provider host (default: 5000)
|
||||
EnforceHTTP2 bool `json:"enforce_http2,omitempty"` // Force HTTP/2 on provider connections (relevant for net/http-based providers like Bedrock)
|
||||
BetaHeaderOverrides map[string]bool `json:"beta_header_overrides,omitempty"` // Override default beta header support per provider (keys are prefixes like "redact-thinking-")
|
||||
}
|
||||
|
||||
// UnmarshalJSON customizes JSON unmarshaling for NetworkConfig.
|
||||
// RetryBackoffInitial and RetryBackoffMax are interpreted as milliseconds in JSON,
|
||||
// but stored as time.Duration (nanoseconds) internally.
|
||||
func (nc *NetworkConfig) UnmarshalJSON(data []byte) error {
|
||||
// Use an alias type to avoid infinite recursion
|
||||
type NetworkConfigAlias struct {
|
||||
BaseURL string `json:"base_url,omitempty"`
|
||||
ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
|
||||
DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"`
|
||||
MaxRetries int `json:"max_retries"`
|
||||
RetryBackoffInitial int64 `json:"retry_backoff_initial"` // milliseconds in JSON
|
||||
RetryBackoffMax int64 `json:"retry_backoff_max"` // milliseconds in JSON
|
||||
InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"`
|
||||
CACertPEM string `json:"ca_cert_pem,omitempty"`
|
||||
StreamIdleTimeoutInSeconds int `json:"stream_idle_timeout_in_seconds,omitempty"`
|
||||
MaxConnsPerHost int `json:"max_conns_per_host,omitempty"`
|
||||
EnforceHTTP2 bool `json:"enforce_http2,omitempty"`
|
||||
BetaHeaderOverrides map[string]bool `json:"beta_header_overrides,omitempty"`
|
||||
}
|
||||
|
||||
var alias NetworkConfigAlias
|
||||
if err := json.Unmarshal(data, &alias); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Copy all fields
|
||||
nc.BaseURL = alias.BaseURL
|
||||
nc.ExtraHeaders = alias.ExtraHeaders
|
||||
nc.DefaultRequestTimeoutInSeconds = alias.DefaultRequestTimeoutInSeconds
|
||||
nc.MaxRetries = alias.MaxRetries
|
||||
nc.InsecureSkipVerify = alias.InsecureSkipVerify
|
||||
nc.CACertPEM = alias.CACertPEM
|
||||
nc.StreamIdleTimeoutInSeconds = alias.StreamIdleTimeoutInSeconds
|
||||
nc.MaxConnsPerHost = alias.MaxConnsPerHost
|
||||
nc.EnforceHTTP2 = alias.EnforceHTTP2
|
||||
nc.BetaHeaderOverrides = alias.BetaHeaderOverrides
|
||||
|
||||
// Convert milliseconds to time.Duration (nanoseconds)
|
||||
// Only convert if value is greater than 0
|
||||
if alias.RetryBackoffInitial > 0 {
|
||||
nc.RetryBackoffInitial = time.Duration(alias.RetryBackoffInitial) * time.Millisecond
|
||||
}
|
||||
if alias.RetryBackoffMax > 0 {
|
||||
nc.RetryBackoffMax = time.Duration(alias.RetryBackoffMax) * time.Millisecond
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON customizes JSON marshaling for NetworkConfig.
|
||||
// RetryBackoffInitial and RetryBackoffMax are converted from time.Duration (nanoseconds)
|
||||
// to milliseconds (integers) in JSON.
|
||||
func (nc NetworkConfig) MarshalJSON() ([]byte, error) {
|
||||
// Use an alias type to avoid infinite recursion
|
||||
type NetworkConfigAlias struct {
|
||||
BaseURL string `json:"base_url,omitempty"`
|
||||
ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
|
||||
DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"`
|
||||
MaxRetries int `json:"max_retries"`
|
||||
RetryBackoffInitial int64 `json:"retry_backoff_initial"` // milliseconds in JSON
|
||||
RetryBackoffMax int64 `json:"retry_backoff_max"` // milliseconds in JSON
|
||||
InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"`
|
||||
CACertPEM string `json:"ca_cert_pem,omitempty"`
|
||||
StreamIdleTimeoutInSeconds int `json:"stream_idle_timeout_in_seconds,omitempty"`
|
||||
MaxConnsPerHost int `json:"max_conns_per_host,omitempty"`
|
||||
EnforceHTTP2 bool `json:"enforce_http2,omitempty"`
|
||||
BetaHeaderOverrides map[string]bool `json:"beta_header_overrides,omitempty"`
|
||||
}
|
||||
|
||||
alias := NetworkConfigAlias{
|
||||
BaseURL: nc.BaseURL,
|
||||
ExtraHeaders: nc.ExtraHeaders,
|
||||
DefaultRequestTimeoutInSeconds: nc.DefaultRequestTimeoutInSeconds,
|
||||
MaxRetries: nc.MaxRetries,
|
||||
// Convert time.Duration (nanoseconds) to milliseconds
|
||||
RetryBackoffInitial: int64(nc.RetryBackoffInitial / time.Millisecond),
|
||||
RetryBackoffMax: int64(nc.RetryBackoffMax / time.Millisecond),
|
||||
InsecureSkipVerify: nc.InsecureSkipVerify,
|
||||
CACertPEM: nc.CACertPEM,
|
||||
StreamIdleTimeoutInSeconds: nc.StreamIdleTimeoutInSeconds,
|
||||
MaxConnsPerHost: nc.MaxConnsPerHost,
|
||||
EnforceHTTP2: nc.EnforceHTTP2,
|
||||
BetaHeaderOverrides: nc.BetaHeaderOverrides,
|
||||
}
|
||||
|
||||
return json.Marshal(alias)
|
||||
}
|
||||
|
||||
// Redacted returns a redacted copy of the network configuration with CACertPEM masked.
|
||||
func (nc *NetworkConfig) Redacted() *NetworkConfig {
|
||||
if nc == nil {
|
||||
return nil
|
||||
}
|
||||
redacted := *nc
|
||||
if nc.CACertPEM != "" {
|
||||
redacted.CACertPEM = "<REDACTED>"
|
||||
}
|
||||
return &redacted
|
||||
}
|
||||
|
||||
// DefaultNetworkConfig is the default network configuration for provider connections.
|
||||
var DefaultNetworkConfig = NetworkConfig{
|
||||
DefaultRequestTimeoutInSeconds: DefaultRequestTimeoutInSeconds,
|
||||
MaxRetries: DefaultMaxRetries,
|
||||
RetryBackoffInitial: DefaultRetryBackoffInitial,
|
||||
RetryBackoffMax: DefaultRetryBackoffMax,
|
||||
StreamIdleTimeoutInSeconds: DefaultStreamIdleTimeoutInSeconds,
|
||||
MaxConnsPerHost: DefaultMaxConnsPerHost,
|
||||
}
|
||||
|
||||
// ConcurrencyAndBufferSize represents configuration for concurrent operations and buffer sizes.
|
||||
type ConcurrencyAndBufferSize struct {
|
||||
Concurrency int `json:"concurrency"` // Number of concurrent operations. Also used as the initial pool size for the provider reponses.
|
||||
BufferSize int `json:"buffer_size"` // Size of the buffer
|
||||
}
|
||||
|
||||
// DefaultConcurrencyAndBufferSize is the default concurrency and buffer size for provider operations.
|
||||
var DefaultConcurrencyAndBufferSize = ConcurrencyAndBufferSize{
|
||||
Concurrency: DefaultConcurrency,
|
||||
BufferSize: DefaultBufferSize,
|
||||
}
|
||||
|
||||
// ProxyType defines the type of proxy to use for connections.
|
||||
type ProxyType string
|
||||
|
||||
const (
|
||||
// NoProxy indicates no proxy should be used
|
||||
NoProxy ProxyType = "none"
|
||||
// HTTPProxy indicates an HTTP proxy should be used
|
||||
HTTPProxy ProxyType = "http"
|
||||
// Socks5Proxy indicates a SOCKS5 proxy should be used
|
||||
Socks5Proxy ProxyType = "socks5"
|
||||
// EnvProxy indicates the proxy should be read from environment variables
|
||||
EnvProxy ProxyType = "environment"
|
||||
)
|
||||
|
||||
// ProxyConfig holds the configuration for proxy settings.
|
||||
type ProxyConfig struct {
|
||||
Type ProxyType `json:"type"` // Type of proxy to use
|
||||
URL string `json:"url"` // URL of the proxy server
|
||||
Username string `json:"username"` // Username for proxy authentication
|
||||
Password string `json:"password"` // Password for proxy authentication
|
||||
CACertPEM string `json:"ca_cert_pem"` // PEM-encoded CA certificate to trust for TLS connections through the proxy
|
||||
}
|
||||
|
||||
// IsRedactedValue returns true if the value is redacted.
|
||||
func (pc *ProxyConfig) IsRedactedValue(value string) bool {
|
||||
return value == "<REDACTED>" || value == "********"
|
||||
}
|
||||
|
||||
// Redacted returns a redacted copy of the proxy configuration.
|
||||
func (pc *ProxyConfig) Redacted() *ProxyConfig {
|
||||
// Create redacted config with same structure but redacted values
|
||||
redactedConfig := ProxyConfig{
|
||||
Type: pc.Type,
|
||||
URL: pc.URL,
|
||||
Username: pc.Username,
|
||||
}
|
||||
if pc.Password != "" {
|
||||
redactedConfig.Password = "<REDACTED>"
|
||||
}
|
||||
if pc.CACertPEM != "" {
|
||||
redactedConfig.CACertPEM = "<REDACTED>"
|
||||
}
|
||||
return &redactedConfig
|
||||
}
|
||||
|
||||
// AllowedRequests controls which operations are permitted.
|
||||
// A nil *AllowedRequests means "all operations allowed."
|
||||
// A non-nil value only allows fields set to true; omitted or false fields are disallowed.
|
||||
type AllowedRequests struct {
|
||||
ListModels bool `json:"list_models"`
|
||||
TextCompletion bool `json:"text_completion"`
|
||||
TextCompletionStream bool `json:"text_completion_stream"`
|
||||
ChatCompletion bool `json:"chat_completion"`
|
||||
ChatCompletionStream bool `json:"chat_completion_stream"`
|
||||
Responses bool `json:"responses"`
|
||||
ResponsesStream bool `json:"responses_stream"`
|
||||
CountTokens bool `json:"count_tokens"`
|
||||
Embedding bool `json:"embedding"`
|
||||
Rerank bool `json:"rerank"`
|
||||
OCR bool `json:"ocr"`
|
||||
Speech bool `json:"speech"`
|
||||
SpeechStream bool `json:"speech_stream"`
|
||||
Transcription bool `json:"transcription"`
|
||||
TranscriptionStream bool `json:"transcription_stream"`
|
||||
ImageGeneration bool `json:"image_generation"`
|
||||
ImageGenerationStream bool `json:"image_generation_stream"`
|
||||
ImageEdit bool `json:"image_edit"`
|
||||
ImageEditStream bool `json:"image_edit_stream"`
|
||||
ImageVariation bool `json:"image_variation"`
|
||||
VideoGeneration bool `json:"video_generation"`
|
||||
VideoRetrieve bool `json:"video_retrieve"`
|
||||
VideoDownload bool `json:"video_download"`
|
||||
VideoDelete bool `json:"video_delete"`
|
||||
VideoList bool `json:"video_list"`
|
||||
VideoRemix bool `json:"video_remix"`
|
||||
BatchCreate bool `json:"batch_create"`
|
||||
BatchList bool `json:"batch_list"`
|
||||
BatchRetrieve bool `json:"batch_retrieve"`
|
||||
BatchCancel bool `json:"batch_cancel"`
|
||||
BatchDelete bool `json:"batch_delete"`
|
||||
BatchResults bool `json:"batch_results"`
|
||||
FileUpload bool `json:"file_upload"`
|
||||
FileList bool `json:"file_list"`
|
||||
FileRetrieve bool `json:"file_retrieve"`
|
||||
FileDelete bool `json:"file_delete"`
|
||||
FileContent bool `json:"file_content"`
|
||||
ContainerCreate bool `json:"container_create"`
|
||||
ContainerList bool `json:"container_list"`
|
||||
ContainerRetrieve bool `json:"container_retrieve"`
|
||||
ContainerDelete bool `json:"container_delete"`
|
||||
ContainerFileCreate bool `json:"container_file_create"`
|
||||
ContainerFileList bool `json:"container_file_list"`
|
||||
ContainerFileRetrieve bool `json:"container_file_retrieve"`
|
||||
ContainerFileContent bool `json:"container_file_content"`
|
||||
ContainerFileDelete bool `json:"container_file_delete"`
|
||||
Passthrough bool `json:"passthrough"`
|
||||
PassthroughStream bool `json:"passthrough_stream"`
|
||||
WebSocketResponses bool `json:"websocket_responses"`
|
||||
Realtime bool `json:"realtime"`
|
||||
}
|
||||
|
||||
// IsOperationAllowed checks if a specific operation is allowed
|
||||
func (ar *AllowedRequests) IsOperationAllowed(operation RequestType) bool {
|
||||
if ar == nil {
|
||||
return true // Default to allowed if no restrictions
|
||||
}
|
||||
|
||||
switch operation {
|
||||
case ListModelsRequest:
|
||||
return ar.ListModels
|
||||
case TextCompletionRequest:
|
||||
return ar.TextCompletion
|
||||
case TextCompletionStreamRequest:
|
||||
return ar.TextCompletionStream
|
||||
case ChatCompletionRequest:
|
||||
return ar.ChatCompletion
|
||||
case ChatCompletionStreamRequest:
|
||||
return ar.ChatCompletionStream
|
||||
case ResponsesRequest:
|
||||
return ar.Responses
|
||||
case ResponsesStreamRequest:
|
||||
return ar.ResponsesStream
|
||||
case CountTokensRequest:
|
||||
return ar.CountTokens
|
||||
case EmbeddingRequest:
|
||||
return ar.Embedding
|
||||
case RerankRequest:
|
||||
return ar.Rerank
|
||||
case OCRRequest:
|
||||
return ar.OCR
|
||||
case SpeechRequest:
|
||||
return ar.Speech
|
||||
case SpeechStreamRequest:
|
||||
return ar.SpeechStream
|
||||
case TranscriptionRequest:
|
||||
return ar.Transcription
|
||||
case TranscriptionStreamRequest:
|
||||
return ar.TranscriptionStream
|
||||
case ImageGenerationRequest:
|
||||
return ar.ImageGeneration
|
||||
case ImageGenerationStreamRequest:
|
||||
return ar.ImageGenerationStream
|
||||
case ImageEditRequest:
|
||||
return ar.ImageEdit
|
||||
case ImageEditStreamRequest:
|
||||
return ar.ImageEditStream
|
||||
case ImageVariationRequest:
|
||||
return ar.ImageVariation
|
||||
case VideoGenerationRequest:
|
||||
return ar.VideoGeneration
|
||||
case VideoRetrieveRequest:
|
||||
return ar.VideoRetrieve
|
||||
case VideoDownloadRequest:
|
||||
return ar.VideoDownload
|
||||
case VideoDeleteRequest:
|
||||
return ar.VideoDelete
|
||||
case VideoListRequest:
|
||||
return ar.VideoList
|
||||
case VideoRemixRequest:
|
||||
return ar.VideoRemix
|
||||
case BatchCreateRequest:
|
||||
return ar.BatchCreate
|
||||
case BatchListRequest:
|
||||
return ar.BatchList
|
||||
case BatchRetrieveRequest:
|
||||
return ar.BatchRetrieve
|
||||
case BatchCancelRequest:
|
||||
return ar.BatchCancel
|
||||
case BatchDeleteRequest:
|
||||
return ar.BatchDelete
|
||||
case BatchResultsRequest:
|
||||
return ar.BatchResults
|
||||
case FileUploadRequest:
|
||||
return ar.FileUpload
|
||||
case FileListRequest:
|
||||
return ar.FileList
|
||||
case FileRetrieveRequest:
|
||||
return ar.FileRetrieve
|
||||
case FileDeleteRequest:
|
||||
return ar.FileDelete
|
||||
case FileContentRequest:
|
||||
return ar.FileContent
|
||||
case ContainerCreateRequest:
|
||||
return ar.ContainerCreate
|
||||
case ContainerListRequest:
|
||||
return ar.ContainerList
|
||||
case ContainerRetrieveRequest:
|
||||
return ar.ContainerRetrieve
|
||||
case ContainerDeleteRequest:
|
||||
return ar.ContainerDelete
|
||||
case ContainerFileCreateRequest:
|
||||
return ar.ContainerFileCreate
|
||||
case ContainerFileListRequest:
|
||||
return ar.ContainerFileList
|
||||
case ContainerFileRetrieveRequest:
|
||||
return ar.ContainerFileRetrieve
|
||||
case ContainerFileContentRequest:
|
||||
return ar.ContainerFileContent
|
||||
case ContainerFileDeleteRequest:
|
||||
return ar.ContainerFileDelete
|
||||
case PassthroughRequest:
|
||||
return ar.Passthrough
|
||||
case PassthroughStreamRequest:
|
||||
return ar.PassthroughStream
|
||||
case WebSocketResponsesRequest:
|
||||
return ar.WebSocketResponses
|
||||
case RealtimeRequest:
|
||||
return ar.Realtime
|
||||
default:
|
||||
return false // Default to not allowed for unknown operations
|
||||
}
|
||||
}
|
||||
|
||||
type CustomProviderConfig struct {
|
||||
CustomProviderKey string `json:"-"` // Custom provider key, internally set by Bifrost
|
||||
IsKeyLess bool `json:"is_key_less"` // Whether the custom provider requires a key (not allowed for Bedrock)
|
||||
BaseProviderType ModelProvider `json:"base_provider_type"` // Base provider type
|
||||
AllowedRequests *AllowedRequests `json:"allowed_requests,omitempty"` // Allowed requests for the custom provider
|
||||
RequestPathOverrides map[RequestType]string `json:"request_path_overrides,omitempty"` // Mapping of request type to its custom path which will override the default path of the provider (not allowed for Bedrock)
|
||||
}
|
||||
|
||||
// IsOperationAllowed checks if a specific operation is allowed for this custom provider
|
||||
func (cpc *CustomProviderConfig) IsOperationAllowed(operation RequestType) bool {
|
||||
if cpc == nil || cpc.AllowedRequests == nil {
|
||||
return true // Default to allowed if no restrictions
|
||||
}
|
||||
return cpc.AllowedRequests.IsOperationAllowed(operation)
|
||||
}
|
||||
|
||||
// ProviderConfig represents the complete configuration for a provider.
|
||||
// An array of ProviderConfig needs to be provided in GetConfigForProvider
|
||||
// in your account interface implementation.
|
||||
type ProviderConfig struct {
|
||||
NetworkConfig NetworkConfig `json:"network_config"` // Network configuration
|
||||
ConcurrencyAndBufferSize ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"` // Concurrency settings
|
||||
// Logger instance, can be provided by the user or bifrost default logger is used if not provided
|
||||
Logger Logger `json:"-"`
|
||||
ProxyConfig *ProxyConfig `json:"proxy_config,omitempty"` // Proxy configuration
|
||||
SendBackRawRequest bool `json:"send_back_raw_request"` // Send raw request back in the bifrost response (default: false)
|
||||
SendBackRawResponse bool `json:"send_back_raw_response"` // Send raw response back in the bifrost response (default: false)
|
||||
StoreRawRequestResponse bool `json:"store_raw_request_response"` // Capture raw request/response for internal logging only; strip from API responses returned to clients (default: false)
|
||||
CustomProviderConfig *CustomProviderConfig `json:"custom_provider_config,omitempty"`
|
||||
OpenAIConfig *OpenAIConfig `json:"openai_config,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIConfig holds OpenAI-specific provider configuration.
|
||||
type OpenAIConfig struct {
|
||||
DisableStore bool `json:"disable_store"` // When true, forces store=false on all outgoing OpenAI requests (default: false)
|
||||
}
|
||||
|
||||
func (config *ProviderConfig) CheckAndSetDefaults() {
|
||||
if config.ConcurrencyAndBufferSize.Concurrency == 0 {
|
||||
config.ConcurrencyAndBufferSize.Concurrency = DefaultConcurrency
|
||||
}
|
||||
|
||||
if config.ConcurrencyAndBufferSize.BufferSize == 0 {
|
||||
config.ConcurrencyAndBufferSize.BufferSize = DefaultBufferSize
|
||||
}
|
||||
|
||||
if config.NetworkConfig.DefaultRequestTimeoutInSeconds <= 0 {
|
||||
config.NetworkConfig.DefaultRequestTimeoutInSeconds = DefaultRequestTimeoutInSeconds
|
||||
}
|
||||
|
||||
if config.NetworkConfig.MaxRetries == 0 {
|
||||
config.NetworkConfig.MaxRetries = DefaultMaxRetries
|
||||
}
|
||||
|
||||
if config.NetworkConfig.RetryBackoffInitial == 0 {
|
||||
config.NetworkConfig.RetryBackoffInitial = DefaultRetryBackoffInitial
|
||||
}
|
||||
|
||||
if config.NetworkConfig.RetryBackoffMax == 0 {
|
||||
config.NetworkConfig.RetryBackoffMax = DefaultRetryBackoffMax
|
||||
}
|
||||
|
||||
if config.NetworkConfig.StreamIdleTimeoutInSeconds <= 0 {
|
||||
config.NetworkConfig.StreamIdleTimeoutInSeconds = DefaultStreamIdleTimeoutInSeconds
|
||||
}
|
||||
|
||||
if config.NetworkConfig.MaxConnsPerHost <= 0 {
|
||||
config.NetworkConfig.MaxConnsPerHost = DefaultMaxConnsPerHost
|
||||
} else if config.NetworkConfig.MaxConnsPerHost > MaxConnsPerHostUpperBound {
|
||||
config.NetworkConfig.MaxConnsPerHost = MaxConnsPerHostUpperBound
|
||||
}
|
||||
|
||||
// Create a defensive copy of ExtraHeaders to prevent data races
|
||||
if config.NetworkConfig.ExtraHeaders != nil {
|
||||
headersCopy := make(map[string]string, len(config.NetworkConfig.ExtraHeaders))
|
||||
maps.Copy(headersCopy, config.NetworkConfig.ExtraHeaders)
|
||||
config.NetworkConfig.ExtraHeaders = headersCopy
|
||||
}
|
||||
|
||||
// Create a defensive copy of BetaHeaderOverrides to prevent data races
|
||||
if config.NetworkConfig.BetaHeaderOverrides != nil {
|
||||
overridesCopy := make(map[string]bool, len(config.NetworkConfig.BetaHeaderOverrides))
|
||||
maps.Copy(overridesCopy, config.NetworkConfig.BetaHeaderOverrides)
|
||||
config.NetworkConfig.BetaHeaderOverrides = overridesCopy
|
||||
}
|
||||
}
|
||||
|
||||
type PostHookRunner func(ctx *BifrostContext, result *BifrostResponse, err *BifrostError) (*BifrostResponse, *BifrostError)
|
||||
|
||||
// Provider defines the interface for AI model providers.
|
||||
type Provider interface {
|
||||
// GetProviderKey returns the provider's identifier
|
||||
GetProviderKey() ModelProvider
|
||||
// ListModels performs a list models request
|
||||
ListModels(ctx *BifrostContext, keys []Key, request *BifrostListModelsRequest) (*BifrostListModelsResponse, *BifrostError)
|
||||
// TextCompletion performs a text completion request
|
||||
TextCompletion(ctx *BifrostContext, key Key, request *BifrostTextCompletionRequest) (*BifrostTextCompletionResponse, *BifrostError)
|
||||
// TextCompletionStream performs a text completion stream request.
|
||||
// postHookSpanFinalizer is invoked by the provider's stream goroutine on stream completion
|
||||
// (or on its panic-recovery defer) to finalize aggregated post-hook spans and release the
|
||||
// per-attempt plugin pipeline. Pass nil if the caller does not need finalization.
|
||||
TextCompletionStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key, request *BifrostTextCompletionRequest) (chan *BifrostStreamChunk, *BifrostError)
|
||||
// ChatCompletion performs a chat completion request
|
||||
ChatCompletion(ctx *BifrostContext, key Key, request *BifrostChatRequest) (*BifrostChatResponse, *BifrostError)
|
||||
// ChatCompletionStream performs a chat completion stream request
|
||||
ChatCompletionStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key, request *BifrostChatRequest) (chan *BifrostStreamChunk, *BifrostError)
|
||||
// Responses performs a completion request using the Responses API (uses chat completion request internally for non-openai providers)
|
||||
Responses(ctx *BifrostContext, key Key, request *BifrostResponsesRequest) (*BifrostResponsesResponse, *BifrostError)
|
||||
// ResponsesStream performs a completion request using the Responses API stream (uses chat completion stream request internally for non-openai providers)
|
||||
ResponsesStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key, request *BifrostResponsesRequest) (chan *BifrostStreamChunk, *BifrostError)
|
||||
// CountTokens performs a count tokens request
|
||||
CountTokens(ctx *BifrostContext, key Key, request *BifrostResponsesRequest) (*BifrostCountTokensResponse, *BifrostError)
|
||||
// Embedding performs an embedding request
|
||||
Embedding(ctx *BifrostContext, key Key, request *BifrostEmbeddingRequest) (*BifrostEmbeddingResponse, *BifrostError)
|
||||
// Rerank performs a rerank request to reorder documents by relevance to a query
|
||||
Rerank(ctx *BifrostContext, key Key, request *BifrostRerankRequest) (*BifrostRerankResponse, *BifrostError)
|
||||
// OCR performs an optical character recognition request on a document
|
||||
OCR(ctx *BifrostContext, key Key, request *BifrostOCRRequest) (*BifrostOCRResponse, *BifrostError)
|
||||
// Speech performs a text to speech request
|
||||
Speech(ctx *BifrostContext, key Key, request *BifrostSpeechRequest) (*BifrostSpeechResponse, *BifrostError)
|
||||
// SpeechStream performs a text to speech stream request
|
||||
SpeechStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key, request *BifrostSpeechRequest) (chan *BifrostStreamChunk, *BifrostError)
|
||||
// Transcription performs a transcription request
|
||||
Transcription(ctx *BifrostContext, key Key, request *BifrostTranscriptionRequest) (*BifrostTranscriptionResponse, *BifrostError)
|
||||
// TranscriptionStream performs a transcription stream request
|
||||
TranscriptionStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key, request *BifrostTranscriptionRequest) (chan *BifrostStreamChunk, *BifrostError)
|
||||
// ImageGeneration performs an image generation request
|
||||
ImageGeneration(ctx *BifrostContext, key Key, request *BifrostImageGenerationRequest) (
|
||||
*BifrostImageGenerationResponse, *BifrostError)
|
||||
// ImageGenerationStream performs an image generation stream request
|
||||
ImageGenerationStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key,
|
||||
request *BifrostImageGenerationRequest) (chan *BifrostStreamChunk, *BifrostError)
|
||||
// ImageEdit performs an image edit request
|
||||
ImageEdit(ctx *BifrostContext, key Key, request *BifrostImageEditRequest) (*BifrostImageGenerationResponse, *BifrostError)
|
||||
// ImageEditStream performs an image edit stream request
|
||||
ImageEditStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key,
|
||||
request *BifrostImageEditRequest) (chan *BifrostStreamChunk, *BifrostError)
|
||||
// ImageVariation performs an image variation request
|
||||
ImageVariation(ctx *BifrostContext, key Key, request *BifrostImageVariationRequest) (*BifrostImageGenerationResponse, *BifrostError)
|
||||
// VideoGeneration performs a video generation request
|
||||
VideoGeneration(ctx *BifrostContext, key Key, request *BifrostVideoGenerationRequest) (*BifrostVideoGenerationResponse, *BifrostError)
|
||||
// VideoRetrieve retrieves a video from the provider
|
||||
VideoRetrieve(ctx *BifrostContext, key Key, request *BifrostVideoRetrieveRequest) (*BifrostVideoGenerationResponse, *BifrostError)
|
||||
// VideoDownload downloads a video from the provider
|
||||
VideoDownload(ctx *BifrostContext, key Key, request *BifrostVideoDownloadRequest) (*BifrostVideoDownloadResponse, *BifrostError)
|
||||
// VideoDelete deletes a video from the provider
|
||||
VideoDelete(ctx *BifrostContext, key Key, request *BifrostVideoDeleteRequest) (*BifrostVideoDeleteResponse, *BifrostError)
|
||||
// VideoList lists videos from the provider
|
||||
VideoList(ctx *BifrostContext, key Key, request *BifrostVideoListRequest) (*BifrostVideoListResponse, *BifrostError)
|
||||
// VideoRemix remixes a video from the provider
|
||||
VideoRemix(ctx *BifrostContext, key Key, request *BifrostVideoRemixRequest) (*BifrostVideoGenerationResponse, *BifrostError)
|
||||
// BatchCreate creates a new batch job for asynchronous processing
|
||||
BatchCreate(ctx *BifrostContext, key Key, request *BifrostBatchCreateRequest) (*BifrostBatchCreateResponse, *BifrostError)
|
||||
// BatchList lists batch jobs
|
||||
BatchList(ctx *BifrostContext, keys []Key, request *BifrostBatchListRequest) (*BifrostBatchListResponse, *BifrostError)
|
||||
// BatchRetrieve retrieves a specific batch job
|
||||
BatchRetrieve(ctx *BifrostContext, keys []Key, request *BifrostBatchRetrieveRequest) (*BifrostBatchRetrieveResponse, *BifrostError)
|
||||
// BatchCancel cancels a batch job
|
||||
BatchCancel(ctx *BifrostContext, keys []Key, request *BifrostBatchCancelRequest) (*BifrostBatchCancelResponse, *BifrostError)
|
||||
// BatchDelete deletes a batch job
|
||||
BatchDelete(ctx *BifrostContext, keys []Key, request *BifrostBatchDeleteRequest) (*BifrostBatchDeleteResponse, *BifrostError)
|
||||
// BatchResults retrieves results from a completed batch job
|
||||
BatchResults(ctx *BifrostContext, keys []Key, request *BifrostBatchResultsRequest) (*BifrostBatchResultsResponse, *BifrostError)
|
||||
// FileUpload uploads a file to the provider
|
||||
FileUpload(ctx *BifrostContext, key Key, request *BifrostFileUploadRequest) (*BifrostFileUploadResponse, *BifrostError)
|
||||
// FileList lists files from the provider
|
||||
FileList(ctx *BifrostContext, keys []Key, request *BifrostFileListRequest) (*BifrostFileListResponse, *BifrostError)
|
||||
// FileRetrieve retrieves file metadata from the provider
|
||||
FileRetrieve(ctx *BifrostContext, keys []Key, request *BifrostFileRetrieveRequest) (*BifrostFileRetrieveResponse, *BifrostError)
|
||||
// FileDelete deletes a file from the provider
|
||||
FileDelete(ctx *BifrostContext, keys []Key, request *BifrostFileDeleteRequest) (*BifrostFileDeleteResponse, *BifrostError)
|
||||
// FileContent downloads file content from the provider
|
||||
FileContent(ctx *BifrostContext, keys []Key, request *BifrostFileContentRequest) (*BifrostFileContentResponse, *BifrostError)
|
||||
// ContainerCreate creates a new container
|
||||
ContainerCreate(ctx *BifrostContext, key Key, request *BifrostContainerCreateRequest) (*BifrostContainerCreateResponse, *BifrostError)
|
||||
// ContainerList lists containers
|
||||
ContainerList(ctx *BifrostContext, keys []Key, request *BifrostContainerListRequest) (*BifrostContainerListResponse, *BifrostError)
|
||||
// ContainerRetrieve retrieves a specific container
|
||||
ContainerRetrieve(ctx *BifrostContext, keys []Key, request *BifrostContainerRetrieveRequest) (*BifrostContainerRetrieveResponse, *BifrostError)
|
||||
// ContainerDelete deletes a container
|
||||
ContainerDelete(ctx *BifrostContext, keys []Key, request *BifrostContainerDeleteRequest) (*BifrostContainerDeleteResponse, *BifrostError)
|
||||
// ContainerFileCreate creates a file in a container
|
||||
ContainerFileCreate(ctx *BifrostContext, key Key, request *BifrostContainerFileCreateRequest) (*BifrostContainerFileCreateResponse, *BifrostError)
|
||||
// ContainerFileList lists files in a container
|
||||
ContainerFileList(ctx *BifrostContext, keys []Key, request *BifrostContainerFileListRequest) (*BifrostContainerFileListResponse, *BifrostError)
|
||||
// ContainerFileRetrieve retrieves a file from a container
|
||||
ContainerFileRetrieve(ctx *BifrostContext, keys []Key, request *BifrostContainerFileRetrieveRequest) (*BifrostContainerFileRetrieveResponse, *BifrostError)
|
||||
// ContainerFileContent retrieves the content of a file from a container
|
||||
ContainerFileContent(ctx *BifrostContext, keys []Key, request *BifrostContainerFileContentRequest) (*BifrostContainerFileContentResponse, *BifrostError)
|
||||
// ContainerFileDelete deletes a file from a container
|
||||
ContainerFileDelete(ctx *BifrostContext, keys []Key, request *BifrostContainerFileDeleteRequest) (*BifrostContainerFileDeleteResponse, *BifrostError)
|
||||
// Passthrough executes a non-streaming passthrough; body is fully buffered.
|
||||
Passthrough(ctx *BifrostContext, key Key, req *BifrostPassthroughRequest) (*BifrostPassthroughResponse, *BifrostError)
|
||||
// PassthroughStream executes a streaming passthrough, forwarding raw response bytes as BifrostStreamChunks.
|
||||
PassthroughStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key, req *BifrostPassthroughRequest) (chan *BifrostStreamChunk, *BifrostError)
|
||||
}
|
||||
|
||||
// WebSocketCapableProvider is an optional interface that providers can implement
|
||||
// to indicate support for the OpenAI Responses API WebSocket Mode.
|
||||
// Checked via type assertion: provider.(WebSocketCapableProvider).
|
||||
// Providers that implement this interface will have native WS upstream connections
|
||||
// instead of the HTTP bridge fallback for Responses WS mode.
|
||||
type WebSocketCapableProvider interface {
|
||||
// SupportsWebSocketMode returns true if the provider supports the Responses API WebSocket Mode.
|
||||
SupportsWebSocketMode() bool
|
||||
// WebSocketResponsesURL returns the WebSocket URL for the Responses API.
|
||||
WebSocketResponsesURL(key Key) string
|
||||
// WebSocketHeaders returns the headers required for the upstream WebSocket connection.
|
||||
WebSocketHeaders(key Key) map[string]string
|
||||
}
|
||||
310
core/schemas/realtime.go
Normal file
310
core/schemas/realtime.go
Normal file
@@ -0,0 +1,310 @@
|
||||
package schemas
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// RealtimeEventType represents the type of a Bifrost unified Realtime event.
|
||||
type RealtimeEventType string
|
||||
|
||||
// Client-to-server event types (sent by the client through Bifrost)
|
||||
const (
|
||||
RTEventSessionUpdate RealtimeEventType = "session.update"
|
||||
RTEventConversationItemCreate RealtimeEventType = "conversation.item.create"
|
||||
RTEventConversationItemDelete RealtimeEventType = "conversation.item.delete"
|
||||
RTEventResponseCreate RealtimeEventType = "response.create"
|
||||
RTEventResponseCancel RealtimeEventType = "response.cancel"
|
||||
RTEventInputAudioAppend RealtimeEventType = "input_audio_buffer.append"
|
||||
RTEventInputAudioCommit RealtimeEventType = "input_audio_buffer.commit"
|
||||
RTEventInputAudioClear RealtimeEventType = "input_audio_buffer.clear"
|
||||
)
|
||||
|
||||
// Server-to-client event types (received from the provider, forwarded to client)
|
||||
const (
|
||||
RTEventSessionCreated RealtimeEventType = "session.created"
|
||||
RTEventSessionUpdated RealtimeEventType = "session.updated"
|
||||
RTEventConversationCreated RealtimeEventType = "conversation.created"
|
||||
RTEventConversationItemAdded RealtimeEventType = "conversation.item.added"
|
||||
RTEventConversationItemCreated RealtimeEventType = "conversation.item.created"
|
||||
RTEventConversationItemRetrieved RealtimeEventType = "conversation.item.retrieved"
|
||||
RTEventConversationItemDone RealtimeEventType = "conversation.item.done"
|
||||
RTEventResponseCreated RealtimeEventType = "response.created"
|
||||
RTEventResponseDone RealtimeEventType = "response.done"
|
||||
RTEventResponseTextDelta RealtimeEventType = "response.text.delta"
|
||||
RTEventResponseTextDone RealtimeEventType = "response.text.done"
|
||||
RTEventResponseAudioDelta RealtimeEventType = "response.audio.delta"
|
||||
RTEventResponseAudioDone RealtimeEventType = "response.audio.done"
|
||||
RTEventResponseAudioTransDelta RealtimeEventType = "response.audio_transcript.delta"
|
||||
RTEventResponseAudioTransDone RealtimeEventType = "response.audio_transcript.done"
|
||||
RTEventResponseOutputItemAdded RealtimeEventType = "response.output_item.added"
|
||||
RTEventResponseOutputItemDone RealtimeEventType = "response.output_item.done"
|
||||
RTEventResponseContentPartAdded RealtimeEventType = "response.content_part.added"
|
||||
RTEventResponseContentPartDone RealtimeEventType = "response.content_part.done"
|
||||
RTEventRateLimitsUpdated RealtimeEventType = "rate_limits.updated"
|
||||
RTEventInputAudioTransCompleted RealtimeEventType = "conversation.item.input_audio_transcription.completed"
|
||||
RTEventInputAudioTransDelta RealtimeEventType = "conversation.item.input_audio_transcription.delta"
|
||||
RTEventInputAudioTransFailed RealtimeEventType = "conversation.item.input_audio_transcription.failed"
|
||||
RTEventInputAudioBufferCommitted RealtimeEventType = "input_audio_buffer.committed"
|
||||
RTEventInputAudioBufferCleared RealtimeEventType = "input_audio_buffer.cleared"
|
||||
RTEventInputAudioSpeechStarted RealtimeEventType = "input_audio_buffer.speech_started"
|
||||
RTEventInputAudioSpeechStopped RealtimeEventType = "input_audio_buffer.speech_stopped"
|
||||
RTEventError RealtimeEventType = "error"
|
||||
)
|
||||
|
||||
// IsRealtimeConversationItemEventType reports whether the event carries a
|
||||
// canonical conversation item payload after provider translation.
|
||||
func IsRealtimeConversationItemEventType(eventType RealtimeEventType) bool {
|
||||
switch eventType {
|
||||
case RTEventConversationItemCreate,
|
||||
RTEventConversationItemAdded,
|
||||
RTEventConversationItemCreated,
|
||||
RTEventConversationItemRetrieved,
|
||||
RTEventConversationItemDone:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// IsRealtimeUserInputEvent reports whether the event represents a finalized
|
||||
// user input item in the canonical Bifrost realtime schema.
|
||||
func IsRealtimeUserInputEvent(event *BifrostRealtimeEvent) bool {
|
||||
return event != nil &&
|
||||
event.Item != nil &&
|
||||
event.Item.Role == "user" &&
|
||||
IsRealtimeConversationItemEventType(event.Type)
|
||||
}
|
||||
|
||||
// IsRealtimeToolOutputEvent reports whether the event represents a finalized
|
||||
// tool output item in the canonical Bifrost realtime schema.
|
||||
func IsRealtimeToolOutputEvent(event *BifrostRealtimeEvent) bool {
|
||||
return event != nil &&
|
||||
event.Item != nil &&
|
||||
event.Item.Type == "function_call_output" &&
|
||||
IsRealtimeConversationItemEventType(event.Type)
|
||||
}
|
||||
|
||||
// IsRealtimeInputTranscriptEvent reports whether the event carries a finalized
|
||||
// input-audio transcript in the canonical Bifrost realtime schema.
|
||||
func IsRealtimeInputTranscriptEvent(event *BifrostRealtimeEvent) bool {
|
||||
return event != nil && event.Type == RTEventInputAudioTransCompleted
|
||||
}
|
||||
|
||||
// BifrostRealtimeEvent is the unified Bifrost envelope for all Realtime events.
|
||||
// Provider converters translate between this format and the provider-native protocol.
|
||||
type BifrostRealtimeEvent struct {
|
||||
Type RealtimeEventType `json:"type"`
|
||||
EventID string `json:"event_id,omitempty"`
|
||||
|
||||
Session *RealtimeSession `json:"session,omitempty"`
|
||||
Item *RealtimeItem `json:"item,omitempty"`
|
||||
Delta *RealtimeDelta `json:"delta,omitempty"`
|
||||
Audio []byte `json:"audio,omitempty"`
|
||||
Error *RealtimeError `json:"error,omitempty"`
|
||||
|
||||
// ExtraParams preserves provider-specific top-level event fields that are not
|
||||
// promoted into the common Bifrost schema.
|
||||
ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"`
|
||||
|
||||
// RawData preserves the original provider event for pass-through or debugging.
|
||||
RawData json.RawMessage `json:"raw_data,omitempty"`
|
||||
}
|
||||
|
||||
// RealtimeSession describes session configuration for the Realtime connection.
|
||||
type RealtimeSession struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Modalities []string `json:"modalities,omitempty"`
|
||||
Instructions string `json:"instructions,omitempty"`
|
||||
Voice string `json:"voice,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
MaxOutputTokens json.RawMessage `json:"max_output_tokens,omitempty"`
|
||||
TurnDetection json.RawMessage `json:"turn_detection,omitempty"`
|
||||
InputAudioFormat string `json:"input_audio_format,omitempty"`
|
||||
OutputAudioType string `json:"output_audio_type,omitempty"`
|
||||
Tools json.RawMessage `json:"tools,omitempty"`
|
||||
ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"`
|
||||
}
|
||||
|
||||
// RealtimeItem represents a conversation item in the Realtime protocol.
|
||||
type RealtimeItem struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
CallID string `json:"call_id,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
Output string `json:"output,omitempty"`
|
||||
ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"`
|
||||
}
|
||||
|
||||
// RealtimeDelta carries incremental content for streaming events.
|
||||
type RealtimeDelta struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Audio string `json:"audio,omitempty"`
|
||||
Transcript string `json:"transcript,omitempty"`
|
||||
ItemID string `json:"item_id,omitempty"`
|
||||
OutputIdx *int `json:"output_index,omitempty"`
|
||||
ContentIdx *int `json:"content_index,omitempty"`
|
||||
ResponseID string `json:"response_id,omitempty"`
|
||||
}
|
||||
|
||||
// RealtimeError describes an error from the Realtime API.
|
||||
type RealtimeError struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Param string `json:"param,omitempty"`
|
||||
ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"`
|
||||
}
|
||||
|
||||
// RealtimeSessionEndpointType identifies the public ephemeral-token endpoint
|
||||
// shape the client called so providers can preserve versioned behavior.
|
||||
type RealtimeSessionEndpointType string
|
||||
|
||||
const (
|
||||
RealtimeSessionEndpointClientSecrets RealtimeSessionEndpointType = "client_secrets"
|
||||
RealtimeSessionEndpointSessions RealtimeSessionEndpointType = "sessions"
|
||||
)
|
||||
|
||||
// RealtimeSessionRoute describes a provider-registered public route for
|
||||
// ephemeral-token creation.
|
||||
type RealtimeSessionRoute struct {
|
||||
Path string
|
||||
EndpointType RealtimeSessionEndpointType
|
||||
DefaultProvider ModelProvider
|
||||
}
|
||||
|
||||
// RealtimeProvider is an optional interface that providers can implement to
|
||||
// indicate support for bidirectional Realtime API (audio/text streaming).
|
||||
// Checked via type assertion: provider.(RealtimeProvider).
|
||||
type RealtimeProvider interface {
|
||||
SupportsRealtimeAPI() bool
|
||||
RealtimeWebSocketURL(key Key, model string) string
|
||||
RealtimeHeaders(key Key) map[string]string
|
||||
// SupportsRealtimeWebRTC reports whether the provider supports WebRTC SDP exchange.
|
||||
SupportsRealtimeWebRTC() bool
|
||||
// ExchangeRealtimeWebRTCSDP performs the provider-specific SDP signaling exchange.
|
||||
// The provider owns the HTTP specifics (URL, headers, body format).
|
||||
// session may be nil if the signaling format doesn't include session config.
|
||||
ExchangeRealtimeWebRTCSDP(ctx *BifrostContext, key Key, model string, sdp string, session json.RawMessage) (string, *BifrostError)
|
||||
ToBifrostRealtimeEvent(providerEvent json.RawMessage) (*BifrostRealtimeEvent, error)
|
||||
ToProviderRealtimeEvent(bifrostEvent *BifrostRealtimeEvent) (json.RawMessage, error)
|
||||
// ShouldStartRealtimeTurn reports whether the canonical client-side event
|
||||
// should start pre-hooks. Providers without an explicit turn-start signal
|
||||
// return false and rely on finalize-time fallback hooks.
|
||||
ShouldStartRealtimeTurn(event *BifrostRealtimeEvent) bool
|
||||
// RealtimeTurnFinalEvent returns the canonical provider event that completes
|
||||
// a turn and should trigger post-hooks.
|
||||
RealtimeTurnFinalEvent() RealtimeEventType
|
||||
RealtimeWebRTCDataChannelLabel() string
|
||||
RealtimeWebSocketSubprotocol() string
|
||||
ShouldForwardRealtimeEvent(event *BifrostRealtimeEvent) bool
|
||||
ShouldAccumulateRealtimeOutput(eventType RealtimeEventType) bool
|
||||
}
|
||||
|
||||
// RealtimeLegacyWebRTCProvider is an optional interface for providers that
|
||||
// support the beta WebRTC handshake (e.g., OpenAI's /v1/realtime).
|
||||
// Only checked for legacy integration routes via type assertion.
|
||||
// Takes SDP offer + optional session JSON, same as ExchangeRealtimeWebRTCSDP
|
||||
// but targets the provider's legacy/beta endpoint.
|
||||
type RealtimeLegacyWebRTCProvider interface {
|
||||
ExchangeLegacyRealtimeWebRTCSDP(ctx *BifrostContext, key Key, sdp string, session json.RawMessage, model string) (string, *BifrostError)
|
||||
}
|
||||
|
||||
// RealtimeUsageExtractor lets providers parse terminal-turn usage/output from
|
||||
// their native wire payloads without coupling handlers to a specific protocol.
|
||||
type RealtimeUsageExtractor interface {
|
||||
ExtractRealtimeTurnUsage(terminalEventRaw []byte) *BifrostLLMUsage
|
||||
ExtractRealtimeTurnOutput(terminalEventRaw []byte) *ChatMessage
|
||||
}
|
||||
|
||||
// RealtimeSessionProvider is an optional interface for providers that can mint
|
||||
// short-lived client secrets for browser/client-side Realtime connections.
|
||||
// Checked via type assertion: provider.(RealtimeSessionProvider).
|
||||
type RealtimeSessionProvider interface {
|
||||
CreateRealtimeClientSecret(ctx *BifrostContext, key Key, endpointType RealtimeSessionEndpointType, rawRequest json.RawMessage) (*BifrostPassthroughResponse, *BifrostError)
|
||||
}
|
||||
|
||||
// ParseRealtimeEvent decodes a client/provider realtime event while preserving
|
||||
// unknown top-level fields in ExtraParams for provider-specific round-tripping.
|
||||
func ParseRealtimeEvent(raw []byte) (*BifrostRealtimeEvent, error) {
|
||||
type realtimeEventAlias struct {
|
||||
Type RealtimeEventType `json:"type"`
|
||||
EventID string `json:"event_id,omitempty"`
|
||||
Session *RealtimeSession `json:"session,omitempty"`
|
||||
Item *RealtimeItem `json:"item,omitempty"`
|
||||
Delta *RealtimeDelta `json:"delta,omitempty"`
|
||||
Audio []byte `json:"audio,omitempty"`
|
||||
Error *RealtimeError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
var alias realtimeEventAlias
|
||||
if err := Unmarshal(raw, &alias); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
event := &BifrostRealtimeEvent{
|
||||
Type: alias.Type,
|
||||
EventID: alias.EventID,
|
||||
Session: alias.Session,
|
||||
Item: alias.Item,
|
||||
Delta: alias.Delta,
|
||||
Audio: alias.Audio,
|
||||
Error: alias.Error,
|
||||
}
|
||||
|
||||
var root map[string]json.RawMessage
|
||||
if err := Unmarshal(raw, &root); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
savedSession := root["session"]
|
||||
savedItem := root["item"]
|
||||
savedError := root["error"]
|
||||
for _, key := range []string{"type", "event_id", "session", "item", "delta", "audio", "error", "raw_data"} {
|
||||
delete(root, key)
|
||||
}
|
||||
if len(root) > 0 {
|
||||
event.ExtraParams = root
|
||||
}
|
||||
if event.Session != nil {
|
||||
var sessionRoot map[string]json.RawMessage
|
||||
if len(savedSession) > 0 && Unmarshal(savedSession, &sessionRoot) == nil {
|
||||
for _, key := range []string{
|
||||
"id", "model", "modalities", "instructions", "voice", "temperature",
|
||||
"max_output_tokens", "turn_detection", "input_audio_format", "output_audio_type", "tools",
|
||||
} {
|
||||
delete(sessionRoot, key)
|
||||
}
|
||||
if len(sessionRoot) > 0 {
|
||||
event.Session.ExtraParams = sessionRoot
|
||||
}
|
||||
}
|
||||
}
|
||||
if event.Item != nil {
|
||||
var itemRoot map[string]json.RawMessage
|
||||
if len(savedItem) > 0 && Unmarshal(savedItem, &itemRoot) == nil {
|
||||
for _, key := range []string{
|
||||
"id", "type", "role", "status", "content", "name", "call_id", "arguments", "output",
|
||||
} {
|
||||
delete(itemRoot, key)
|
||||
}
|
||||
if len(itemRoot) > 0 {
|
||||
event.Item.ExtraParams = itemRoot
|
||||
}
|
||||
}
|
||||
}
|
||||
if event.Error != nil {
|
||||
var errorRoot map[string]json.RawMessage
|
||||
if len(savedError) > 0 && Unmarshal(savedError, &errorRoot) == nil {
|
||||
for _, key := range []string{"type", "code", "message", "param"} {
|
||||
delete(errorRoot, key)
|
||||
}
|
||||
if len(errorRoot) > 0 {
|
||||
event.Error.ExtraParams = errorRoot
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return event, nil
|
||||
}
|
||||
66
core/schemas/realtime_client_secrets.go
Normal file
66
core/schemas/realtime_client_secrets.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ParseRealtimeClientSecretBody parses a realtime client-secret request body
|
||||
// into a mutable raw JSON map while preserving unknown fields.
|
||||
func ParseRealtimeClientSecretBody(raw json.RawMessage) (map[string]json.RawMessage, *BifrostError) {
|
||||
var root map[string]json.RawMessage
|
||||
if err := Unmarshal(raw, &root); err != nil {
|
||||
return nil, NewRealtimeClientSecretBodyError(400, "invalid_request_error", "invalid JSON body", err)
|
||||
}
|
||||
return root, nil
|
||||
}
|
||||
|
||||
// ExtractRealtimeClientSecretModel extracts the model from either session.model
|
||||
// or the legacy top-level model field.
|
||||
func ExtractRealtimeClientSecretModel(root map[string]json.RawMessage) (string, *BifrostError) {
|
||||
if sessionJSON, ok := root["session"]; ok && len(sessionJSON) > 0 && !bytes.Equal(sessionJSON, []byte("null")) {
|
||||
var session map[string]json.RawMessage
|
||||
if err := Unmarshal(sessionJSON, &session); err != nil {
|
||||
return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "session must be an object", err)
|
||||
}
|
||||
if modelJSON, ok := session["model"]; ok {
|
||||
var sessionModel string
|
||||
if err := Unmarshal(modelJSON, &sessionModel); err != nil {
|
||||
return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "session.model must be a string", err)
|
||||
}
|
||||
if strings.TrimSpace(sessionModel) != "" {
|
||||
return strings.TrimSpace(sessionModel), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if modelJSON, ok := root["model"]; ok {
|
||||
var model string
|
||||
if err := Unmarshal(modelJSON, &model); err != nil {
|
||||
return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "model must be a string", err)
|
||||
}
|
||||
if strings.TrimSpace(model) != "" {
|
||||
return strings.TrimSpace(model), nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "session.model or model is required", nil)
|
||||
}
|
||||
|
||||
// NewRealtimeClientSecretBodyError builds a standard invalid-request style error
|
||||
// for HTTP realtime client-secret request parsing/validation.
|
||||
func NewRealtimeClientSecretBodyError(status int, errorType, message string, err error) *BifrostError {
|
||||
return &BifrostError{
|
||||
IsBifrostError: false,
|
||||
StatusCode: Ptr(status),
|
||||
Error: &ErrorField{
|
||||
Type: Ptr(errorType),
|
||||
Message: message,
|
||||
Error: err,
|
||||
},
|
||||
ExtraFields: BifrostErrorExtraFields{
|
||||
RequestType: RealtimeRequest,
|
||||
},
|
||||
}
|
||||
}
|
||||
40
core/schemas/realtime_client_secrets_test.go
Normal file
40
core/schemas/realtime_client_secrets_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractRealtimeClientSecretModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
root, err := ParseRealtimeClientSecretBody(json.RawMessage(`{"session":{"model":"openai/gpt-4o-realtime-preview"}}`))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseRealtimeClientSecretBody() error = %v", err)
|
||||
}
|
||||
|
||||
model, err := ExtractRealtimeClientSecretModel(root)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractRealtimeClientSecretModel() error = %v", err)
|
||||
}
|
||||
if model != "openai/gpt-4o-realtime-preview" {
|
||||
t.Fatalf("model = %q, want %q", model, "openai/gpt-4o-realtime-preview")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractRealtimeClientSecretModelFallbackTopLevel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
root, err := ParseRealtimeClientSecretBody(json.RawMessage(`{"model":"gpt-4o-realtime-preview"}`))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseRealtimeClientSecretBody() error = %v", err)
|
||||
}
|
||||
|
||||
model, err := ExtractRealtimeClientSecretModel(root)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractRealtimeClientSecretModel() error = %v", err)
|
||||
}
|
||||
if model != "gpt-4o-realtime-preview" {
|
||||
t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview")
|
||||
}
|
||||
}
|
||||
68
core/schemas/realtime_test.go
Normal file
68
core/schemas/realtime_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package schemas
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestIsRealtimeConversationItemEventType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
eventType RealtimeEventType
|
||||
want bool
|
||||
}{
|
||||
{name: "create", eventType: RTEventConversationItemCreate, want: true},
|
||||
{name: "added", eventType: RTEventConversationItemAdded, want: true},
|
||||
{name: "created", eventType: RTEventConversationItemCreated, want: true},
|
||||
{name: "retrieved", eventType: RTEventConversationItemRetrieved, want: true},
|
||||
{name: "done", eventType: RTEventConversationItemDone, want: true},
|
||||
{name: "response done", eventType: RTEventResponseDone, want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := IsRealtimeConversationItemEventType(tt.eventType); got != tt.want {
|
||||
t.Fatalf("IsRealtimeConversationItemEventType(%q) = %v, want %v", tt.eventType, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRealtimeCanonicalEventClassifiers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
userEvent := &BifrostRealtimeEvent{
|
||||
Type: RTEventConversationItemAdded,
|
||||
Item: &RealtimeItem{
|
||||
Role: "user",
|
||||
Type: "message",
|
||||
},
|
||||
}
|
||||
if !IsRealtimeUserInputEvent(userEvent) {
|
||||
t.Fatal("expected conversation.item.added user event to be classified as realtime user input")
|
||||
}
|
||||
if IsRealtimeToolOutputEvent(userEvent) {
|
||||
t.Fatal("did not expect conversation.item.added user event to be classified as realtime tool output")
|
||||
}
|
||||
|
||||
toolEvent := &BifrostRealtimeEvent{
|
||||
Type: RTEventConversationItemRetrieved,
|
||||
Item: &RealtimeItem{
|
||||
Type: "function_call_output",
|
||||
},
|
||||
}
|
||||
if !IsRealtimeToolOutputEvent(toolEvent) {
|
||||
t.Fatal("expected function_call_output item to be classified as realtime tool output")
|
||||
}
|
||||
if IsRealtimeUserInputEvent(toolEvent) {
|
||||
t.Fatal("did not expect function_call_output item to be classified as realtime user input")
|
||||
}
|
||||
|
||||
transcriptEvent := &BifrostRealtimeEvent{Type: RTEventInputAudioTransCompleted}
|
||||
if !IsRealtimeInputTranscriptEvent(transcriptEvent) {
|
||||
t.Fatal("expected input audio transcription completion to be classified as transcript event")
|
||||
}
|
||||
if IsRealtimeInputTranscriptEvent(&BifrostRealtimeEvent{Type: RTEventInputAudioTransDelta}) {
|
||||
t.Fatal("did not expect input audio transcription delta to be classified as transcript event")
|
||||
}
|
||||
}
|
||||
49
core/schemas/rerank.go
Normal file
49
core/schemas/rerank.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package schemas
|
||||
|
||||
// RerankDocument represents a document to be reranked.
|
||||
type RerankDocument struct {
|
||||
Text string `json:"text"`
|
||||
ID *string `json:"id,omitempty"`
|
||||
Meta map[string]interface{} `json:"meta,omitempty"`
|
||||
}
|
||||
|
||||
// RerankParameters contains optional parameters for a rerank request.
|
||||
type RerankParameters struct {
|
||||
TopN *int `json:"top_n,omitempty"`
|
||||
MaxTokensPerDoc *int `json:"max_tokens_per_doc,omitempty"`
|
||||
Priority *int `json:"priority,omitempty"`
|
||||
ReturnDocuments *bool `json:"return_documents,omitempty"`
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// BifrostRerankRequest represents a request to rerank documents by relevance to a query.
|
||||
type BifrostRerankRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
Query string `json:"query"`
|
||||
Documents []RerankDocument `json:"documents"`
|
||||
Params *RerankParameters `json:"params,omitempty"`
|
||||
Fallbacks []Fallback `json:"fallbacks,omitempty"`
|
||||
RawRequestBody []byte `json:"-"`
|
||||
}
|
||||
|
||||
// GetRawRequestBody returns the raw request body for the rerank request.
|
||||
func (r *BifrostRerankRequest) GetRawRequestBody() []byte {
|
||||
return r.RawRequestBody
|
||||
}
|
||||
|
||||
// RerankResult represents a single reranked document with its relevance score.
|
||||
type RerankResult struct {
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float64 `json:"relevance_score"`
|
||||
Document *RerankDocument `json:"document,omitempty"`
|
||||
}
|
||||
|
||||
// BifrostRerankResponse represents the response from a rerank request.
|
||||
type BifrostRerankResponse struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Results []RerankResult `json:"results"`
|
||||
Model string `json:"model"`
|
||||
Usage *BifrostLLMUsage `json:"usage,omitempty"`
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
2345
core/schemas/responses.go
Normal file
2345
core/schemas/responses.go
Normal file
File diff suppressed because it is too large
Load Diff
1356
core/schemas/serialization_test.go
Normal file
1356
core/schemas/serialization_test.go
Normal file
File diff suppressed because it is too large
Load Diff
169
core/schemas/speech.go
Normal file
169
core/schemas/speech.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
type BifrostSpeechRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
Input *SpeechInput `json:"input,omitempty"`
|
||||
Params *SpeechParameters `json:"params,omitempty"`
|
||||
Fallbacks []Fallback `json:"fallbacks,omitempty"`
|
||||
RawRequestBody []byte `json:"-"` // set bifrost-use-raw-request-body to true in ctx to use the raw request body. Bifrost will directly send this to the downstream provider.
|
||||
}
|
||||
|
||||
func (r *BifrostSpeechRequest) GetRawRequestBody() []byte {
|
||||
return r.RawRequestBody
|
||||
}
|
||||
|
||||
type BifrostSpeechResponse struct {
|
||||
Audio []byte `json:"audio"`
|
||||
Usage *SpeechUsage `json:"usage"`
|
||||
Alignment *SpeechAlignment `json:"alignment,omitempty"` // Character-level timing information
|
||||
NormalizedAlignment *SpeechAlignment `json:"normalized_alignment,omitempty"` // Character-level timing information for normalized text
|
||||
AudioBase64 *string `json:"audio_base64,omitempty"` // Base64-encoded audio (when timestamps are requested)
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
|
||||
func (r *BifrostSpeechResponse) BackfillParams(request *BifrostSpeechRequest) {
|
||||
if r == nil || request == nil || request.Input == nil {
|
||||
return
|
||||
}
|
||||
if r.Usage == nil {
|
||||
r.Usage = &SpeechUsage{}
|
||||
}
|
||||
r.Usage.InputChars = utf8.RuneCountInString(request.Input.Input)
|
||||
}
|
||||
|
||||
// SpeechAlignment represents character-level timing information for audio-text synchronization
|
||||
type SpeechAlignment struct {
|
||||
CharStartTimesMs []float64 `json:"char_start_times_ms"` // Start time in milliseconds for each character
|
||||
CharEndTimesMs []float64 `json:"char_end_times_ms"` // End time in milliseconds for each character
|
||||
Characters []string `json:"characters"` // Characters corresponding to timing info
|
||||
}
|
||||
|
||||
// SpeechInput represents the input for a speech request.
|
||||
type SpeechInput struct {
|
||||
Input string `json:"input"`
|
||||
}
|
||||
|
||||
type SpeechParameters struct {
|
||||
VoiceConfig *SpeechVoiceInput `json:"voice"`
|
||||
Instructions string `json:"instructions,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"` // Default is "mp3"
|
||||
Speed *float64 `json:"speed,omitempty"`
|
||||
|
||||
LanguageCode *string `json:"language_code,omitempty"`
|
||||
PronunciationDictionaryLocators []SpeechPronunciationDictionaryLocator `json:"pronunciation_dictionary_locators,omitempty"`
|
||||
EnableLogging *bool `json:"enable_logging,omitempty"`
|
||||
OptimizeStreamingLatency *bool `json:"optimize_streaming_latency,omitempty"`
|
||||
WithTimestamps *bool `json:"with_timestamps,omitempty"` // Returns character-level timing information
|
||||
|
||||
// Dynamic parameters that can be provider-specific, they are directly
|
||||
// added to the request as is.
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
type SpeechPronunciationDictionaryLocator struct {
|
||||
PronunciationDictionaryID string `json:"pronunciation_dictionary_id"`
|
||||
VersionID *string `json:"version_id,omitempty"`
|
||||
}
|
||||
|
||||
type SpeechVoiceInput struct {
|
||||
Voice *string
|
||||
MultiVoiceConfig []VoiceConfig
|
||||
}
|
||||
|
||||
type VoiceConfig struct {
|
||||
Speaker string `json:"speaker"`
|
||||
Voice string `json:"voice"`
|
||||
}
|
||||
|
||||
// MarshalJSON implements custom JSON marshalling for SpeechVoiceInput.
|
||||
// It marshals either Voice or MultiVoiceConfig directly without wrapping.
|
||||
func (vi *SpeechVoiceInput) MarshalJSON() ([]byte, error) {
|
||||
// Validation: ensure only one field is set at a time
|
||||
if vi.Voice != nil && len(vi.MultiVoiceConfig) > 0 {
|
||||
return nil, fmt.Errorf("both Voice and MultiVoiceConfig are set; only one should be non-nil")
|
||||
}
|
||||
|
||||
if vi.Voice != nil {
|
||||
return MarshalSorted(*vi.Voice)
|
||||
}
|
||||
if len(vi.MultiVoiceConfig) > 0 {
|
||||
return MarshalSorted(vi.MultiVoiceConfig)
|
||||
}
|
||||
// If both are nil, return null
|
||||
return MarshalSorted(nil)
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshalling for SpeechVoiceInput.
|
||||
// It determines whether "voice" is a string or a VoiceConfig object/array and assigns to the appropriate field.
|
||||
// It also handles direct string/array content without a wrapper object.
|
||||
func (vi *SpeechVoiceInput) UnmarshalJSON(data []byte) error {
|
||||
// Reset receiver state before attempting any decode to avoid stale data
|
||||
vi.Voice = nil
|
||||
vi.MultiVoiceConfig = nil
|
||||
|
||||
// First, try to unmarshal as a direct string
|
||||
var stringContent string
|
||||
if err := Unmarshal(data, &stringContent); err == nil {
|
||||
vi.Voice = &stringContent
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to unmarshal as an array of VoiceConfig objects
|
||||
var voiceConfigs []VoiceConfig
|
||||
if err := Unmarshal(data, &voiceConfigs); err == nil {
|
||||
// Validate each VoiceConfig and build a new slice deterministically
|
||||
validConfigs := make([]VoiceConfig, 0, len(voiceConfigs))
|
||||
for _, config := range voiceConfigs {
|
||||
if config.Voice == "" {
|
||||
return fmt.Errorf("voice config has empty voice field")
|
||||
}
|
||||
validConfigs = append(validConfigs, config)
|
||||
}
|
||||
vi.MultiVoiceConfig = validConfigs
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("voice field is neither a string, nor an array of VoiceConfig objects")
|
||||
}
|
||||
|
||||
type SpeechStreamResponseType string
|
||||
|
||||
const (
|
||||
SpeechStreamResponseTypeDelta SpeechStreamResponseType = "speech.audio.delta"
|
||||
SpeechStreamResponseTypeDone SpeechStreamResponseType = "speech.audio.done"
|
||||
)
|
||||
|
||||
type BifrostSpeechStreamResponse struct {
|
||||
Type SpeechStreamResponseType `json:"type"`
|
||||
Audio []byte `json:"audio"`
|
||||
Usage *SpeechUsage `json:"usage"`
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
|
||||
func (r *BifrostSpeechStreamResponse) BackfillParams(request *BifrostSpeechRequest) {
|
||||
if r == nil || request == nil || request.Input == nil {
|
||||
return
|
||||
}
|
||||
if r.Usage == nil {
|
||||
r.Usage = &SpeechUsage{}
|
||||
}
|
||||
r.Usage.InputChars = utf8.RuneCountInString(request.Input.Input)
|
||||
}
|
||||
|
||||
type SpeechUsageInputTokenDetails struct {
|
||||
TextTokens int `json:"text_tokens,omitempty"`
|
||||
AudioTokens int `json:"audio_tokens,omitempty"`
|
||||
}
|
||||
type SpeechUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
InputChars int `json:"input_chars,omitempty"`
|
||||
InputTokenDetails *SpeechUsageInputTokenDetails `json:"input_token_details,omitempty"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
146
core/schemas/textcompletions.go
Normal file
146
core/schemas/textcompletions.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// BifrostTextCompletionRequest is the request struct for text completion requests
|
||||
type BifrostTextCompletionRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
Input *TextCompletionInput `json:"input,omitempty"`
|
||||
Params *TextCompletionParameters `json:"params,omitempty"`
|
||||
Fallbacks []Fallback `json:"fallbacks,omitempty"`
|
||||
RawRequestBody []byte `json:"-"` // set bifrost-use-raw-request-body to true in ctx to use the raw request body. Bifrost will directly send this to the downstream provider.
|
||||
}
|
||||
|
||||
func (r *BifrostTextCompletionRequest) GetRawRequestBody() []byte {
|
||||
return r.RawRequestBody
|
||||
}
|
||||
|
||||
// ToBifrostChatRequest converts a Bifrost text completion request to a Bifrost chat completion request
|
||||
// This method is discouraged to use, but is useful for litellm fallback flows
|
||||
func (r *BifrostTextCompletionRequest) ToBifrostChatRequest() *BifrostChatRequest {
|
||||
if r == nil || r.Input == nil {
|
||||
return nil
|
||||
}
|
||||
message := ChatMessage{Role: ChatMessageRoleUser}
|
||||
if r.Input.PromptStr != nil {
|
||||
message.Content = &ChatMessageContent{
|
||||
ContentStr: r.Input.PromptStr,
|
||||
}
|
||||
} else if len(r.Input.PromptArray) > 0 {
|
||||
blocks := make([]ChatContentBlock, 0, len(r.Input.PromptArray))
|
||||
for _, prompt := range r.Input.PromptArray {
|
||||
blocks = append(blocks, ChatContentBlock{
|
||||
Type: ChatContentBlockTypeText,
|
||||
Text: &prompt,
|
||||
})
|
||||
}
|
||||
message.Content = &ChatMessageContent{
|
||||
ContentBlocks: blocks,
|
||||
}
|
||||
}
|
||||
params := ChatParameters{}
|
||||
if r.Params != nil {
|
||||
params.MaxCompletionTokens = r.Params.MaxTokens
|
||||
params.Temperature = r.Params.Temperature
|
||||
params.TopP = r.Params.TopP
|
||||
params.Stop = r.Params.Stop
|
||||
params.ExtraParams = r.Params.ExtraParams
|
||||
params.StreamOptions = r.Params.StreamOptions
|
||||
params.User = r.Params.User
|
||||
params.FrequencyPenalty = r.Params.FrequencyPenalty
|
||||
params.LogitBias = r.Params.LogitBias
|
||||
params.PresencePenalty = r.Params.PresencePenalty
|
||||
params.Seed = r.Params.Seed
|
||||
}
|
||||
return &BifrostChatRequest{
|
||||
Provider: r.Provider,
|
||||
Model: r.Model,
|
||||
Fallbacks: r.Fallbacks,
|
||||
Input: []ChatMessage{message},
|
||||
Params: ¶ms,
|
||||
}
|
||||
}
|
||||
|
||||
type BifrostTextCompletionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Choices []BifrostResponseChoice `json:"choices"`
|
||||
Model string `json:"model"`
|
||||
Object string `json:"object"` // "text_completion" (same for text completion stream)
|
||||
SystemFingerprint string `json:"system_fingerprint"`
|
||||
Usage *BifrostLLMUsage `json:"usage"`
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
|
||||
type TextCompletionInput struct {
|
||||
PromptStr *string
|
||||
PromptArray []string
|
||||
}
|
||||
|
||||
func (t *TextCompletionInput) MarshalJSON() ([]byte, error) {
|
||||
set := 0
|
||||
if t.PromptStr != nil {
|
||||
set++
|
||||
}
|
||||
if t.PromptArray != nil {
|
||||
set++
|
||||
}
|
||||
if set == 0 {
|
||||
return nil, fmt.Errorf("text completion input is empty")
|
||||
}
|
||||
if set > 1 {
|
||||
return nil, fmt.Errorf("text completion input must set exactly one of: prompt_str or prompt_array")
|
||||
}
|
||||
if t.PromptStr != nil {
|
||||
return MarshalSorted(*t.PromptStr)
|
||||
}
|
||||
return MarshalSorted(t.PromptArray)
|
||||
}
|
||||
|
||||
func (t *TextCompletionInput) UnmarshalJSON(data []byte) error {
|
||||
var prompt string
|
||||
if err := Unmarshal(data, &prompt); err == nil {
|
||||
t.PromptStr = &prompt
|
||||
t.PromptArray = nil
|
||||
return nil
|
||||
}
|
||||
var promptArray []string
|
||||
if err := Unmarshal(data, &promptArray); err == nil {
|
||||
t.PromptStr = nil
|
||||
t.PromptArray = promptArray
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("invalid text completion input")
|
||||
}
|
||||
|
||||
type TextCompletionParameters struct {
|
||||
BestOf *int `json:"best_of,omitempty"`
|
||||
Echo *bool `json:"echo,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
LogitBias *map[string]float64 `json:"logit_bias,omitempty"`
|
||||
LogProbs *int `json:"logprobs,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
N *int `json:"n,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
Seed *int `json:"seed,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
Suffix *string `json:"suffix,omitempty"`
|
||||
StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
User *string `json:"user,omitempty"`
|
||||
|
||||
// Dynamic parameters that can be provider-specific, they are directly
|
||||
// added to the request as is.
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// TextCompletionLogProb represents log probability information for text completion.
|
||||
type TextCompletionLogProb struct {
|
||||
TextOffset []int `json:"text_offset"`
|
||||
TokenLogProbs []float64 `json:"token_logprobs"`
|
||||
Tokens []string `json:"tokens"`
|
||||
TopLogProbs []map[string]float64 `json:"top_logprobs"`
|
||||
}
|
||||
390
core/schemas/trace.go
Normal file
390
core/schemas/trace.go
Normal file
@@ -0,0 +1,390 @@
|
||||
// Package schemas defines the core schemas and types used by the Bifrost system.
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Trace represents a distributed trace that captures the full lifecycle of a request
|
||||
type Trace struct {
|
||||
RequestID string // Request ID for the trace
|
||||
TraceID string // Unique identifier for this trace
|
||||
ParentID string // Parent trace ID from incoming W3C traceparent header
|
||||
RootSpan *Span // The root span of this trace
|
||||
Spans []*Span // All spans in this trace
|
||||
StartTime time.Time // When the trace started
|
||||
EndTime time.Time // When the trace completed
|
||||
Attributes map[string]any // Additional attributes for the trace
|
||||
PluginLogs []PluginLogEntry // Plugin log entries accumulated during request processing
|
||||
mu sync.Mutex // Mutex for thread-safe span operations
|
||||
}
|
||||
|
||||
// AddSpan adds a span to the trace in a thread-safe manner
|
||||
func (t *Trace) AddSpan(span *Span) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.Spans = append(t.Spans, span)
|
||||
}
|
||||
|
||||
// GetSpan retrieves a span by ID
|
||||
func (t *Trace) GetSpan(spanID string) *Span {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
for _, span := range t.Spans {
|
||||
if span.SpanID == spanID {
|
||||
return span
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRequestID retrieves the request ID from the trace
|
||||
func (t *Trace) GetRequestID() string {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return t.RequestID
|
||||
}
|
||||
|
||||
// SetRequestID sets the request ID for the trace
|
||||
func (t *Trace) SetRequestID(requestID string) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.RequestID = requestID
|
||||
}
|
||||
|
||||
// Reset clears the trace for reuse from pool
|
||||
func (t *Trace) Reset() {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.RequestID = ""
|
||||
t.TraceID = ""
|
||||
t.ParentID = ""
|
||||
t.RootSpan = nil
|
||||
for i := range t.Spans {
|
||||
t.Spans[i] = nil
|
||||
}
|
||||
t.Spans = t.Spans[:0]
|
||||
t.StartTime = time.Time{}
|
||||
t.EndTime = time.Time{}
|
||||
t.Attributes = nil
|
||||
for i := range t.PluginLogs {
|
||||
t.PluginLogs[i] = PluginLogEntry{}
|
||||
}
|
||||
t.PluginLogs = t.PluginLogs[:0]
|
||||
}
|
||||
|
||||
// AppendPluginLogs appends plugin log entries to the trace in a thread-safe manner.
|
||||
func (t *Trace) AppendPluginLogs(logs []PluginLogEntry) {
|
||||
if len(logs) == 0 {
|
||||
return
|
||||
}
|
||||
t.mu.Lock()
|
||||
t.PluginLogs = append(t.PluginLogs, logs...)
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
// Span represents a single operation within a trace
|
||||
type Span struct {
|
||||
SpanID string // Unique identifier for this span
|
||||
ParentID string // Parent span ID (empty for root span)
|
||||
TraceID string // The trace this span belongs to
|
||||
Name string // Name of the operation
|
||||
Kind SpanKind // Type of span (LLM call, plugin, etc.)
|
||||
StartTime time.Time // When the span started
|
||||
EndTime time.Time // When the span completed
|
||||
Status SpanStatus // Status of the operation
|
||||
StatusMsg string // Optional status message (for errors)
|
||||
Attributes map[string]any // Additional attributes for the span
|
||||
Events []SpanEvent // Events that occurred during the span
|
||||
mu sync.Mutex // Mutex for thread-safe attribute operations
|
||||
}
|
||||
|
||||
// SetAttribute sets an attribute on the span in a thread-safe manner
|
||||
func (s *Span) SetAttribute(key string, value any) {
|
||||
if value == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.Attributes == nil {
|
||||
s.Attributes = make(map[string]any)
|
||||
}
|
||||
s.Attributes[key] = value
|
||||
}
|
||||
|
||||
// AddEvent adds an event to the span in a thread-safe manner
|
||||
func (s *Span) AddEvent(event SpanEvent) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.Events = append(s.Events, event)
|
||||
}
|
||||
|
||||
// End marks the span as complete with the given status
|
||||
func (s *Span) End(status SpanStatus, statusMsg string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.EndTime = time.Now()
|
||||
s.Status = status
|
||||
s.StatusMsg = statusMsg
|
||||
}
|
||||
|
||||
// Reset clears the span for reuse from pool
|
||||
func (s *Span) Reset() {
|
||||
s.SpanID = ""
|
||||
s.ParentID = ""
|
||||
s.TraceID = ""
|
||||
s.Name = ""
|
||||
s.Kind = SpanKindUnspecified
|
||||
s.StartTime = time.Time{}
|
||||
s.EndTime = time.Time{}
|
||||
s.Status = SpanStatusUnset
|
||||
s.StatusMsg = ""
|
||||
s.Attributes = nil
|
||||
s.Events = s.Events[:0]
|
||||
}
|
||||
|
||||
// SpanEvent represents a time-stamped event within a span
|
||||
type SpanEvent struct {
|
||||
Name string // Name of the event
|
||||
Timestamp time.Time // When the event occurred
|
||||
Attributes map[string]any // Additional attributes for the event
|
||||
}
|
||||
|
||||
// SpanKind represents the type of operation a span represents
|
||||
// These are LLM-specific kinds designed for AI gateway observability
|
||||
type SpanKind string
|
||||
|
||||
const (
|
||||
// SpanKindUnspecified is the default span kind
|
||||
SpanKindUnspecified SpanKind = ""
|
||||
// SpanKindLLMCall represents a call to an LLM provider
|
||||
SpanKindLLMCall SpanKind = "llm.call"
|
||||
// SpanKindPlugin represents plugin execution (PreLLMHook/PostLLMHook)
|
||||
SpanKindPlugin SpanKind = "plugin"
|
||||
// SpanKindMCPTool represents an MCP tool invocation
|
||||
SpanKindMCPTool SpanKind = "mcp.tool"
|
||||
// SpanKindRetry represents a retry attempt
|
||||
SpanKindRetry SpanKind = "retry"
|
||||
// SpanKindFallback represents a fallback to another provider
|
||||
SpanKindFallback SpanKind = "fallback"
|
||||
// SpanKindHTTPRequest represents the root HTTP request span
|
||||
SpanKindHTTPRequest SpanKind = "http.request"
|
||||
// SpanKindEmbedding represents an embedding request
|
||||
SpanKindEmbedding SpanKind = "embedding"
|
||||
// SpanKindSpeech represents a text-to-speech request
|
||||
SpanKindSpeech SpanKind = "speech"
|
||||
// SpanKindTranscription represents a speech-to-text request
|
||||
SpanKindTranscription SpanKind = "transcription"
|
||||
// SpanKindInternal represents internal operations (key selection, etc.)
|
||||
SpanKindInternal SpanKind = "internal"
|
||||
)
|
||||
|
||||
// SpanStatus represents the status of a span's operation
|
||||
type SpanStatus string
|
||||
|
||||
const (
|
||||
// SpanStatusUnset indicates status has not been set
|
||||
SpanStatusUnset SpanStatus = "unset"
|
||||
// SpanStatusOk indicates the operation completed successfully
|
||||
SpanStatusOk SpanStatus = "ok"
|
||||
// SpanStatusError indicates the operation failed
|
||||
SpanStatusError SpanStatus = "error"
|
||||
)
|
||||
|
||||
// LLM Attribute Keys (gen_ai.* namespace)
|
||||
// These follow the OpenTelemetry semantic conventions for GenAI
|
||||
// and are compatible with both OTEL and Datadog backends.
|
||||
const (
|
||||
// Provider and Model Attributes
|
||||
AttrProviderName = "gen_ai.provider.name"
|
||||
AttrRequestModel = "gen_ai.request.model"
|
||||
|
||||
// Request Parameter Attributes
|
||||
AttrMaxTokens = "gen_ai.request.max_tokens"
|
||||
AttrTemperature = "gen_ai.request.temperature"
|
||||
AttrTopP = "gen_ai.request.top_p"
|
||||
AttrStopSequences = "gen_ai.request.stop_sequences"
|
||||
AttrPresencePenalty = "gen_ai.request.presence_penalty"
|
||||
AttrFrequencyPenalty = "gen_ai.request.frequency_penalty"
|
||||
AttrParallelToolCall = "gen_ai.request.parallel_tool_calls"
|
||||
AttrRequestUser = "gen_ai.request.user"
|
||||
AttrBestOf = "gen_ai.request.best_of"
|
||||
AttrEcho = "gen_ai.request.echo"
|
||||
AttrLogitBias = "gen_ai.request.logit_bias"
|
||||
AttrLogProbs = "gen_ai.request.logprobs"
|
||||
AttrN = "gen_ai.request.n"
|
||||
AttrSeed = "gen_ai.request.seed"
|
||||
AttrSuffix = "gen_ai.request.suffix"
|
||||
AttrDimensions = "gen_ai.request.dimensions"
|
||||
AttrEncodingFormat = "gen_ai.request.encoding_format"
|
||||
AttrLanguage = "gen_ai.request.language"
|
||||
AttrPrompt = "gen_ai.request.prompt"
|
||||
AttrResponseFormat = "gen_ai.request.response_format"
|
||||
AttrFormat = "gen_ai.request.format"
|
||||
AttrVoice = "gen_ai.request.voice"
|
||||
AttrMultiVoiceConfig = "gen_ai.request.multi_voice_config"
|
||||
AttrInstructions = "gen_ai.request.instructions"
|
||||
AttrSpeed = "gen_ai.request.speed"
|
||||
AttrMessageCount = "gen_ai.request.message_count"
|
||||
|
||||
// Response Attributes
|
||||
AttrResponseID = "gen_ai.response.id"
|
||||
AttrResponseModel = "gen_ai.response.model"
|
||||
AttrFinishReason = "gen_ai.response.finish_reason"
|
||||
AttrSystemFprint = "gen_ai.response.system_fingerprint"
|
||||
AttrServiceTier = "gen_ai.response.service_tier"
|
||||
AttrCreated = "gen_ai.response.created"
|
||||
AttrObject = "gen_ai.response.object"
|
||||
AttrTimeToFirstToken = "gen_ai.response.time_to_first_token"
|
||||
AttrTotalChunks = "gen_ai.response.total_chunks"
|
||||
|
||||
// Plugin Attributes (for aggregated streaming post-hook spans)
|
||||
AttrPluginInvocations = "plugin.invocation_count"
|
||||
AttrPluginAvgDurationMs = "plugin.avg_duration_ms"
|
||||
AttrPluginTotalDurationMs = "plugin.total_duration_ms"
|
||||
AttrPluginErrorCount = "plugin.error_count"
|
||||
|
||||
// Usage Attributes
|
||||
AttrPromptTokens = "gen_ai.usage.prompt_tokens"
|
||||
AttrCompletionTokens = "gen_ai.usage.completion_tokens"
|
||||
AttrTotalTokens = "gen_ai.usage.total_tokens"
|
||||
AttrInputTokens = "gen_ai.usage.input_tokens"
|
||||
AttrOutputTokens = "gen_ai.usage.output_tokens"
|
||||
AttrUsageCost = "gen_ai.usage.cost"
|
||||
// Chat completion usage detail attributes
|
||||
AttrPromptTokenDetailsText = "gen_ai.usage.prompt_token_details.text_tokens"
|
||||
AttrPromptTokenDetailsAudio = "gen_ai.usage.prompt_token_details.audio_tokens"
|
||||
AttrPromptTokenDetailsImage = "gen_ai.usage.prompt_token_details.image_tokens"
|
||||
AttrPromptTokenDetailsCachedRead = "gen_ai.usage.prompt_token_details.cached_read_tokens"
|
||||
AttrPromptTokenDetailsCachedWrite = "gen_ai.usage.prompt_token_details.cached_write_tokens"
|
||||
AttrCompletionTokenDetailsText = "gen_ai.usage.completion_token_details.text_tokens"
|
||||
AttrCompletionTokenDetailsAudio = "gen_ai.usage.completion_token_details.audio_tokens"
|
||||
AttrCompletionTokenDetailsImage = "gen_ai.usage.completion_token_details.image_tokens"
|
||||
AttrCompletionTokenDetailsReason = "gen_ai.usage.completion_token_details.reasoning_tokens"
|
||||
AttrCompletionTokenDetailsAccept = "gen_ai.usage.completion_token_details.accepted_prediction_tokens"
|
||||
AttrCompletionTokenDetailsReject = "gen_ai.usage.completion_token_details.rejected_prediction_tokens"
|
||||
AttrCompletionTokenDetailsCite = "gen_ai.usage.completion_token_details.citation_tokens"
|
||||
AttrCompletionTokenDetailsSearch = "gen_ai.usage.completion_token_details.num_search_queries"
|
||||
|
||||
// Error Attributes
|
||||
AttrError = "gen_ai.error"
|
||||
AttrErrorType = "gen_ai.error.type"
|
||||
AttrErrorCode = "gen_ai.error.code"
|
||||
|
||||
// Input/Output Attributes
|
||||
AttrInputText = "gen_ai.input.text"
|
||||
AttrInputMessages = "gen_ai.input.messages"
|
||||
AttrInputSpeech = "gen_ai.input.speech"
|
||||
AttrInputEmbedding = "gen_ai.input.embedding"
|
||||
AttrOutputMessages = "gen_ai.output.messages"
|
||||
|
||||
// Bifrost Context Attributes
|
||||
AttrVirtualKeyID = "gen_ai.virtual_key_id"
|
||||
AttrVirtualKeyName = "gen_ai.virtual_key_name"
|
||||
AttrSelectedKeyID = "gen_ai.selected_key_id"
|
||||
AttrSelectedKeyName = "gen_ai.selected_key_name"
|
||||
AttrRoutingRuleID = "gen_ai.routing_rule_id"
|
||||
AttrRoutingRuleName = "gen_ai.routing_rule_name"
|
||||
AttrTeamID = "gen_ai.team_id"
|
||||
AttrTeamName = "gen_ai.team_name"
|
||||
AttrCustomerID = "gen_ai.customer_id"
|
||||
AttrCustomerName = "gen_ai.customer_name"
|
||||
AttrNumberOfRetries = "gen_ai.number_of_retries"
|
||||
AttrFallbackIndex = "gen_ai.fallback_index"
|
||||
|
||||
// Responses API Request Attributes
|
||||
AttrPromptCacheKey = "gen_ai.request.prompt_cache_key"
|
||||
AttrReasoningEffort = "gen_ai.request.reasoning_effort"
|
||||
AttrReasoningSummary = "gen_ai.request.reasoning_summary"
|
||||
AttrReasoningGenSummary = "gen_ai.request.reasoning_generate_summary"
|
||||
AttrSafetyIdentifier = "gen_ai.request.safety_identifier"
|
||||
AttrStore = "gen_ai.request.store"
|
||||
AttrTextVerbosity = "gen_ai.request.text_verbosity"
|
||||
AttrTextFormatType = "gen_ai.request.text_format_type"
|
||||
AttrTopLogProbs = "gen_ai.request.top_logprobs"
|
||||
AttrToolChoiceType = "gen_ai.request.tool_choice_type"
|
||||
AttrToolChoiceName = "gen_ai.request.tool_choice_name"
|
||||
AttrTools = "gen_ai.request.tools"
|
||||
AttrTruncation = "gen_ai.request.truncation"
|
||||
|
||||
// Responses API Response Attributes
|
||||
AttrRespInclude = "gen_ai.responses.include"
|
||||
AttrRespMaxOutputTokens = "gen_ai.responses.max_output_tokens"
|
||||
AttrRespMaxToolCalls = "gen_ai.responses.max_tool_calls"
|
||||
AttrRespMetadata = "gen_ai.responses.metadata"
|
||||
AttrRespPreviousRespID = "gen_ai.responses.previous_response_id"
|
||||
AttrRespPromptCacheKey = "gen_ai.responses.prompt_cache_key"
|
||||
AttrRespReasoningText = "gen_ai.responses.reasoning"
|
||||
AttrRespReasoningEffort = "gen_ai.responses.reasoning_effort"
|
||||
AttrRespReasoningGenSum = "gen_ai.responses.reasoning_generate_summary"
|
||||
AttrRespSafetyIdentifier = "gen_ai.responses.safety_identifier"
|
||||
AttrRespStore = "gen_ai.responses.store"
|
||||
AttrRespTemperature = "gen_ai.responses.temperature"
|
||||
AttrRespTextVerbosity = "gen_ai.responses.text_verbosity"
|
||||
AttrRespTextFormatType = "gen_ai.responses.text_format_type"
|
||||
AttrRespTopLogProbs = "gen_ai.responses.top_logprobs"
|
||||
AttrRespTopP = "gen_ai.responses.top_p"
|
||||
AttrRespToolChoiceType = "gen_ai.responses.tool_choice_type"
|
||||
AttrRespToolChoiceName = "gen_ai.responses.tool_choice_name"
|
||||
AttrRespTruncation = "gen_ai.responses.truncation"
|
||||
AttrRespTools = "gen_ai.responses.tools"
|
||||
|
||||
// Batch Operation Attributes
|
||||
AttrBatchID = "gen_ai.batch.id"
|
||||
AttrBatchStatus = "gen_ai.batch.status"
|
||||
AttrBatchObject = "gen_ai.batch.object"
|
||||
AttrBatchEndpoint = "gen_ai.batch.endpoint"
|
||||
AttrBatchInputFileID = "gen_ai.batch.input_file_id"
|
||||
AttrBatchOutputFileID = "gen_ai.batch.output_file_id"
|
||||
AttrBatchErrorFileID = "gen_ai.batch.error_file_id"
|
||||
AttrBatchCompletionWin = "gen_ai.batch.completion_window"
|
||||
AttrBatchCreatedAt = "gen_ai.batch.created_at"
|
||||
AttrBatchExpiresAt = "gen_ai.batch.expires_at"
|
||||
AttrBatchRequestsCount = "gen_ai.batch.requests_count"
|
||||
AttrBatchDataCount = "gen_ai.batch.data_count"
|
||||
AttrBatchResultsCount = "gen_ai.batch.results_count"
|
||||
AttrBatchHasMore = "gen_ai.batch.has_more"
|
||||
AttrBatchMetadata = "gen_ai.batch.metadata"
|
||||
AttrBatchLimit = "gen_ai.batch.limit"
|
||||
AttrBatchAfter = "gen_ai.batch.after"
|
||||
AttrBatchBeforeID = "gen_ai.batch.before_id"
|
||||
AttrBatchAfterID = "gen_ai.batch.after_id"
|
||||
AttrBatchPageToken = "gen_ai.batch.page_token"
|
||||
AttrBatchPageSize = "gen_ai.batch.page_size"
|
||||
AttrBatchCountTotal = "gen_ai.batch.request_counts.total"
|
||||
AttrBatchCountCompleted = "gen_ai.batch.request_counts.completed"
|
||||
AttrBatchCountFailed = "gen_ai.batch.request_counts.failed"
|
||||
AttrBatchFirstID = "gen_ai.batch.first_id"
|
||||
AttrBatchLastID = "gen_ai.batch.last_id"
|
||||
AttrBatchInProgressAt = "gen_ai.batch.in_progress_at"
|
||||
AttrBatchFinalizingAt = "gen_ai.batch.finalizing_at"
|
||||
AttrBatchCompletedAt = "gen_ai.batch.completed_at"
|
||||
AttrBatchFailedAt = "gen_ai.batch.failed_at"
|
||||
AttrBatchExpiredAt = "gen_ai.batch.expired_at"
|
||||
AttrBatchCancellingAt = "gen_ai.batch.cancelling_at"
|
||||
AttrBatchCancelledAt = "gen_ai.batch.cancelled_at"
|
||||
AttrBatchNextCursor = "gen_ai.batch.next_cursor"
|
||||
|
||||
// Transcription Response Attributes
|
||||
AttrInputTokenDetailsText = "gen_ai.usage.input_token_details.text_tokens"
|
||||
AttrInputTokenDetailsAudio = "gen_ai.usage.input_token_details.audio_tokens"
|
||||
|
||||
// File Operation Attributes
|
||||
AttrFileID = "gen_ai.file.id"
|
||||
AttrFileObject = "gen_ai.file.object"
|
||||
AttrFileFilename = "gen_ai.file.filename"
|
||||
AttrFilePurpose = "gen_ai.file.purpose"
|
||||
AttrFileBytes = "gen_ai.file.bytes"
|
||||
AttrFileCreatedAt = "gen_ai.file.created_at"
|
||||
AttrFileStatus = "gen_ai.file.status"
|
||||
AttrFileStorageBackend = "gen_ai.file.storage_backend"
|
||||
AttrFileDataCount = "gen_ai.file.data_count"
|
||||
AttrFileHasMore = "gen_ai.file.has_more"
|
||||
AttrFileDeleted = "gen_ai.file.deleted"
|
||||
AttrFileContentType = "gen_ai.file.content_type"
|
||||
AttrFileContentBytes = "gen_ai.file.content_bytes"
|
||||
AttrFileLimit = "gen_ai.file.limit"
|
||||
AttrFileAfter = "gen_ai.file.after"
|
||||
AttrFileOrder = "gen_ai.file.order"
|
||||
)
|
||||
204
core/schemas/tracer.go
Normal file
204
core/schemas/tracer.go
Normal file
@@ -0,0 +1,204 @@
|
||||
// Package schemas defines the core schemas and types used by the Bifrost system.
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SpanHandle is an opaque handle to a span, implementation-specific.
|
||||
// Different Tracer implementations can use their own concrete types.
|
||||
type SpanHandle interface{}
|
||||
|
||||
// StreamAccumulatorResult contains the accumulated data from streaming chunks.
|
||||
// This is the return type for tracer's streaming accumulation methods.
|
||||
type StreamAccumulatorResult struct {
|
||||
RequestID string // Request ID
|
||||
RequestedModel string // Original model requested by the caller
|
||||
ResolvedModel string // Actual model used by the provider (equals RequestedModel when no alias mapping exists)
|
||||
Provider ModelProvider // Provider used
|
||||
Status string // Status of the stream
|
||||
Latency int64 // Latency in milliseconds
|
||||
TimeToFirstToken int64 // Time to first token in milliseconds
|
||||
OutputMessage *ChatMessage // Accumulated output message
|
||||
OutputMessages []ResponsesMessage // For responses API
|
||||
TokenUsage *BifrostLLMUsage // Token usage
|
||||
Cost *float64 // Cost in dollars
|
||||
ErrorDetails *BifrostError // Error details if any
|
||||
AudioOutput *BifrostSpeechResponse // For speech streaming
|
||||
TranscriptionOutput *BifrostTranscriptionResponse // For transcription streaming
|
||||
ImageGenerationOutput *BifrostImageGenerationResponse // For image generation streaming
|
||||
FinishReason *string // Finish reason
|
||||
RawResponse *string // Raw response
|
||||
RawRequest interface{} // Raw request
|
||||
}
|
||||
|
||||
// Tracer defines the interface for distributed tracing in Bifrost.
|
||||
// Implementations can be injected via BifrostConfig to enable automatic instrumentation.
|
||||
// The interface is designed to be minimal and implementation-agnostic.
|
||||
type Tracer interface {
|
||||
// CreateTrace creates a new trace with optional parent ID and returns the trace ID.
|
||||
// The parentID can be extracted from W3C traceparent headers for distributed tracing.
|
||||
// The requestID is optional and can be used to identify the request.
|
||||
CreateTrace(parentID string, requestID ...string) string
|
||||
|
||||
// EndTrace completes a trace and returns the trace data for observation/export.
|
||||
// After this call, the trace is removed from active tracking and returned for cleanup.
|
||||
// Returns nil if trace not found.
|
||||
EndTrace(traceID string) *Trace
|
||||
|
||||
// StartSpan creates a new span as a child of the current span in context.
|
||||
// Returns updated context with new span and a handle for the span.
|
||||
// The context should be used for subsequent operations to maintain span hierarchy.
|
||||
StartSpan(ctx context.Context, name string, kind SpanKind) (context.Context, SpanHandle)
|
||||
|
||||
// EndSpan completes a span with status and optional message.
|
||||
// Should be called when the operation represented by the span is complete.
|
||||
EndSpan(handle SpanHandle, status SpanStatus, statusMsg string)
|
||||
|
||||
// SetAttribute sets an attribute on the span.
|
||||
// Attributes provide additional context about the operation.
|
||||
SetAttribute(handle SpanHandle, key string, value any)
|
||||
|
||||
// AddEvent adds a timestamped event to the span.
|
||||
// Events represent discrete occurrences during the span's lifetime.
|
||||
AddEvent(handle SpanHandle, name string, attrs map[string]any)
|
||||
|
||||
// PopulateLLMRequestAttributes populates all LLM-specific request attributes on the span.
|
||||
// This includes model parameters, input messages, temperature, max tokens, etc.
|
||||
PopulateLLMRequestAttributes(handle SpanHandle, req *BifrostRequest)
|
||||
|
||||
// PopulateLLMResponseAttributes populates all LLM-specific response attributes on the span.
|
||||
// This includes output messages, tokens, usage stats, and error information if present.
|
||||
PopulateLLMResponseAttributes(ctx *BifrostContext, handle SpanHandle, resp *BifrostResponse, err *BifrostError)
|
||||
|
||||
// StoreDeferredSpan stores a span handle for later completion (used for streaming requests).
|
||||
// The span handle is stored keyed by trace ID so it can be retrieved when the stream completes.
|
||||
StoreDeferredSpan(traceID string, handle SpanHandle)
|
||||
|
||||
// GetDeferredSpanHandle retrieves a deferred span handle by trace ID.
|
||||
// Returns nil if no deferred span exists for the given trace ID.
|
||||
GetDeferredSpanHandle(traceID string) SpanHandle
|
||||
|
||||
// ClearDeferredSpan removes the deferred span handle for a trace ID.
|
||||
// Should be called after the deferred span has been completed.
|
||||
ClearDeferredSpan(traceID string)
|
||||
|
||||
// GetDeferredSpanID returns the span ID for the deferred span.
|
||||
// Returns empty string if no deferred span exists.
|
||||
GetDeferredSpanID(traceID string) string
|
||||
|
||||
// AddStreamingChunk accumulates a streaming chunk for the deferred span.
|
||||
// Pass the full BifrostResponse to capture content, tool calls, reasoning, etc.
|
||||
// This is called for each streaming chunk to build up the complete response.
|
||||
AddStreamingChunk(traceID string, response *BifrostResponse)
|
||||
|
||||
// GetAccumulatedChunks returns the accumulated response, TTFT, and chunk count for a deferred span.
|
||||
// The response is built from the streaming accumulator during the final ProcessStreamingChunk call.
|
||||
// Returns nil response if no plugin has called ProcessStreamingChunk (callers should nil-check).
|
||||
// Returns nil, 0, 0 if no accumulated data exists.
|
||||
GetAccumulatedChunks(traceID string) (response *BifrostResponse, ttftNs int64, chunkCount int)
|
||||
|
||||
// CreateStreamAccumulator creates a new stream accumulator for the given trace ID.
|
||||
// This should be called at the start of a streaming request.
|
||||
CreateStreamAccumulator(traceID string, startTime time.Time)
|
||||
|
||||
// CleanupStreamAccumulator removes the stream accumulator for the given trace ID.
|
||||
// This should be called after the streaming request is complete.
|
||||
CleanupStreamAccumulator(traceID string)
|
||||
|
||||
// ProcessStreamingChunk processes a streaming chunk and accumulates it.
|
||||
// Returns the accumulated result. IsFinal will be true when the stream is complete.
|
||||
// This method is used by plugins to access accumulated streaming data.
|
||||
// The ctx parameter must contain the stream end indicator for proper final chunk detection.
|
||||
ProcessStreamingChunk(traceID string, isFinalChunk bool, result *BifrostResponse, err *BifrostError) *StreamAccumulatorResult
|
||||
|
||||
// AttachPluginLogs appends plugin log entries to the trace identified by traceID.
|
||||
// Thread-safe. Should be called after plugin hooks complete, before trace completion.
|
||||
AttachPluginLogs(traceID string, logs []PluginLogEntry)
|
||||
|
||||
// CompleteAndFlushTrace ends a trace, exports it to observability plugins, and
|
||||
// releases the trace resources. Used by transports that bypass normal HTTP trace completion.
|
||||
CompleteAndFlushTrace(traceID string)
|
||||
|
||||
// Stop releases resources associated with the tracer.
|
||||
// Should be called during shutdown to stop background goroutines.
|
||||
Stop()
|
||||
}
|
||||
|
||||
// NoOpTracer is a tracer that does nothing (default when tracing disabled).
|
||||
// It satisfies the Tracer interface but performs no actual tracing operations.
|
||||
type NoOpTracer struct{}
|
||||
|
||||
// CreateTrace returns an empty string (no trace created).
|
||||
func (n *NoOpTracer) CreateTrace(_ string, _ ...string) string { return "" }
|
||||
|
||||
// EndTrace returns nil (no trace to end).
|
||||
func (n *NoOpTracer) EndTrace(_ string) *Trace { return nil }
|
||||
|
||||
// StartSpan returns the context unchanged and a nil handle.
|
||||
func (n *NoOpTracer) StartSpan(ctx context.Context, _ string, _ SpanKind) (context.Context, SpanHandle) {
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
// EndSpan does nothing.
|
||||
func (n *NoOpTracer) EndSpan(_ SpanHandle, _ SpanStatus, _ string) {}
|
||||
|
||||
// SetAttribute does nothing.
|
||||
func (n *NoOpTracer) SetAttribute(_ SpanHandle, _ string, _ any) {}
|
||||
|
||||
// AddEvent does nothing.
|
||||
func (n *NoOpTracer) AddEvent(_ SpanHandle, _ string, _ map[string]any) {}
|
||||
|
||||
// PopulateLLMRequestAttributes does nothing.
|
||||
func (n *NoOpTracer) PopulateLLMRequestAttributes(_ SpanHandle, _ *BifrostRequest) {}
|
||||
|
||||
// PopulateLLMResponseAttributes does nothing.
|
||||
func (n *NoOpTracer) PopulateLLMResponseAttributes(_ *BifrostContext, _ SpanHandle, _ *BifrostResponse, _ *BifrostError) {
|
||||
}
|
||||
|
||||
// StoreDeferredSpan does nothing.
|
||||
func (n *NoOpTracer) StoreDeferredSpan(_ string, _ SpanHandle) {}
|
||||
|
||||
// GetDeferredSpanHandle returns nil.
|
||||
func (n *NoOpTracer) GetDeferredSpanHandle(_ string) SpanHandle { return nil }
|
||||
|
||||
// ClearDeferredSpan does nothing.
|
||||
func (n *NoOpTracer) ClearDeferredSpan(_ string) {}
|
||||
|
||||
// GetDeferredSpanID returns empty string.
|
||||
func (n *NoOpTracer) GetDeferredSpanID(_ string) string { return "" }
|
||||
|
||||
// AddStreamingChunk does nothing.
|
||||
func (n *NoOpTracer) AddStreamingChunk(_ string, _ *BifrostResponse) {}
|
||||
|
||||
// GetAccumulatedChunks returns nil, 0, 0.
|
||||
func (n *NoOpTracer) GetAccumulatedChunks(_ string) (*BifrostResponse, int64, int) { return nil, 0, 0 }
|
||||
|
||||
// CreateStreamAccumulator does nothing.
|
||||
func (n *NoOpTracer) CreateStreamAccumulator(_ string, _ time.Time) {}
|
||||
|
||||
// CleanupStreamAccumulator does nothing.
|
||||
func (n *NoOpTracer) CleanupStreamAccumulator(_ string) {}
|
||||
|
||||
// ProcessStreamingChunk returns nil.
|
||||
func (n *NoOpTracer) ProcessStreamingChunk(_ string, _ bool, _ *BifrostResponse, _ *BifrostError) *StreamAccumulatorResult {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AttachPluginLogs does nothing.
|
||||
func (n *NoOpTracer) AttachPluginLogs(_ string, _ []PluginLogEntry) {}
|
||||
|
||||
// CompleteAndFlushTrace does nothing.
|
||||
func (n *NoOpTracer) CompleteAndFlushTrace(_ string) {}
|
||||
|
||||
// Stop does nothing.
|
||||
func (n *NoOpTracer) Stop() {}
|
||||
|
||||
// DefaultTracer returns a no-op tracer for use when tracing is disabled.
|
||||
func DefaultTracer() Tracer {
|
||||
return &NoOpTracer{}
|
||||
}
|
||||
|
||||
// Ensure NoOpTracer implements Tracer at compile time
|
||||
var _ Tracer = (*NoOpTracer)(nil)
|
||||
156
core/schemas/transcriptions.go
Normal file
156
core/schemas/transcriptions.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package schemas
|
||||
|
||||
type BifrostTranscriptionRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
Input *TranscriptionInput `json:"input,omitempty"`
|
||||
Params *TranscriptionParameters `json:"params,omitempty"`
|
||||
Fallbacks []Fallback `json:"fallbacks,omitempty"`
|
||||
RawRequestBody []byte `json:"-"` // set bifrost-use-raw-request-body to true in ctx to use the raw request body. Bifrost will directly send this to the downstream provider.
|
||||
}
|
||||
|
||||
func (r *BifrostTranscriptionRequest) GetRawRequestBody() []byte {
|
||||
return r.RawRequestBody
|
||||
}
|
||||
|
||||
type BifrostTranscriptionResponse struct {
|
||||
Duration *float64 `json:"duration,omitempty"` // Duration in seconds
|
||||
Language *string `json:"language,omitempty"` // e.g., "english"
|
||||
LogProbs []TranscriptionLogProb `json:"logprobs,omitempty"`
|
||||
Segments []TranscriptionSegment `json:"segments,omitempty"`
|
||||
Task *string `json:"task,omitempty"` // e.g., "transcribe"
|
||||
Text string `json:"text"`
|
||||
Usage *TranscriptionUsage `json:"usage,omitempty"`
|
||||
Words []TranscriptionWord `json:"words,omitempty"`
|
||||
ResponseFormat *string `json:"-"` // Set by provider for non-JSON formats (text, srt, vtt); used by integration response converters
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
|
||||
func (r *BifrostTranscriptionResponse) BackfillParams(req *BifrostTranscriptionRequest) {
|
||||
if r == nil || req == nil || req.Params == nil || req.Params.ResponseFormat == nil {
|
||||
return
|
||||
}
|
||||
r.ResponseFormat = req.Params.ResponseFormat
|
||||
}
|
||||
|
||||
// IsPlainTextTranscriptionFormat returns true if the given response format
|
||||
// produces a plain-text response body (not JSON).
|
||||
func IsPlainTextTranscriptionFormat(format *string) bool {
|
||||
if format == nil {
|
||||
return false
|
||||
}
|
||||
switch *format {
|
||||
case "text", "srt", "vtt":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
type TranscriptionInput struct {
|
||||
File []byte `json:"file"`
|
||||
Filename string `json:"filename,omitempty"` // Original filename, used to preserve file format extension
|
||||
}
|
||||
|
||||
type TranscriptionParameters struct {
|
||||
Language *string `json:"language,omitempty"`
|
||||
Prompt *string `json:"prompt,omitempty"`
|
||||
ResponseFormat *string `json:"response_format,omitempty"` // Default is "json"
|
||||
Temperature *float64 `json:"temperature,omitempty"` // Sampling temperature (0.0-1.0)
|
||||
TimestampGranularities []string `json:"timestamp_granularities,omitempty"` // "word" and/or "segment"; requires response_format=verbose_json
|
||||
Include []string `json:"include,omitempty"` // Additional response info (e.g., logprobs)
|
||||
Format *string `json:"file_format,omitempty"` // Type of file, not required in openai, but required in gemini
|
||||
MaxLength *int `json:"max_length,omitempty"` // Maximum length of the transcription used by HuggingFace
|
||||
MinLength *int `json:"min_length,omitempty"` // Minimum length of the transcription used by HuggingFace
|
||||
MaxNewTokens *int `json:"max_new_tokens,omitempty"` // Maximum new tokens to generate used by HuggingFace
|
||||
MinNewTokens *int `json:"min_new_tokens,omitempty"` // Minimum new tokens to generate used by HuggingFace
|
||||
|
||||
// Elevenlabs-specific fields
|
||||
AdditionalFormats []TranscriptionAdditionalFormat `json:"additional_formats,omitempty"`
|
||||
WebhookMetadata interface{} `json:"webhook_metadata,omitempty"`
|
||||
|
||||
// Dynamic parameters that can be provider-specific, they are directly
|
||||
// added to the request as is.
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
type TranscriptionAdditionalFormat struct {
|
||||
Format TranscriptionExportOptions `json:"format"`
|
||||
IncludeSpeakers *bool `json:"include_speakers,omitempty"`
|
||||
IncludeTimestamps *bool `json:"include_timestamps,omitempty"`
|
||||
SegmentOnSilenceLongerThanS *float64 `json:"segment_on_silence_longer_than_s,omitempty"`
|
||||
MaxSegmentDurationS *float64 `json:"max_segment_duration_s,omitempty"`
|
||||
MaxSegmentChars *int `json:"max_segment_chars,omitempty"`
|
||||
MaxCharactersPerLine *int `json:"max_characters_per_line,omitempty"`
|
||||
}
|
||||
|
||||
type TranscriptionExportOptions string
|
||||
|
||||
const (
|
||||
TranscriptionExportOptionsSegmentedJson TranscriptionExportOptions = "segmented_json"
|
||||
TranscriptionExportOptionsDocx TranscriptionExportOptions = "docx"
|
||||
TranscriptionExportOptionsPdf TranscriptionExportOptions = "pdf"
|
||||
TranscriptionExportOptionsTxt TranscriptionExportOptions = "txt"
|
||||
TranscriptionExportOptionsHtml TranscriptionExportOptions = "html"
|
||||
TranscriptionExportOptionsSrt TranscriptionExportOptions = "srt"
|
||||
)
|
||||
|
||||
// TranscriptionLogProb represents log probability information for transcription
|
||||
type TranscriptionLogProb struct {
|
||||
Token string `json:"token"`
|
||||
LogProb float64 `json:"logprob"`
|
||||
Bytes []int `json:"bytes"`
|
||||
}
|
||||
|
||||
// TranscriptionWord represents word-level timing information
|
||||
type TranscriptionWord struct {
|
||||
Word string `json:"word"`
|
||||
Start float64 `json:"start"`
|
||||
End float64 `json:"end"`
|
||||
}
|
||||
|
||||
// TranscriptionSegment represents segment-level transcription information
|
||||
type TranscriptionSegment struct {
|
||||
ID int `json:"id"`
|
||||
Seek int `json:"seek"`
|
||||
Start float64 `json:"start"`
|
||||
End float64 `json:"end"`
|
||||
Text string `json:"text"`
|
||||
Tokens []int `json:"tokens"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
AvgLogProb float64 `json:"avg_logprob"`
|
||||
CompressionRatio float64 `json:"compression_ratio"`
|
||||
NoSpeechProb float64 `json:"no_speech_prob"`
|
||||
}
|
||||
|
||||
// TranscriptionUsage represents usage information for transcription
|
||||
type TranscriptionUsage struct {
|
||||
Type string `json:"type"` // "tokens" or "duration"
|
||||
InputTokens *int `json:"input_tokens,omitempty"`
|
||||
InputTokenDetails *TranscriptionUsageInputTokenDetails `json:"input_token_details,omitempty"`
|
||||
OutputTokens *int `json:"output_tokens,omitempty"`
|
||||
TotalTokens *int `json:"total_tokens,omitempty"`
|
||||
Seconds *int `json:"seconds,omitempty"` // For duration-based usage
|
||||
}
|
||||
|
||||
type TranscriptionUsageInputTokenDetails struct {
|
||||
TextTokens int `json:"text_tokens"`
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
}
|
||||
|
||||
type TranscriptionStreamResponseType string
|
||||
|
||||
const (
|
||||
TranscriptionStreamResponseTypeDelta TranscriptionStreamResponseType = "transcript.text.delta"
|
||||
TranscriptionStreamResponseTypeDone TranscriptionStreamResponseType = "transcript.text.done"
|
||||
)
|
||||
|
||||
// BifrostTranscriptionStreamResponse represents streaming specific fields only
|
||||
type BifrostTranscriptionStreamResponse struct {
|
||||
Delta *string `json:"delta,omitempty"` // For delta events
|
||||
LogProbs []TranscriptionLogProb `json:"logprobs,omitempty"`
|
||||
Text string `json:"text"`
|
||||
Type TranscriptionStreamResponseType `json:"type"`
|
||||
Usage *TranscriptionUsage `json:"usage,omitempty"`
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
72
core/schemas/useragents.go
Normal file
72
core/schemas/useragents.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package schemas
|
||||
|
||||
import "strings"
|
||||
|
||||
// UserAgentIdentifiers lists substrings that may appear in User-Agent for a given integration.
|
||||
// Versions of the same client may use different strings; Matches checks any of them.
|
||||
type UserAgentIdentifiers []string
|
||||
|
||||
var (
|
||||
// ClaudeCLI — Anthropic Claude Code / Claude CLI (identifiers vary by release).
|
||||
ClaudeCLI = UserAgentIdentifiers{"claude-cli", "claude-code", "claude-vscode"}
|
||||
GeminiCLI = UserAgentIdentifiers{"geminicli"}
|
||||
CodexCLI = UserAgentIdentifiers{"codex-tui"}
|
||||
QwenCodeCLI = UserAgentIdentifiers{"qwencode"}
|
||||
OpenCode = UserAgentIdentifiers{"opencode"}
|
||||
Cursor = UserAgentIdentifiers{"cursor"}
|
||||
)
|
||||
|
||||
// integrationUserAgents is the set of known client User-Agent patterns we persist on the context.
|
||||
var integrationUserAgents = []UserAgentIdentifiers{
|
||||
ClaudeCLI, GeminiCLI, CodexCLI, QwenCodeCLI, OpenCode, Cursor,
|
||||
}
|
||||
|
||||
// Matches reports whether userAgent contains any identifier (case-insensitive substring match).
|
||||
func (ids UserAgentIdentifiers) Matches(userAgent string) bool {
|
||||
if len(ids) == 0 || userAgent == "" {
|
||||
return false
|
||||
}
|
||||
ua := strings.ToLower(userAgent)
|
||||
for _, id := range ids {
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(ua, strings.ToLower(id)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// String returns the first identifier for logging and tests that need a canonical sample value.
|
||||
func (ids UserAgentIdentifiers) String() string {
|
||||
if len(ids) == 0 {
|
||||
return ""
|
||||
}
|
||||
return ids[0]
|
||||
}
|
||||
|
||||
func ExtractAndSetUserAgentFromHeaders(headers map[string][]string, bifrostCtx *BifrostContext) {
|
||||
if len(headers) == 0 {
|
||||
return
|
||||
}
|
||||
if bifrostCtx == nil {
|
||||
return
|
||||
}
|
||||
var userAgent []string
|
||||
for key, value := range headers {
|
||||
if strings.EqualFold(key, "user-agent") {
|
||||
userAgent = value
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(userAgent) > 0 {
|
||||
ua := userAgent[0]
|
||||
for _, ids := range integrationUserAgents {
|
||||
if ids.Matches(ua) {
|
||||
bifrostCtx.SetValue(BifrostContextKeyUserAgent, ua)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
1398
core/schemas/utils.go
Normal file
1398
core/schemas/utils.go
Normal file
File diff suppressed because it is too large
Load Diff
247
core/schemas/videos.go
Normal file
247
core/schemas/videos.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package schemas
|
||||
|
||||
// VideoStatus is the lifecycle status of a video job.
|
||||
type VideoStatus string
|
||||
|
||||
const (
|
||||
VideoStatusQueued VideoStatus = "queued"
|
||||
VideoStatusInProgress VideoStatus = "in_progress"
|
||||
VideoStatusCompleted VideoStatus = "completed"
|
||||
VideoStatusFailed VideoStatus = "failed"
|
||||
)
|
||||
|
||||
type VideoOutputType string
|
||||
|
||||
const (
|
||||
VideoOutputTypeBase64 VideoOutputType = "base64"
|
||||
VideoOutputTypeURL VideoOutputType = "url"
|
||||
)
|
||||
|
||||
// VideoCreateError is the error payload when video generation fails.
|
||||
type VideoCreateError struct {
|
||||
Code string `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// ContentFilterInfo contains information about content that was filtered due to safety policies.
|
||||
// This is a provider-agnostic structure for representing content filtering results.
|
||||
type ContentFilterInfo struct {
|
||||
FilteredCount int `json:"filtered_count,omitempty"` // Number of items filtered
|
||||
Reasons []string `json:"reasons,omitempty"` // Human-readable reasons for filtering
|
||||
}
|
||||
|
||||
type VideoOutput struct {
|
||||
Type VideoOutputType `json:"type"` // "url" | "base64"
|
||||
URL *string `json:"url,omitempty"`
|
||||
Base64Data *string `json:"base64,omitempty"`
|
||||
ContentType string `json:"content_type"`
|
||||
}
|
||||
|
||||
// VideoReferenceInput represents a reference image for video generation
|
||||
type VideoReferenceInput struct {
|
||||
Image []byte `json:"image"` // Image bytes
|
||||
ReferenceType string `json:"reference_type,omitempty"` // "style" or "asset" (Gemini: "REFERENCE_TYPE_STYLE" or "REFERENCE_TYPE_ASSET")
|
||||
}
|
||||
|
||||
type VideoObject struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"` // always "video"
|
||||
Model string `json:"model"`
|
||||
Status VideoStatus `json:"status"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
CompletedAt *int64 `json:"completed_at,omitempty"`
|
||||
ExpiresAt *int64 `json:"expires_at,omitempty"`
|
||||
Progress *float64 `json:"progress,omitempty"`
|
||||
Prompt string `json:"prompt"`
|
||||
RemixedFromVideoID *string `json:"remixed_from_video_id,omitempty"`
|
||||
Seconds *string `json:"seconds"`
|
||||
Size string `json:"size"`
|
||||
Error *VideoCreateError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// --- Video Generation ---
|
||||
|
||||
type BifrostVideoGenerationRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
Input *VideoGenerationInput `json:"input"`
|
||||
Params *VideoGenerationParameters `json:"params,omitempty"`
|
||||
Fallbacks []Fallback `json:"fallbacks,omitempty"`
|
||||
RawRequestBody []byte `json:"-"`
|
||||
}
|
||||
|
||||
func (b *BifrostVideoGenerationRequest) GetRawRequestBody() []byte {
|
||||
return b.RawRequestBody
|
||||
}
|
||||
|
||||
func (b *BifrostVideoGenerationRequest) GetExtraParams() map[string]interface{} {
|
||||
if b == nil || b.Params == nil {
|
||||
return nil
|
||||
}
|
||||
return b.Params.ExtraParams
|
||||
}
|
||||
|
||||
type VideoGenerationInput struct {
|
||||
Prompt string `json:"prompt"`
|
||||
InputReference *string `json:"input_reference,omitempty"` // Primary image for image-to-video (OpenAI-compatible)
|
||||
}
|
||||
|
||||
type VideoGenerationParameters struct {
|
||||
Seconds *string `json:"seconds,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
|
||||
NegativePrompt *string `json:"negative_prompt,omitempty"`
|
||||
Seed *int `json:"seed,omitempty"`
|
||||
VideoURI *string `json:"video_uri,omitempty"` // for video to video generation
|
||||
Audio *bool `json:"audio,omitempty"`
|
||||
ExtraParams map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// DefaultVideoDuration is the default video duration in seconds for Gemini/Vertex when not specified.
|
||||
const DefaultVideoDuration = "8"
|
||||
|
||||
// BifrostVideoGenerationResponse represents the video generation job response in bifrost format.
|
||||
type BifrostVideoGenerationResponse struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
CompletedAt *int64 `json:"completed_at,omitempty"` // Unix timestamp (seconds) when the job completed
|
||||
CreatedAt int64 `json:"created_at,omitempty"` // Unix timestamp (seconds) when the job was created
|
||||
Error *VideoCreateError `json:"error,omitempty"` // Error payload if generation failed
|
||||
ExpiresAt *int64 `json:"expires_at,omitempty"` // Unix timestamp (seconds) when downloadable assets expire
|
||||
Model string `json:"model,omitempty"` // Video generation model that produced the job
|
||||
Object string `json:"object,omitempty"` // Object type, always "video"
|
||||
Progress *float64 `json:"progress,omitempty"` // Approximate completion percentage (0-100)
|
||||
Prompt string `json:"prompt,omitempty"` // Prompt used to generate the video
|
||||
RemixedFromVideoID *string `json:"remixed_from_video_id,omitempty"` // Source video ID if this is a remix
|
||||
Seconds *string `json:"seconds,omitempty"` // Duration of the generated clip in seconds
|
||||
Size string `json:"size,omitempty"` // Resolution of the generated video
|
||||
Status VideoStatus `json:"status,omitempty"` // Current lifecycle status of the video job
|
||||
Videos []VideoOutput `json:"videos,omitempty"` // Generated videos (supports multiple videos)
|
||||
ContentFilter *ContentFilterInfo `json:"content_filter,omitempty"` // Information about content filtering (if applicable)
|
||||
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields,omitempty"`
|
||||
}
|
||||
|
||||
// getSecondsFromVideoRequest extracts Seconds from video-related requests.
|
||||
func getSecondsFromVideoRequest(req *BifrostRequest) *string {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
useDefaultForSeconds := func(p ModelProvider) bool {
|
||||
return p == Gemini || p == Vertex
|
||||
}
|
||||
if req.VideoGenerationRequest != nil {
|
||||
var seconds *string
|
||||
if req.VideoGenerationRequest.Params != nil {
|
||||
seconds = req.VideoGenerationRequest.Params.Seconds
|
||||
}
|
||||
if seconds == nil && useDefaultForSeconds(req.VideoGenerationRequest.Provider) {
|
||||
seconds = Ptr(DefaultVideoDuration)
|
||||
}
|
||||
return seconds
|
||||
}
|
||||
if req.VideoRemixRequest != nil && useDefaultForSeconds(req.VideoRemixRequest.Provider) {
|
||||
return Ptr(DefaultVideoDuration)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BackfillParams populates response fields from the original request that are needed
|
||||
// for cost calculation but may not be returned by the provider.
|
||||
// - Seconds (duration from request params or default)
|
||||
func (r *BifrostVideoGenerationResponse) BackfillParams(req *BifrostRequest) {
|
||||
if r == nil || req == nil {
|
||||
return
|
||||
}
|
||||
seconds := getSecondsFromVideoRequest(req)
|
||||
if seconds != nil {
|
||||
r.Seconds = seconds
|
||||
}
|
||||
if r.Model == "" && req.VideoGenerationRequest != nil {
|
||||
r.Model = req.VideoGenerationRequest.Model
|
||||
}
|
||||
}
|
||||
|
||||
// --- Video Remix ---
|
||||
|
||||
type BifrostVideoRemixRequest struct {
|
||||
ID string `json:"id"`
|
||||
Provider ModelProvider `json:"provider"`
|
||||
Input *VideoGenerationInput `json:"input"`
|
||||
ExtraParams map[string]any `json:"-"`
|
||||
RawRequestBody []byte `json:"-"`
|
||||
}
|
||||
|
||||
func (b *BifrostVideoRemixRequest) GetRawRequestBody() []byte {
|
||||
return b.RawRequestBody
|
||||
}
|
||||
|
||||
func (b *BifrostVideoRemixRequest) GetExtraParams() map[string]interface{} {
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
return b.ExtraParams
|
||||
}
|
||||
|
||||
// --- Video List ---
|
||||
|
||||
type BifrostVideoListRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
After *string `json:"after,omitempty"`
|
||||
Limit *int `json:"limit,omitempty"`
|
||||
Order *string `json:"order,omitempty"`
|
||||
}
|
||||
|
||||
type BifrostVideoListResponse struct {
|
||||
Object string `json:"object"` // "list"
|
||||
Data []VideoObject `json:"data"`
|
||||
FirstID *string `json:"first_id,omitempty"`
|
||||
HasMore *bool `json:"has_more,omitempty"`
|
||||
LastID *string `json:"last_id,omitempty"`
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
|
||||
// --- Video Retrieve / Delete ---
|
||||
|
||||
type BifrostVideoReferenceRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
type BifrostVideoDeleteRequest = BifrostVideoReferenceRequest
|
||||
type BifrostVideoRetrieveRequest = BifrostVideoReferenceRequest
|
||||
|
||||
type BifrostVideoDeleteResponse struct {
|
||||
ID string `json:"id"`
|
||||
Deleted bool `json:"deleted"`
|
||||
Object string `json:"object,omitempty"` // "video.deleted"
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
|
||||
// --- Video Download ---
|
||||
|
||||
type BifrostVideoDownloadRequest struct {
|
||||
Provider ModelProvider `json:"provider"`
|
||||
ID string `json:"id"`
|
||||
Variant *VideoDownloadVariant `json:"variant,omitempty"`
|
||||
ExtraParams map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
type VideoDownloadVariant string
|
||||
|
||||
const (
|
||||
VideoDownloadVariantVideo VideoDownloadVariant = "video"
|
||||
VideoDownloadVariantThumbnail VideoDownloadVariant = "thumbnail"
|
||||
VideoDownloadVariantSpriteSheet VideoDownloadVariant = "sprite_sheet"
|
||||
)
|
||||
|
||||
type BifrostVideoDownloadResponse struct {
|
||||
VideoID string `json:"video_id"`
|
||||
Content []byte `json:"-"` // Raw video content (not serialized)
|
||||
ContentType string `json:"content_type,omitempty"` // MIME type (e.g., "video/mp4", "image/png" for thumbnails)
|
||||
|
||||
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
||||
}
|
||||
|
||||
type VideoLogParams struct {
|
||||
VideoID string `json:"video_id"`
|
||||
}
|
||||
103
core/schemas/websocket.go
Normal file
103
core/schemas/websocket.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package schemas
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// WebSocketEventType represents event types in the Responses API WebSocket protocol.
|
||||
type WebSocketEventType string
|
||||
|
||||
const (
|
||||
WSEventResponseCreate WebSocketEventType = "response.create"
|
||||
WSEventError WebSocketEventType = "error"
|
||||
)
|
||||
|
||||
// WebSocketResponsesEvent represents a client-sent event over the Responses WebSocket connection.
|
||||
// The payload mirrors the Responses API create body, with transport-specific fields
|
||||
// (stream, background) omitted since they're implicit in the WebSocket context.
|
||||
type WebSocketResponsesEvent struct {
|
||||
Type WebSocketEventType `json:"type"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Store *bool `json:"store,omitempty"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
Instructions string `json:"instructions,omitempty"`
|
||||
PreviousResponseID string `json:"previous_response_id,omitempty"`
|
||||
Generate *bool `json:"generate,omitempty"`
|
||||
Tools json.RawMessage `json:"tools,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
|
||||
Reasoning json.RawMessage `json:"reasoning,omitempty"`
|
||||
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||
Text json.RawMessage `json:"text,omitempty"`
|
||||
Truncation string `json:"truncation,omitempty"`
|
||||
}
|
||||
|
||||
// WebSocketErrorEvent represents a server-sent error event over WebSocket.
|
||||
type WebSocketErrorEvent struct {
|
||||
Type WebSocketEventType `json:"type"`
|
||||
Status int `json:"status,omitempty"`
|
||||
Error *WebSocketErrorBody `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// WebSocketErrorBody is the error detail within a WebSocketErrorEvent.
|
||||
type WebSocketErrorBody struct {
|
||||
Code string `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Param string `json:"param,omitempty"`
|
||||
}
|
||||
|
||||
// WebSocketConfig provides optional tuning for WebSocket gateway features.
|
||||
// WebSocket is always enabled. These fields allow overriding the high defaults.
|
||||
type WebSocketConfig struct {
|
||||
MaxConnections int `json:"max_connections_per_user"`
|
||||
TranscriptBufferSize int `json:"transcript_buffer_size"`
|
||||
Pool *WSPoolConfig `json:"pool,omitempty"`
|
||||
}
|
||||
|
||||
// WSPoolConfig configures the upstream WebSocket connection pool.
|
||||
type WSPoolConfig struct {
|
||||
MaxIdlePerKey int `json:"max_idle_per_key"`
|
||||
MaxTotalConnections int `json:"max_total_connections"`
|
||||
IdleTimeoutSeconds int `json:"idle_timeout_seconds"`
|
||||
MaxConnectionLifetimeSeconds int `json:"max_connection_lifetime_seconds"`
|
||||
}
|
||||
|
||||
// Default pool configuration values (set high for production workloads)
|
||||
const (
|
||||
DefaultWSMaxIdlePerKey = 50
|
||||
DefaultWSMaxTotalConnections = 1000
|
||||
DefaultWSIdleTimeoutSeconds = 600
|
||||
DefaultWSMaxConnectionLifetimeSeconds = 7200
|
||||
DefaultWSMaxConnections = 100
|
||||
DefaultWSTranscriptBufferSize = 100
|
||||
)
|
||||
|
||||
// CheckAndSetDefaults fills in default values for WebSocketConfig.
|
||||
func (c *WebSocketConfig) CheckAndSetDefaults() {
|
||||
if c.MaxConnections <= 0 {
|
||||
c.MaxConnections = DefaultWSMaxConnections
|
||||
}
|
||||
if c.TranscriptBufferSize <= 0 {
|
||||
c.TranscriptBufferSize = DefaultWSTranscriptBufferSize
|
||||
}
|
||||
if c.Pool == nil {
|
||||
c.Pool = &WSPoolConfig{}
|
||||
}
|
||||
c.Pool.CheckAndSetDefaults()
|
||||
}
|
||||
|
||||
// CheckAndSetDefaults fills in default values for WSPoolConfig.
|
||||
func (c *WSPoolConfig) CheckAndSetDefaults() {
|
||||
if c.MaxIdlePerKey <= 0 {
|
||||
c.MaxIdlePerKey = DefaultWSMaxIdlePerKey
|
||||
}
|
||||
if c.MaxTotalConnections <= 0 {
|
||||
c.MaxTotalConnections = DefaultWSMaxTotalConnections
|
||||
}
|
||||
if c.IdleTimeoutSeconds <= 0 {
|
||||
c.IdleTimeoutSeconds = DefaultWSIdleTimeoutSeconds
|
||||
}
|
||||
if c.MaxConnectionLifetimeSeconds <= 0 {
|
||||
c.MaxConnectionLifetimeSeconds = DefaultWSMaxConnectionLifetimeSeconds
|
||||
}
|
||||
}
|
||||
59
core/schemas/websocket_test.go
Normal file
59
core/schemas/websocket_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWebSocketConfig_CheckAndSetDefaults(t *testing.T) {
|
||||
config := &WebSocketConfig{}
|
||||
config.CheckAndSetDefaults()
|
||||
|
||||
if config.MaxConnections != DefaultWSMaxConnections {
|
||||
t.Errorf("expected MaxConnections=%d, got %d", DefaultWSMaxConnections, config.MaxConnections)
|
||||
}
|
||||
if config.TranscriptBufferSize != DefaultWSTranscriptBufferSize {
|
||||
t.Errorf("expected TranscriptBufferSize=%d, got %d", DefaultWSTranscriptBufferSize, config.TranscriptBufferSize)
|
||||
}
|
||||
if config.Pool == nil {
|
||||
t.Fatal("expected Pool to be initialized")
|
||||
}
|
||||
if config.Pool.MaxIdlePerKey != DefaultWSMaxIdlePerKey {
|
||||
t.Errorf("expected Pool.MaxIdlePerKey=%d, got %d", DefaultWSMaxIdlePerKey, config.Pool.MaxIdlePerKey)
|
||||
}
|
||||
if config.Pool.MaxTotalConnections != DefaultWSMaxTotalConnections {
|
||||
t.Errorf("expected Pool.MaxTotalConnections=%d, got %d", DefaultWSMaxTotalConnections, config.Pool.MaxTotalConnections)
|
||||
}
|
||||
if config.Pool.IdleTimeoutSeconds != DefaultWSIdleTimeoutSeconds {
|
||||
t.Errorf("expected Pool.IdleTimeoutSeconds=%d, got %d", DefaultWSIdleTimeoutSeconds, config.Pool.IdleTimeoutSeconds)
|
||||
}
|
||||
if config.Pool.MaxConnectionLifetimeSeconds != DefaultWSMaxConnectionLifetimeSeconds {
|
||||
t.Errorf("expected Pool.MaxConnectionLifetimeSeconds=%d, got %d", DefaultWSMaxConnectionLifetimeSeconds, config.Pool.MaxConnectionLifetimeSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketConfig_PreservesExistingValues(t *testing.T) {
|
||||
config := &WebSocketConfig{
|
||||
MaxConnections: 20,
|
||||
TranscriptBufferSize: 123,
|
||||
Pool: &WSPoolConfig{
|
||||
MaxIdlePerKey: 5,
|
||||
MaxTotalConnections: 50,
|
||||
IdleTimeoutSeconds: 60,
|
||||
MaxConnectionLifetimeSeconds: 1800,
|
||||
},
|
||||
}
|
||||
config.CheckAndSetDefaults()
|
||||
|
||||
if config.MaxConnections != 20 {
|
||||
t.Errorf("expected MaxConnections=20, got %d", config.MaxConnections)
|
||||
}
|
||||
if config.TranscriptBufferSize != 123 {
|
||||
t.Errorf("expected TranscriptBufferSize=123, got %d", config.TranscriptBufferSize)
|
||||
}
|
||||
if config.Pool.MaxIdlePerKey != 5 {
|
||||
t.Errorf("expected Pool.MaxIdlePerKey=5, got %d", config.Pool.MaxIdlePerKey)
|
||||
}
|
||||
if config.Pool.MaxTotalConnections != 50 {
|
||||
t.Errorf("expected Pool.MaxTotalConnections=50, got %d", config.Pool.MaxTotalConnections)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user