628 lines
21 KiB
Go
628 lines
21 KiB
Go
package bifrost
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"math/rand"
|
|
"net"
|
|
"net/url"
|
|
"slices"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/maximhq/bifrost/core/mcp"
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
)
|
|
|
|
// Define a set of retryable status codes
|
|
var retryableStatusCodes = map[int]bool{
|
|
500: true, // Internal Server Error
|
|
502: true, // Bad Gateway
|
|
503: true, // Service Unavailable
|
|
504: true, // Gateway Timeout
|
|
429: true, // Too Many Requests
|
|
}
|
|
|
|
// Define rate limit error message patterns (case-insensitive)
|
|
var rateLimitPatterns = []string{
|
|
"rate limit",
|
|
"rate_limit",
|
|
"ratelimit",
|
|
"too many requests",
|
|
"quota exceeded",
|
|
"quota_exceeded",
|
|
"request limit",
|
|
"throttled",
|
|
"throttling",
|
|
"rate exceeded",
|
|
"limit exceeded",
|
|
"requests per",
|
|
"rpm exceeded",
|
|
"tpm exceeded",
|
|
"tokens per minute",
|
|
"requests per minute",
|
|
"requests per second",
|
|
"api rate limit",
|
|
"usage limit",
|
|
"concurrent requests limit",
|
|
"burst_rate",
|
|
"rate increased",
|
|
}
|
|
|
|
// dynamicallyConfigurableProviders is the list of providers that can be dynamically configured.
|
|
// Excluding providers that require extra configuration (e.g. Ollama, SGL, vLLM).
|
|
var dynamicallyConfigurableProviders = []schemas.ModelProvider{
|
|
schemas.Anthropic,
|
|
schemas.Azure,
|
|
schemas.Bedrock,
|
|
schemas.Cerebras,
|
|
schemas.Cohere,
|
|
schemas.Elevenlabs,
|
|
schemas.Gemini,
|
|
schemas.Groq,
|
|
schemas.HuggingFace,
|
|
schemas.Mistral,
|
|
schemas.Nebius,
|
|
schemas.OpenAI,
|
|
schemas.OpenRouter,
|
|
schemas.Parasail,
|
|
schemas.Perplexity,
|
|
schemas.Vertex,
|
|
schemas.XAI,
|
|
}
|
|
|
|
// isModelRequired returns true if the request type requires a model
|
|
func isModelRequired(reqType schemas.RequestType) bool {
|
|
return reqType == schemas.TextCompletionRequest || reqType == schemas.TextCompletionStreamRequest || reqType == schemas.ChatCompletionRequest || reqType == schemas.ChatCompletionStreamRequest || reqType == schemas.ResponsesRequest || reqType == schemas.ResponsesStreamRequest || reqType == schemas.SpeechRequest || reqType == schemas.SpeechStreamRequest || reqType == schemas.TranscriptionRequest || reqType == schemas.TranscriptionStreamRequest || reqType == schemas.EmbeddingRequest || reqType == schemas.ImageGenerationRequest || reqType == schemas.ImageGenerationStreamRequest || reqType == schemas.VideoGenerationRequest
|
|
}
|
|
|
|
// Ptr returns a pointer to the given value.
|
|
func Ptr[T any](v T) *T {
|
|
return &v
|
|
}
|
|
|
|
// providerRequiresKey returns true if the given provider requires an API key for authentication.
|
|
func providerRequiresKey(customConfig *schemas.CustomProviderConfig) bool {
|
|
// Keyless custom providers are not allowed for Bedrock.
|
|
if customConfig != nil && customConfig.IsKeyLess && customConfig.BaseProviderType != schemas.Bedrock {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
// CanProviderKeyValueBeEmpty returns true if the given provider allows the API key to be empty.
|
|
// Some providers like Vertex and Bedrock have their credentials in additional key configs.
|
|
// Ollama and SGL are keyless (API Key is optional) but use per-key server URLs.
|
|
func CanProviderKeyValueBeEmpty(providerKey schemas.ModelProvider) bool {
|
|
return providerKey == schemas.Vertex || providerKey == schemas.Bedrock || providerKey == schemas.VLLM || providerKey == schemas.Azure || providerKey == schemas.Ollama || providerKey == schemas.SGL
|
|
}
|
|
|
|
func isKeySkippingAllowed(providerKey schemas.ModelProvider) bool {
|
|
return providerKey != schemas.Azure && providerKey != schemas.Bedrock && providerKey != schemas.Vertex
|
|
}
|
|
|
|
// calculateBackoff implements exponential backoff with jitter for retry attempts.
|
|
func calculateBackoff(attempt int, config *schemas.ProviderConfig) time.Duration {
|
|
// Calculate an exponential backoff: initial * 2^attempt
|
|
backoff := min(config.NetworkConfig.RetryBackoffInitial*time.Duration(1<<uint(attempt)), config.NetworkConfig.RetryBackoffMax)
|
|
// Add jitter (20%)
|
|
jitter := float64(backoff) * (0.8 + 0.4*rand.Float64())
|
|
result := time.Duration(jitter)
|
|
// Ensure we never exceed the configured maximum
|
|
return min(result, config.NetworkConfig.RetryBackoffMax)
|
|
}
|
|
|
|
// validateRequest validates the given request.
|
|
func validateRequest(req *schemas.BifrostRequest) *schemas.BifrostError {
|
|
if req == nil {
|
|
return newBifrostErrorFromMsg("bifrost request cannot be nil")
|
|
}
|
|
provider, model, _ := req.GetRequestFields()
|
|
if provider == "" {
|
|
return newBifrostErrorFromMsg("provider is required")
|
|
}
|
|
if isModelRequired(req.RequestType) && model == "" {
|
|
return newBifrostErrorFromMsg("model is required")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// validateKey validates the given key.
|
|
func validateKey(providerKey schemas.ModelProvider, key *schemas.Key) error {
|
|
// Validate the key for the provider
|
|
switch providerKey {
|
|
case schemas.Azure:
|
|
if key.AzureKeyConfig == nil {
|
|
return fmt.Errorf("azure_key_config is required")
|
|
}
|
|
if key.AzureKeyConfig.Endpoint.GetValue() == "" {
|
|
return fmt.Errorf("azure_key_config.endpoint is required")
|
|
}
|
|
case schemas.Bedrock:
|
|
// BedrockKeyConfig is optional — an empty config is valid for IRSA / ambient credential auth.
|
|
if key.BedrockKeyConfig == nil {
|
|
key.BedrockKeyConfig = &schemas.BedrockKeyConfig{}
|
|
}
|
|
case schemas.Vertex:
|
|
if key.VertexKeyConfig == nil {
|
|
return fmt.Errorf("vertex_key_config is required")
|
|
}
|
|
case schemas.VLLM:
|
|
if key.VLLMKeyConfig == nil {
|
|
return fmt.Errorf("vllm_key_config is required")
|
|
}
|
|
if key.VLLMKeyConfig.URL.GetValue() == "" {
|
|
return fmt.Errorf("vllm_key_config.url is required")
|
|
}
|
|
case schemas.Ollama:
|
|
if key.OllamaKeyConfig == nil {
|
|
return fmt.Errorf("ollama_key_config is required")
|
|
}
|
|
if key.OllamaKeyConfig.URL.GetValue() == "" {
|
|
return fmt.Errorf("ollama_key_config.url is required")
|
|
}
|
|
case schemas.SGL:
|
|
if key.SGLKeyConfig == nil {
|
|
return fmt.Errorf("sgl_key_config is required")
|
|
}
|
|
if key.SGLKeyConfig.URL.GetValue() == "" {
|
|
return fmt.Errorf("sgl_key_config.url is required")
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// IsRateLimitErrorMessage checks if an error message indicates a rate limit issue
|
|
func IsRateLimitErrorMessage(errorMessage string) bool {
|
|
if errorMessage == "" {
|
|
return false
|
|
}
|
|
|
|
// Convert to lowercase for case-insensitive matching
|
|
lowerMessage := strings.ToLower(errorMessage)
|
|
|
|
// Check if any rate limit pattern is found in the error message
|
|
for _, pattern := range rateLimitPatterns {
|
|
if strings.Contains(lowerMessage, pattern) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// newBifrostError wraps a standard error into a BifrostError with IsBifrostError set to false.
|
|
// This helper function reduces code duplication when handling non-Bifrost errors.
|
|
func newBifrostError(err error) *schemas.BifrostError {
|
|
return &schemas.BifrostError{
|
|
IsBifrostError: false,
|
|
Error: &schemas.ErrorField{
|
|
Message: err.Error(),
|
|
Error: err,
|
|
},
|
|
}
|
|
}
|
|
|
|
// newBifrostErrorFromMsg creates a BifrostError with a custom message.
|
|
// This helper function is used for static error messages.
|
|
func newBifrostErrorFromMsg(message string) *schemas.BifrostError {
|
|
return &schemas.BifrostError{
|
|
IsBifrostError: false,
|
|
Error: &schemas.ErrorField{
|
|
Message: message,
|
|
},
|
|
}
|
|
}
|
|
|
|
// newBifrostCtxDoneError creates a BifrostError from a cancelled/expired context.
|
|
// It distinguishes DeadlineExceeded (504 RequestTimedOut) from Canceled (499 RequestCancelled).
|
|
func newBifrostCtxDoneError(ctx *schemas.BifrostContext, stage string) *schemas.BifrostError {
|
|
var statusCode int
|
|
var errorType string
|
|
var message string
|
|
|
|
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
|
statusCode = 504
|
|
errorType = schemas.RequestTimedOut
|
|
message = fmt.Sprintf("request timed out %s: %v", stage, ctx.Err())
|
|
} else {
|
|
statusCode = 499
|
|
errorType = schemas.RequestCancelled
|
|
message = fmt.Sprintf("request cancelled %s: %v", stage, ctx.Err())
|
|
}
|
|
|
|
return &schemas.BifrostError{
|
|
IsBifrostError: true,
|
|
StatusCode: &statusCode,
|
|
AllowFallbacks: new(false),
|
|
Error: &schemas.ErrorField{
|
|
Type: &errorType,
|
|
Message: message,
|
|
Error: ctx.Err(),
|
|
},
|
|
}
|
|
}
|
|
|
|
// newBifrostMessageChan creates a channel that sends a bifrost response.
|
|
// It is used to send a bifrost response to the client.
|
|
func newBifrostMessageChan(message *schemas.BifrostResponse) chan *schemas.BifrostStreamChunk {
|
|
ch := make(chan *schemas.BifrostStreamChunk)
|
|
|
|
go func() {
|
|
defer close(ch)
|
|
ch <- &schemas.BifrostStreamChunk{
|
|
BifrostTextCompletionResponse: message.TextCompletionResponse,
|
|
BifrostChatResponse: message.ChatResponse,
|
|
BifrostResponsesStreamResponse: message.ResponsesStreamResponse,
|
|
BifrostSpeechStreamResponse: message.SpeechStreamResponse,
|
|
BifrostTranscriptionStreamResponse: message.TranscriptionStreamResponse,
|
|
}
|
|
}()
|
|
|
|
return ch
|
|
}
|
|
|
|
// clearCtxForFallback clears the ctx values which are not applicable for fallback requests.
|
|
func clearCtxForFallback(ctx *schemas.BifrostContext) {
|
|
ctx.ClearValue(schemas.BifrostContextKeyAPIKeyID)
|
|
ctx.ClearValue(schemas.BifrostContextKeyAPIKeyName)
|
|
ctx.ClearValue(schemas.BifrostContextKeyGovernanceIncludeOnlyKeys)
|
|
ctx.ClearValue(schemas.BifrostContextKeyChangeRequestType)
|
|
ctx.ClearValue(schemas.BifrostContextKeyAttemptTrail)
|
|
ctx.ClearValue(schemas.BifrostContextKeyStreamEndIndicator)
|
|
}
|
|
|
|
var supportedBaseProvidersSet = func() map[schemas.ModelProvider]struct{} {
|
|
m := make(map[schemas.ModelProvider]struct{}, len(schemas.SupportedBaseProviders))
|
|
for _, p := range schemas.SupportedBaseProviders {
|
|
m[p] = struct{}{}
|
|
}
|
|
return m
|
|
}()
|
|
|
|
// IsSupportedBaseProvider reports whether providerKey is allowed as a base provider
|
|
// for custom providers.
|
|
func IsSupportedBaseProvider(providerKey schemas.ModelProvider) bool {
|
|
_, ok := supportedBaseProvidersSet[providerKey]
|
|
return ok
|
|
}
|
|
|
|
var standardProvidersSet = func() map[schemas.ModelProvider]struct{} {
|
|
m := make(map[schemas.ModelProvider]struct{}, len(schemas.StandardProviders))
|
|
for _, p := range schemas.StandardProviders {
|
|
m[p] = struct{}{}
|
|
}
|
|
return m
|
|
}()
|
|
|
|
// IsStandardProvider reports whether providerKey is a built-in (non-custom) provider.
|
|
func IsStandardProvider(providerKey schemas.ModelProvider) bool {
|
|
_, ok := standardProvidersSet[providerKey]
|
|
return ok
|
|
}
|
|
|
|
// IsStreamRequestType returns true if the given request type is a stream request.
|
|
func IsStreamRequestType(reqType schemas.RequestType) bool {
|
|
return reqType == schemas.TextCompletionStreamRequest || reqType == schemas.ChatCompletionStreamRequest || reqType == schemas.ResponsesStreamRequest || reqType == schemas.SpeechStreamRequest || reqType == schemas.TranscriptionStreamRequest || reqType == schemas.ImageGenerationStreamRequest || reqType == schemas.ImageEditStreamRequest || reqType == schemas.PassthroughStreamRequest || reqType == schemas.WebSocketResponsesRequest || reqType == schemas.RealtimeRequest
|
|
}
|
|
|
|
func GetTracerFromContext(ctx *schemas.BifrostContext) (schemas.Tracer, string, error) {
|
|
tracer, ok := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer)
|
|
if !ok || tracer == nil {
|
|
return nil, "", fmt.Errorf("tracer not found in context")
|
|
}
|
|
traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string)
|
|
if !ok || traceID == "" {
|
|
return nil, "", fmt.Errorf("traceID not found in context")
|
|
}
|
|
return tracer, traceID, nil
|
|
}
|
|
|
|
// isBatchRequestType returns true if the given request type is a batch API operation.
|
|
func isBatchRequestType(reqType schemas.RequestType) bool {
|
|
return reqType == schemas.BatchCreateRequest || reqType == schemas.BatchListRequest || reqType == schemas.BatchRetrieveRequest || reqType == schemas.BatchCancelRequest || reqType == schemas.BatchDeleteRequest || reqType == schemas.BatchResultsRequest
|
|
}
|
|
|
|
// isFileRequestType returns true if the given request type is a file API operation.
|
|
func isFileRequestType(reqType schemas.RequestType) bool {
|
|
return reqType == schemas.FileUploadRequest || reqType == schemas.FileListRequest || reqType == schemas.FileRetrieveRequest || reqType == schemas.FileDeleteRequest || reqType == schemas.FileContentRequest
|
|
}
|
|
|
|
// isContainerRequestType returns true if the given request type is a container API operation.
|
|
func isContainerRequestType(reqType schemas.RequestType) bool {
|
|
return reqType == schemas.ContainerCreateRequest || reqType == schemas.ContainerListRequest ||
|
|
reqType == schemas.ContainerRetrieveRequest || reqType == schemas.ContainerDeleteRequest ||
|
|
reqType == schemas.ContainerFileCreateRequest || reqType == schemas.ContainerFileListRequest ||
|
|
reqType == schemas.ContainerFileRetrieveRequest || reqType == schemas.ContainerFileContentRequest ||
|
|
reqType == schemas.ContainerFileDeleteRequest
|
|
}
|
|
|
|
// isModellessVideoRequestType returns true if the given request type is a video request that does not require a model.
|
|
func isModellessVideoRequestType(reqType schemas.RequestType) bool {
|
|
switch reqType {
|
|
case schemas.VideoRetrieveRequest, schemas.VideoDownloadRequest, schemas.VideoListRequest,
|
|
schemas.VideoDeleteRequest, schemas.VideoRemixRequest:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
// isPassthroughRequestType returns true if the given request type is a passthrough request.
|
|
func isPassthroughRequestType(reqType schemas.RequestType) bool {
|
|
return reqType == schemas.PassthroughRequest || reqType == schemas.PassthroughStreamRequest
|
|
}
|
|
|
|
// IsFinalChunk returns true if the given context is a final chunk.
|
|
func IsFinalChunk(ctx *schemas.BifrostContext) bool {
|
|
if ctx == nil {
|
|
return false
|
|
}
|
|
|
|
isStreamEndIndicator := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator)
|
|
if isStreamEndIndicator == nil {
|
|
return false
|
|
}
|
|
|
|
if f, ok := isStreamEndIndicator.(bool); ok {
|
|
return f
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// GetResponseFields extracts the request type, provider, original model, and resolved model from the result or error.
|
|
func GetResponseFields(result *schemas.BifrostResponse, err *schemas.BifrostError) (requestType schemas.RequestType, provider schemas.ModelProvider, originalModel string, resolvedModel string) {
|
|
if result != nil {
|
|
extraFields := result.GetExtraFields()
|
|
return extraFields.RequestType, extraFields.Provider, extraFields.OriginalModelRequested, extraFields.ResolvedModelUsed
|
|
}
|
|
if err != nil {
|
|
return err.ExtraFields.RequestType, err.ExtraFields.Provider, err.ExtraFields.OriginalModelRequested, err.ExtraFields.ResolvedModelUsed
|
|
}
|
|
return
|
|
}
|
|
|
|
// MarshalUnsafe marshals the given value to a JSON string without escaping HTML characters.
|
|
// Returns empty string if marshaling fails.
|
|
func MarshalUnsafe(v any) string {
|
|
var buf bytes.Buffer
|
|
encoder := json.NewEncoder(&buf)
|
|
encoder.SetEscapeHTML(false)
|
|
err := encoder.Encode(v)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
// Encode adds a trailing newline, trim it
|
|
return strings.TrimSpace(buf.String())
|
|
}
|
|
|
|
func GetErrorMessage(err *schemas.BifrostError) string {
|
|
if err == nil {
|
|
return ""
|
|
}
|
|
if err.Error != nil && err.Error.Message != "" {
|
|
return err.Error.Message
|
|
} else if err.StatusCode != nil {
|
|
switch *err.StatusCode {
|
|
case 401:
|
|
return "unauthorized"
|
|
case 403:
|
|
return "forbidden"
|
|
case 404:
|
|
return "endpoint not found"
|
|
case 405:
|
|
return "method not allowed"
|
|
case 429:
|
|
return "rate limit exceeded"
|
|
case 500:
|
|
return "internal server error"
|
|
case 502:
|
|
return "bad gateway"
|
|
case 503:
|
|
return "service unavailable"
|
|
case 504:
|
|
return "gateway timeout"
|
|
default:
|
|
if err.Error != nil && err.Error.Message != "" {
|
|
return err.Error.Message
|
|
}
|
|
return fmt.Sprintf("HTTP %d error", *err.StatusCode)
|
|
}
|
|
} else if err.Type != nil {
|
|
return *err.Type
|
|
} else {
|
|
return "unknown error"
|
|
}
|
|
}
|
|
|
|
// GetStringFromContext safely extracts a string value from context
|
|
func GetStringFromContext(ctx context.Context, key any) string {
|
|
if value := ctx.Value(key); value != nil {
|
|
if str, ok := value.(string); ok {
|
|
return str
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// GetIntFromContext safely extracts an int value from context
|
|
func GetIntFromContext(ctx context.Context, key any) int {
|
|
if value := ctx.Value(key); value != nil {
|
|
if intValue, ok := value.(int); ok {
|
|
return intValue
|
|
}
|
|
}
|
|
return 0
|
|
}
|
|
|
|
// GetBoolFromContext safely extracts a bool value from context
|
|
func GetBoolFromContext(ctx context.Context, key any) bool {
|
|
if value := ctx.Value(key); value != nil {
|
|
if boolValue, ok := value.(bool); ok {
|
|
return boolValue
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// RedactSensitiveString redacts sensitive information in a string
|
|
func RedactSensitiveString(s string) string {
|
|
if s == "" {
|
|
return ""
|
|
}
|
|
// Show first 4 and last 4 characters for identification, rest is [REDACTED]
|
|
if len(s) <= 8 {
|
|
return "[REDACTED]"
|
|
}
|
|
return s[:4] + "[REDACTED]" + s[len(s)-4:]
|
|
}
|
|
|
|
// ValidateExternalURL validates a URL for security concerns (SSRF protection)
|
|
func ValidateExternalURL(urlStr string) error {
|
|
if urlStr == "" {
|
|
return fmt.Errorf("URL cannot be empty")
|
|
}
|
|
// Parse the URL
|
|
parsedURL, err := url.Parse(urlStr)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid URL format: %w", err)
|
|
}
|
|
// Only allow HTTPS scheme (or HTTP for localhost in development)
|
|
if parsedURL.Scheme != "https" && parsedURL.Scheme != "http" {
|
|
return fmt.Errorf("only https and http schemes are allowed, got: %s", parsedURL.Scheme)
|
|
}
|
|
// Extract hostname
|
|
hostname := parsedURL.Hostname()
|
|
if hostname == "" {
|
|
return fmt.Errorf("URL must have a hostname")
|
|
}
|
|
// Block localhost and loopback addresses
|
|
if isLocalhost(hostname) {
|
|
return fmt.Errorf("localhost and loopback addresses are not allowed")
|
|
}
|
|
// Resolve hostname to IP addresses
|
|
ips, err := net.LookupIP(hostname)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to resolve hostname: %w", err)
|
|
}
|
|
// Check if any resolved IP is private
|
|
for _, ip := range ips {
|
|
if isPrivateIP(ip) {
|
|
return fmt.Errorf("private IP addresses are not allowed")
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// isLocalhost checks if a hostname is localhost or a loopback address
|
|
func isLocalhost(hostname string) bool {
|
|
return hostname == "localhost" ||
|
|
hostname == "127.0.0.1" ||
|
|
hostname == "::1" ||
|
|
hostname == "0.0.0.0" ||
|
|
hostname == "::"
|
|
}
|
|
|
|
// isPrivateIP checks if an IP address is in a private range
|
|
func isPrivateIP(ip net.IP) bool {
|
|
// Private IPv4 ranges
|
|
privateRanges := []string{
|
|
"10.0.0.0/8",
|
|
"172.16.0.0/12",
|
|
"192.168.0.0/16",
|
|
"169.254.0.0/16", // Link-local
|
|
"127.0.0.0/8", // Loopback
|
|
}
|
|
for _, cidr := range privateRanges {
|
|
_, subnet, _ := net.ParseCIDR(cidr)
|
|
if subnet.Contains(ip) {
|
|
return true
|
|
}
|
|
}
|
|
// Check for private IPv6
|
|
if ip.To4() == nil {
|
|
// Check for IPv6 loopback and link-local
|
|
if ip.IsLoopback() || ip.IsLinkLocalUnicast() {
|
|
return true
|
|
}
|
|
// Check for IPv6 unique local addresses (fc00::/7)
|
|
if len(ip) == 16 && (ip[0]&0xfe) == 0xfc {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// sanitizeSpanName sanitizes a span name to remove capital letters and spaces to make it a valid span name
|
|
func sanitizeSpanName(name string) string {
|
|
return strings.ToLower(strings.ReplaceAll(name, " ", "-"))
|
|
}
|
|
|
|
// IsCodemodeTool returns true if the given tool name is a codemode tool.
|
|
func IsCodemodeTool(toolName string) bool {
|
|
return mcp.IsCodeModeTool(toolName)
|
|
}
|
|
|
|
// hashSHA256 returns a deterministic hex-encoded SHA-256 hash of the input.
|
|
func hashSHA256(value string) string {
|
|
h := sha256.Sum256([]byte(value))
|
|
return hex.EncodeToString(h[:])
|
|
}
|
|
|
|
func buildSessionKey(providerKey schemas.ModelProvider, sessionID string, model string) string {
|
|
// Hash session ID to prevent PII leakage and ensure bounded key size
|
|
hashedSessionID := hashSHA256(sessionID)
|
|
discriminator := model
|
|
if discriminator == "" {
|
|
discriminator = "__modelless__"
|
|
}
|
|
return "session:" + string(providerKey) + ":" + hashedSessionID + ":" + hashSHA256(discriminator)
|
|
}
|
|
|
|
// isPromptOptionalImageEditType returns true for edit task types that do not require a text prompt.
|
|
// It normalises hyphenated variants (e.g. "erase-object") to underscore form before matching.
|
|
func isPromptOptionalImageEditType(t *string) bool {
|
|
if t == nil {
|
|
return false
|
|
}
|
|
normalized := strings.ToLower(strings.TrimSpace(*t))
|
|
normalized = strings.ReplaceAll(normalized, "-", "_")
|
|
return slices.Contains(
|
|
[]string{"background_removal", "remove_background", "remove_bg", "erase_object", "upscale_fast"},
|
|
normalized,
|
|
)
|
|
}
|
|
|
|
// wrapConvertedStreamPostHookRunner wraps a PostHookRunner so that streaming
|
|
// responses produced by a type-converted request are converted back to the
|
|
// caller's original type before the post-hook runs.
|
|
func wrapConvertedStreamPostHookRunner(postHookRunner schemas.PostHookRunner, targetType schemas.RequestType) schemas.PostHookRunner {
|
|
return func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
|
|
if result != nil {
|
|
switch targetType {
|
|
case schemas.ChatCompletionRequest:
|
|
// text→chat: convert chat stream chunk back to text completion
|
|
if result.ChatResponse != nil {
|
|
if converted := result.ChatResponse.ToBifrostTextCompletionResponse(); converted != nil {
|
|
result = &schemas.BifrostResponse{TextCompletionResponse: converted}
|
|
}
|
|
}
|
|
case schemas.ResponsesRequest:
|
|
// chat→responses: convert responses stream chunk back to chat
|
|
if result.ResponsesStreamResponse != nil {
|
|
if converted := result.ResponsesStreamResponse.ToBifrostChatResponse(); converted != nil {
|
|
result = &schemas.BifrostResponse{ChatResponse: converted}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return postHookRunner(ctx, result, bifrostErr)
|
|
}
|
|
}
|