first commit
This commit is contained in:
627
core/utils.go
Normal file
627
core/utils.go
Normal file
@@ -0,0 +1,627 @@
|
||||
package bifrost
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/mcp"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// Define a set of retryable status codes
|
||||
var retryableStatusCodes = map[int]bool{
|
||||
500: true, // Internal Server Error
|
||||
502: true, // Bad Gateway
|
||||
503: true, // Service Unavailable
|
||||
504: true, // Gateway Timeout
|
||||
429: true, // Too Many Requests
|
||||
}
|
||||
|
||||
// Define rate limit error message patterns (case-insensitive)
|
||||
var rateLimitPatterns = []string{
|
||||
"rate limit",
|
||||
"rate_limit",
|
||||
"ratelimit",
|
||||
"too many requests",
|
||||
"quota exceeded",
|
||||
"quota_exceeded",
|
||||
"request limit",
|
||||
"throttled",
|
||||
"throttling",
|
||||
"rate exceeded",
|
||||
"limit exceeded",
|
||||
"requests per",
|
||||
"rpm exceeded",
|
||||
"tpm exceeded",
|
||||
"tokens per minute",
|
||||
"requests per minute",
|
||||
"requests per second",
|
||||
"api rate limit",
|
||||
"usage limit",
|
||||
"concurrent requests limit",
|
||||
"burst_rate",
|
||||
"rate increased",
|
||||
}
|
||||
|
||||
// dynamicallyConfigurableProviders is the list of providers that can be dynamically configured.
|
||||
// Excluding providers that require extra configuration (e.g. Ollama, SGL, vLLM).
|
||||
var dynamicallyConfigurableProviders = []schemas.ModelProvider{
|
||||
schemas.Anthropic,
|
||||
schemas.Azure,
|
||||
schemas.Bedrock,
|
||||
schemas.Cerebras,
|
||||
schemas.Cohere,
|
||||
schemas.Elevenlabs,
|
||||
schemas.Gemini,
|
||||
schemas.Groq,
|
||||
schemas.HuggingFace,
|
||||
schemas.Mistral,
|
||||
schemas.Nebius,
|
||||
schemas.OpenAI,
|
||||
schemas.OpenRouter,
|
||||
schemas.Parasail,
|
||||
schemas.Perplexity,
|
||||
schemas.Vertex,
|
||||
schemas.XAI,
|
||||
}
|
||||
|
||||
// isModelRequired returns true if the request type requires a model
|
||||
func isModelRequired(reqType schemas.RequestType) bool {
|
||||
return reqType == schemas.TextCompletionRequest || reqType == schemas.TextCompletionStreamRequest || reqType == schemas.ChatCompletionRequest || reqType == schemas.ChatCompletionStreamRequest || reqType == schemas.ResponsesRequest || reqType == schemas.ResponsesStreamRequest || reqType == schemas.SpeechRequest || reqType == schemas.SpeechStreamRequest || reqType == schemas.TranscriptionRequest || reqType == schemas.TranscriptionStreamRequest || reqType == schemas.EmbeddingRequest || reqType == schemas.ImageGenerationRequest || reqType == schemas.ImageGenerationStreamRequest || reqType == schemas.VideoGenerationRequest
|
||||
}
|
||||
|
||||
// Ptr returns a pointer to the given value.
|
||||
func Ptr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
// providerRequiresKey returns true if the given provider requires an API key for authentication.
|
||||
func providerRequiresKey(customConfig *schemas.CustomProviderConfig) bool {
|
||||
// Keyless custom providers are not allowed for Bedrock.
|
||||
if customConfig != nil && customConfig.IsKeyLess && customConfig.BaseProviderType != schemas.Bedrock {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// CanProviderKeyValueBeEmpty returns true if the given provider allows the API key to be empty.
|
||||
// Some providers like Vertex and Bedrock have their credentials in additional key configs.
|
||||
// Ollama and SGL are keyless (API Key is optional) but use per-key server URLs.
|
||||
func CanProviderKeyValueBeEmpty(providerKey schemas.ModelProvider) bool {
|
||||
return providerKey == schemas.Vertex || providerKey == schemas.Bedrock || providerKey == schemas.VLLM || providerKey == schemas.Azure || providerKey == schemas.Ollama || providerKey == schemas.SGL
|
||||
}
|
||||
|
||||
func isKeySkippingAllowed(providerKey schemas.ModelProvider) bool {
|
||||
return providerKey != schemas.Azure && providerKey != schemas.Bedrock && providerKey != schemas.Vertex
|
||||
}
|
||||
|
||||
// calculateBackoff implements exponential backoff with jitter for retry attempts.
|
||||
func calculateBackoff(attempt int, config *schemas.ProviderConfig) time.Duration {
|
||||
// Calculate an exponential backoff: initial * 2^attempt
|
||||
backoff := min(config.NetworkConfig.RetryBackoffInitial*time.Duration(1<<uint(attempt)), config.NetworkConfig.RetryBackoffMax)
|
||||
// Add jitter (20%)
|
||||
jitter := float64(backoff) * (0.8 + 0.4*rand.Float64())
|
||||
result := time.Duration(jitter)
|
||||
// Ensure we never exceed the configured maximum
|
||||
return min(result, config.NetworkConfig.RetryBackoffMax)
|
||||
}
|
||||
|
||||
// validateRequest validates the given request.
|
||||
func validateRequest(req *schemas.BifrostRequest) *schemas.BifrostError {
|
||||
if req == nil {
|
||||
return newBifrostErrorFromMsg("bifrost request cannot be nil")
|
||||
}
|
||||
provider, model, _ := req.GetRequestFields()
|
||||
if provider == "" {
|
||||
return newBifrostErrorFromMsg("provider is required")
|
||||
}
|
||||
if isModelRequired(req.RequestType) && model == "" {
|
||||
return newBifrostErrorFromMsg("model is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateKey validates the given key.
|
||||
func validateKey(providerKey schemas.ModelProvider, key *schemas.Key) error {
|
||||
// Validate the key for the provider
|
||||
switch providerKey {
|
||||
case schemas.Azure:
|
||||
if key.AzureKeyConfig == nil {
|
||||
return fmt.Errorf("azure_key_config is required")
|
||||
}
|
||||
if key.AzureKeyConfig.Endpoint.GetValue() == "" {
|
||||
return fmt.Errorf("azure_key_config.endpoint is required")
|
||||
}
|
||||
case schemas.Bedrock:
|
||||
// BedrockKeyConfig is optional — an empty config is valid for IRSA / ambient credential auth.
|
||||
if key.BedrockKeyConfig == nil {
|
||||
key.BedrockKeyConfig = &schemas.BedrockKeyConfig{}
|
||||
}
|
||||
case schemas.Vertex:
|
||||
if key.VertexKeyConfig == nil {
|
||||
return fmt.Errorf("vertex_key_config is required")
|
||||
}
|
||||
case schemas.VLLM:
|
||||
if key.VLLMKeyConfig == nil {
|
||||
return fmt.Errorf("vllm_key_config is required")
|
||||
}
|
||||
if key.VLLMKeyConfig.URL.GetValue() == "" {
|
||||
return fmt.Errorf("vllm_key_config.url is required")
|
||||
}
|
||||
case schemas.Ollama:
|
||||
if key.OllamaKeyConfig == nil {
|
||||
return fmt.Errorf("ollama_key_config is required")
|
||||
}
|
||||
if key.OllamaKeyConfig.URL.GetValue() == "" {
|
||||
return fmt.Errorf("ollama_key_config.url is required")
|
||||
}
|
||||
case schemas.SGL:
|
||||
if key.SGLKeyConfig == nil {
|
||||
return fmt.Errorf("sgl_key_config is required")
|
||||
}
|
||||
if key.SGLKeyConfig.URL.GetValue() == "" {
|
||||
return fmt.Errorf("sgl_key_config.url is required")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsRateLimitErrorMessage checks if an error message indicates a rate limit issue
|
||||
func IsRateLimitErrorMessage(errorMessage string) bool {
|
||||
if errorMessage == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Convert to lowercase for case-insensitive matching
|
||||
lowerMessage := strings.ToLower(errorMessage)
|
||||
|
||||
// Check if any rate limit pattern is found in the error message
|
||||
for _, pattern := range rateLimitPatterns {
|
||||
if strings.Contains(lowerMessage, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// newBifrostError wraps a standard error into a BifrostError with IsBifrostError set to false.
|
||||
// This helper function reduces code duplication when handling non-Bifrost errors.
|
||||
func newBifrostError(err error) *schemas.BifrostError {
|
||||
return &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: err.Error(),
|
||||
Error: err,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// newBifrostErrorFromMsg creates a BifrostError with a custom message.
|
||||
// This helper function is used for static error messages.
|
||||
func newBifrostErrorFromMsg(message string) *schemas.BifrostError {
|
||||
return &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: message,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// newBifrostCtxDoneError creates a BifrostError from a cancelled/expired context.
|
||||
// It distinguishes DeadlineExceeded (504 RequestTimedOut) from Canceled (499 RequestCancelled).
|
||||
func newBifrostCtxDoneError(ctx *schemas.BifrostContext, stage string) *schemas.BifrostError {
|
||||
var statusCode int
|
||||
var errorType string
|
||||
var message string
|
||||
|
||||
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
||||
statusCode = 504
|
||||
errorType = schemas.RequestTimedOut
|
||||
message = fmt.Sprintf("request timed out %s: %v", stage, ctx.Err())
|
||||
} else {
|
||||
statusCode = 499
|
||||
errorType = schemas.RequestCancelled
|
||||
message = fmt.Sprintf("request cancelled %s: %v", stage, ctx.Err())
|
||||
}
|
||||
|
||||
return &schemas.BifrostError{
|
||||
IsBifrostError: true,
|
||||
StatusCode: &statusCode,
|
||||
AllowFallbacks: new(false),
|
||||
Error: &schemas.ErrorField{
|
||||
Type: &errorType,
|
||||
Message: message,
|
||||
Error: ctx.Err(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// newBifrostMessageChan creates a channel that sends a bifrost response.
|
||||
// It is used to send a bifrost response to the client.
|
||||
func newBifrostMessageChan(message *schemas.BifrostResponse) chan *schemas.BifrostStreamChunk {
|
||||
ch := make(chan *schemas.BifrostStreamChunk)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
ch <- &schemas.BifrostStreamChunk{
|
||||
BifrostTextCompletionResponse: message.TextCompletionResponse,
|
||||
BifrostChatResponse: message.ChatResponse,
|
||||
BifrostResponsesStreamResponse: message.ResponsesStreamResponse,
|
||||
BifrostSpeechStreamResponse: message.SpeechStreamResponse,
|
||||
BifrostTranscriptionStreamResponse: message.TranscriptionStreamResponse,
|
||||
}
|
||||
}()
|
||||
|
||||
return ch
|
||||
}
|
||||
|
||||
// clearCtxForFallback clears the ctx values which are not applicable for fallback requests.
|
||||
func clearCtxForFallback(ctx *schemas.BifrostContext) {
|
||||
ctx.ClearValue(schemas.BifrostContextKeyAPIKeyID)
|
||||
ctx.ClearValue(schemas.BifrostContextKeyAPIKeyName)
|
||||
ctx.ClearValue(schemas.BifrostContextKeyGovernanceIncludeOnlyKeys)
|
||||
ctx.ClearValue(schemas.BifrostContextKeyChangeRequestType)
|
||||
ctx.ClearValue(schemas.BifrostContextKeyAttemptTrail)
|
||||
ctx.ClearValue(schemas.BifrostContextKeyStreamEndIndicator)
|
||||
}
|
||||
|
||||
var supportedBaseProvidersSet = func() map[schemas.ModelProvider]struct{} {
|
||||
m := make(map[schemas.ModelProvider]struct{}, len(schemas.SupportedBaseProviders))
|
||||
for _, p := range schemas.SupportedBaseProviders {
|
||||
m[p] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}()
|
||||
|
||||
// IsSupportedBaseProvider reports whether providerKey is allowed as a base provider
|
||||
// for custom providers.
|
||||
func IsSupportedBaseProvider(providerKey schemas.ModelProvider) bool {
|
||||
_, ok := supportedBaseProvidersSet[providerKey]
|
||||
return ok
|
||||
}
|
||||
|
||||
var standardProvidersSet = func() map[schemas.ModelProvider]struct{} {
|
||||
m := make(map[schemas.ModelProvider]struct{}, len(schemas.StandardProviders))
|
||||
for _, p := range schemas.StandardProviders {
|
||||
m[p] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}()
|
||||
|
||||
// IsStandardProvider reports whether providerKey is a built-in (non-custom) provider.
|
||||
func IsStandardProvider(providerKey schemas.ModelProvider) bool {
|
||||
_, ok := standardProvidersSet[providerKey]
|
||||
return ok
|
||||
}
|
||||
|
||||
// IsStreamRequestType returns true if the given request type is a stream request.
|
||||
func IsStreamRequestType(reqType schemas.RequestType) bool {
|
||||
return reqType == schemas.TextCompletionStreamRequest || reqType == schemas.ChatCompletionStreamRequest || reqType == schemas.ResponsesStreamRequest || reqType == schemas.SpeechStreamRequest || reqType == schemas.TranscriptionStreamRequest || reqType == schemas.ImageGenerationStreamRequest || reqType == schemas.ImageEditStreamRequest || reqType == schemas.PassthroughStreamRequest || reqType == schemas.WebSocketResponsesRequest || reqType == schemas.RealtimeRequest
|
||||
}
|
||||
|
||||
func GetTracerFromContext(ctx *schemas.BifrostContext) (schemas.Tracer, string, error) {
|
||||
tracer, ok := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer)
|
||||
if !ok || tracer == nil {
|
||||
return nil, "", fmt.Errorf("tracer not found in context")
|
||||
}
|
||||
traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string)
|
||||
if !ok || traceID == "" {
|
||||
return nil, "", fmt.Errorf("traceID not found in context")
|
||||
}
|
||||
return tracer, traceID, nil
|
||||
}
|
||||
|
||||
// isBatchRequestType returns true if the given request type is a batch API operation.
|
||||
func isBatchRequestType(reqType schemas.RequestType) bool {
|
||||
return reqType == schemas.BatchCreateRequest || reqType == schemas.BatchListRequest || reqType == schemas.BatchRetrieveRequest || reqType == schemas.BatchCancelRequest || reqType == schemas.BatchDeleteRequest || reqType == schemas.BatchResultsRequest
|
||||
}
|
||||
|
||||
// isFileRequestType returns true if the given request type is a file API operation.
|
||||
func isFileRequestType(reqType schemas.RequestType) bool {
|
||||
return reqType == schemas.FileUploadRequest || reqType == schemas.FileListRequest || reqType == schemas.FileRetrieveRequest || reqType == schemas.FileDeleteRequest || reqType == schemas.FileContentRequest
|
||||
}
|
||||
|
||||
// isContainerRequestType returns true if the given request type is a container API operation.
|
||||
func isContainerRequestType(reqType schemas.RequestType) bool {
|
||||
return reqType == schemas.ContainerCreateRequest || reqType == schemas.ContainerListRequest ||
|
||||
reqType == schemas.ContainerRetrieveRequest || reqType == schemas.ContainerDeleteRequest ||
|
||||
reqType == schemas.ContainerFileCreateRequest || reqType == schemas.ContainerFileListRequest ||
|
||||
reqType == schemas.ContainerFileRetrieveRequest || reqType == schemas.ContainerFileContentRequest ||
|
||||
reqType == schemas.ContainerFileDeleteRequest
|
||||
}
|
||||
|
||||
// isModellessVideoRequestType returns true if the given request type is a video request that does not require a model.
|
||||
func isModellessVideoRequestType(reqType schemas.RequestType) bool {
|
||||
switch reqType {
|
||||
case schemas.VideoRetrieveRequest, schemas.VideoDownloadRequest, schemas.VideoListRequest,
|
||||
schemas.VideoDeleteRequest, schemas.VideoRemixRequest:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// isPassthroughRequestType returns true if the given request type is a passthrough request.
|
||||
func isPassthroughRequestType(reqType schemas.RequestType) bool {
|
||||
return reqType == schemas.PassthroughRequest || reqType == schemas.PassthroughStreamRequest
|
||||
}
|
||||
|
||||
// IsFinalChunk returns true if the given context is a final chunk.
|
||||
func IsFinalChunk(ctx *schemas.BifrostContext) bool {
|
||||
if ctx == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
isStreamEndIndicator := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator)
|
||||
if isStreamEndIndicator == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if f, ok := isStreamEndIndicator.(bool); ok {
|
||||
return f
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// GetResponseFields extracts the request type, provider, original model, and resolved model from the result or error.
|
||||
func GetResponseFields(result *schemas.BifrostResponse, err *schemas.BifrostError) (requestType schemas.RequestType, provider schemas.ModelProvider, originalModel string, resolvedModel string) {
|
||||
if result != nil {
|
||||
extraFields := result.GetExtraFields()
|
||||
return extraFields.RequestType, extraFields.Provider, extraFields.OriginalModelRequested, extraFields.ResolvedModelUsed
|
||||
}
|
||||
if err != nil {
|
||||
return err.ExtraFields.RequestType, err.ExtraFields.Provider, err.ExtraFields.OriginalModelRequested, err.ExtraFields.ResolvedModelUsed
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalUnsafe marshals the given value to a JSON string without escaping HTML characters.
|
||||
// Returns empty string if marshaling fails.
|
||||
func MarshalUnsafe(v any) string {
|
||||
var buf bytes.Buffer
|
||||
encoder := json.NewEncoder(&buf)
|
||||
encoder.SetEscapeHTML(false)
|
||||
err := encoder.Encode(v)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
// Encode adds a trailing newline, trim it
|
||||
return strings.TrimSpace(buf.String())
|
||||
}
|
||||
|
||||
func GetErrorMessage(err *schemas.BifrostError) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
if err.Error != nil && err.Error.Message != "" {
|
||||
return err.Error.Message
|
||||
} else if err.StatusCode != nil {
|
||||
switch *err.StatusCode {
|
||||
case 401:
|
||||
return "unauthorized"
|
||||
case 403:
|
||||
return "forbidden"
|
||||
case 404:
|
||||
return "endpoint not found"
|
||||
case 405:
|
||||
return "method not allowed"
|
||||
case 429:
|
||||
return "rate limit exceeded"
|
||||
case 500:
|
||||
return "internal server error"
|
||||
case 502:
|
||||
return "bad gateway"
|
||||
case 503:
|
||||
return "service unavailable"
|
||||
case 504:
|
||||
return "gateway timeout"
|
||||
default:
|
||||
if err.Error != nil && err.Error.Message != "" {
|
||||
return err.Error.Message
|
||||
}
|
||||
return fmt.Sprintf("HTTP %d error", *err.StatusCode)
|
||||
}
|
||||
} else if err.Type != nil {
|
||||
return *err.Type
|
||||
} else {
|
||||
return "unknown error"
|
||||
}
|
||||
}
|
||||
|
||||
// GetStringFromContext safely extracts a string value from context
|
||||
func GetStringFromContext(ctx context.Context, key any) string {
|
||||
if value := ctx.Value(key); value != nil {
|
||||
if str, ok := value.(string); ok {
|
||||
return str
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetIntFromContext safely extracts an int value from context
|
||||
func GetIntFromContext(ctx context.Context, key any) int {
|
||||
if value := ctx.Value(key); value != nil {
|
||||
if intValue, ok := value.(int); ok {
|
||||
return intValue
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetBoolFromContext safely extracts a bool value from context
|
||||
func GetBoolFromContext(ctx context.Context, key any) bool {
|
||||
if value := ctx.Value(key); value != nil {
|
||||
if boolValue, ok := value.(bool); ok {
|
||||
return boolValue
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RedactSensitiveString redacts sensitive information in a string
|
||||
func RedactSensitiveString(s string) string {
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
// Show first 4 and last 4 characters for identification, rest is [REDACTED]
|
||||
if len(s) <= 8 {
|
||||
return "[REDACTED]"
|
||||
}
|
||||
return s[:4] + "[REDACTED]" + s[len(s)-4:]
|
||||
}
|
||||
|
||||
// ValidateExternalURL validates a URL for security concerns (SSRF protection)
|
||||
func ValidateExternalURL(urlStr string) error {
|
||||
if urlStr == "" {
|
||||
return fmt.Errorf("URL cannot be empty")
|
||||
}
|
||||
// Parse the URL
|
||||
parsedURL, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid URL format: %w", err)
|
||||
}
|
||||
// Only allow HTTPS scheme (or HTTP for localhost in development)
|
||||
if parsedURL.Scheme != "https" && parsedURL.Scheme != "http" {
|
||||
return fmt.Errorf("only https and http schemes are allowed, got: %s", parsedURL.Scheme)
|
||||
}
|
||||
// Extract hostname
|
||||
hostname := parsedURL.Hostname()
|
||||
if hostname == "" {
|
||||
return fmt.Errorf("URL must have a hostname")
|
||||
}
|
||||
// Block localhost and loopback addresses
|
||||
if isLocalhost(hostname) {
|
||||
return fmt.Errorf("localhost and loopback addresses are not allowed")
|
||||
}
|
||||
// Resolve hostname to IP addresses
|
||||
ips, err := net.LookupIP(hostname)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve hostname: %w", err)
|
||||
}
|
||||
// Check if any resolved IP is private
|
||||
for _, ip := range ips {
|
||||
if isPrivateIP(ip) {
|
||||
return fmt.Errorf("private IP addresses are not allowed")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// isLocalhost checks if a hostname is localhost or a loopback address
|
||||
func isLocalhost(hostname string) bool {
|
||||
return hostname == "localhost" ||
|
||||
hostname == "127.0.0.1" ||
|
||||
hostname == "::1" ||
|
||||
hostname == "0.0.0.0" ||
|
||||
hostname == "::"
|
||||
}
|
||||
|
||||
// isPrivateIP checks if an IP address is in a private range
|
||||
func isPrivateIP(ip net.IP) bool {
|
||||
// Private IPv4 ranges
|
||||
privateRanges := []string{
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
"169.254.0.0/16", // Link-local
|
||||
"127.0.0.0/8", // Loopback
|
||||
}
|
||||
for _, cidr := range privateRanges {
|
||||
_, subnet, _ := net.ParseCIDR(cidr)
|
||||
if subnet.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// Check for private IPv6
|
||||
if ip.To4() == nil {
|
||||
// Check for IPv6 loopback and link-local
|
||||
if ip.IsLoopback() || ip.IsLinkLocalUnicast() {
|
||||
return true
|
||||
}
|
||||
// Check for IPv6 unique local addresses (fc00::/7)
|
||||
if len(ip) == 16 && (ip[0]&0xfe) == 0xfc {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// sanitizeSpanName sanitizes a span name to remove capital letters and spaces to make it a valid span name
|
||||
func sanitizeSpanName(name string) string {
|
||||
return strings.ToLower(strings.ReplaceAll(name, " ", "-"))
|
||||
}
|
||||
|
||||
// IsCodemodeTool returns true if the given tool name is a codemode tool.
|
||||
func IsCodemodeTool(toolName string) bool {
|
||||
return mcp.IsCodeModeTool(toolName)
|
||||
}
|
||||
|
||||
// hashSHA256 returns a deterministic hex-encoded SHA-256 hash of the input.
|
||||
func hashSHA256(value string) string {
|
||||
h := sha256.Sum256([]byte(value))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
func buildSessionKey(providerKey schemas.ModelProvider, sessionID string, model string) string {
|
||||
// Hash session ID to prevent PII leakage and ensure bounded key size
|
||||
hashedSessionID := hashSHA256(sessionID)
|
||||
discriminator := model
|
||||
if discriminator == "" {
|
||||
discriminator = "__modelless__"
|
||||
}
|
||||
return "session:" + string(providerKey) + ":" + hashedSessionID + ":" + hashSHA256(discriminator)
|
||||
}
|
||||
|
||||
// isPromptOptionalImageEditType returns true for edit task types that do not require a text prompt.
|
||||
// It normalises hyphenated variants (e.g. "erase-object") to underscore form before matching.
|
||||
func isPromptOptionalImageEditType(t *string) bool {
|
||||
if t == nil {
|
||||
return false
|
||||
}
|
||||
normalized := strings.ToLower(strings.TrimSpace(*t))
|
||||
normalized = strings.ReplaceAll(normalized, "-", "_")
|
||||
return slices.Contains(
|
||||
[]string{"background_removal", "remove_background", "remove_bg", "erase_object", "upscale_fast"},
|
||||
normalized,
|
||||
)
|
||||
}
|
||||
|
||||
// wrapConvertedStreamPostHookRunner wraps a PostHookRunner so that streaming
|
||||
// responses produced by a type-converted request are converted back to the
|
||||
// caller's original type before the post-hook runs.
|
||||
func wrapConvertedStreamPostHookRunner(postHookRunner schemas.PostHookRunner, targetType schemas.RequestType) schemas.PostHookRunner {
|
||||
return func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
|
||||
if result != nil {
|
||||
switch targetType {
|
||||
case schemas.ChatCompletionRequest:
|
||||
// text→chat: convert chat stream chunk back to text completion
|
||||
if result.ChatResponse != nil {
|
||||
if converted := result.ChatResponse.ToBifrostTextCompletionResponse(); converted != nil {
|
||||
result = &schemas.BifrostResponse{TextCompletionResponse: converted}
|
||||
}
|
||||
}
|
||||
case schemas.ResponsesRequest:
|
||||
// chat→responses: convert responses stream chunk back to chat
|
||||
if result.ResponsesStreamResponse != nil {
|
||||
if converted := result.ResponsesStreamResponse.ToBifrostChatResponse(); converted != nil {
|
||||
result = &schemas.BifrostResponse{ChatResponse: converted}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return postHookRunner(ctx, result, bifrostErr)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user