226 lines
7.5 KiB
Go
226 lines
7.5 KiB
Go
// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
|
|
// This file contains common utility functions used across all handlers.
|
|
package handlers
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
|
"github.com/valyala/fasthttp"
|
|
)
|
|
|
|
// pluginDisabledKey is a dedicated context key type for marking a plugin as disabled
|
|
// rather than removed. Using a named type instead of a raw string follows Go best practices.
|
|
type pluginDisabledKey struct{}
|
|
|
|
// PluginDisabledKey is the context key used to indicate a plugin is being disabled.
|
|
var PluginDisabledKey pluginDisabledKey
|
|
|
|
// badRequestError wraps a client input validation error so that outer handlers
|
|
// can distinguish it from internal server errors and return HTTP 400.
|
|
type badRequestError struct{ err error }
|
|
|
|
func (e *badRequestError) Error() string { return e.err.Error() }
|
|
func (e *badRequestError) Unwrap() error { return e.err }
|
|
|
|
// SendJSON sends a JSON response with 200 OK status
|
|
func SendJSON(ctx *fasthttp.RequestCtx, data interface{}) {
|
|
ctx.SetContentType("application/json")
|
|
if err := json.NewEncoder(ctx).Encode(data); err != nil {
|
|
logger.Warn(fmt.Sprintf("Failed to encode JSON response: %v", err))
|
|
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to encode response: %v", err))
|
|
}
|
|
}
|
|
|
|
// SendJSONWithStatus sends a JSON response with a custom status code
|
|
func SendJSONWithStatus(ctx *fasthttp.RequestCtx, data interface{}, statusCode int) {
|
|
ctx.SetContentType("application/json")
|
|
ctx.SetStatusCode(statusCode)
|
|
if err := json.NewEncoder(ctx).Encode(data); err != nil {
|
|
logger.Warn(fmt.Sprintf("Failed to encode JSON response: %v", err))
|
|
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to encode response: %v", err))
|
|
}
|
|
}
|
|
|
|
// SendError sends a BifrostError response
|
|
func SendError(ctx *fasthttp.RequestCtx, statusCode int, message string) {
|
|
bifrostErr := &schemas.BifrostError{
|
|
IsBifrostError: false,
|
|
StatusCode: &statusCode,
|
|
Error: &schemas.ErrorField{
|
|
Message: message,
|
|
},
|
|
}
|
|
SendBifrostError(ctx, bifrostErr)
|
|
}
|
|
|
|
// SendBifrostError sends a BifrostError response
|
|
func SendBifrostError(ctx *fasthttp.RequestCtx, bifrostErr *schemas.BifrostError) {
|
|
if bifrostErr.StatusCode != nil {
|
|
ctx.SetStatusCode(*bifrostErr.StatusCode)
|
|
} else if !bifrostErr.IsBifrostError {
|
|
ctx.SetStatusCode(fasthttp.StatusBadRequest)
|
|
} else {
|
|
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
|
}
|
|
|
|
ctx.SetContentType("application/json")
|
|
if encodeErr := json.NewEncoder(ctx).Encode(bifrostErr); encodeErr != nil {
|
|
logger.Warn(fmt.Sprintf("Failed to encode error response: %v", encodeErr))
|
|
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
|
ctx.SetBodyString(fmt.Sprintf("Failed to encode error response: %v", encodeErr))
|
|
}
|
|
}
|
|
|
|
// streamLargeResponseIfActive checks if large response mode was activated by the provider
|
|
// and streams the response directly to the client. Returns true if the response was handled
|
|
// (caller should return), false if normal response handling should continue.
|
|
func streamLargeResponseIfActive(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext) bool {
|
|
isLargeResponse, ok := bifrostCtx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool)
|
|
if !ok || !isLargeResponse {
|
|
return false
|
|
}
|
|
if !lib.StreamLargeResponseBody(ctx, bifrostCtx) {
|
|
SendError(ctx, fasthttp.StatusInternalServerError, "Large response reader not available")
|
|
}
|
|
return true
|
|
}
|
|
|
|
// SendSSEError sends an error in Server-Sent Events format
|
|
func SendSSEError(ctx *fasthttp.RequestCtx, bifrostErr *schemas.BifrostError) {
|
|
errorJSON, err := json.Marshal(map[string]interface{}{
|
|
"error": bifrostErr,
|
|
})
|
|
if err != nil {
|
|
logger.Error("failed to marshal error for SSE: %v", err)
|
|
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if _, err := fmt.Fprintf(ctx, "data: %s\n\n", errorJSON); err != nil {
|
|
logger.Warn(fmt.Sprintf("Failed to write SSE error: %v", err))
|
|
}
|
|
}
|
|
|
|
// IsOriginAllowed checks if the given origin is allowed based on localhost rules and configured allowed origins.
|
|
// Localhost origins are always allowed. Additional origins can be configured in allowedOrigins.
|
|
// Supports wildcard patterns like *.example.com to match any subdomain.
|
|
func IsOriginAllowed(origin string, allowedOrigins []string) bool {
|
|
// Always allow localhost origins
|
|
if isLocalhostOrigin(origin) {
|
|
return true
|
|
}
|
|
|
|
// Check configured allowed origins
|
|
for _, allowedOrigin := range allowedOrigins {
|
|
// Check for exact match first
|
|
if allowedOrigin == origin {
|
|
return true
|
|
}
|
|
|
|
if allowedOrigin == "*" {
|
|
return true
|
|
}
|
|
|
|
// Check for wildcard pattern
|
|
if strings.Contains(allowedOrigin, "*") {
|
|
if matchesWildcardPattern(origin, allowedOrigin) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// isLocalhostOrigin checks if the given origin is a localhost origin
|
|
func isLocalhostOrigin(origin string) bool {
|
|
return strings.HasPrefix(origin, "http://localhost:") ||
|
|
strings.HasPrefix(origin, "https://localhost:") ||
|
|
strings.HasPrefix(origin, "http://127.0.0.1:") ||
|
|
strings.HasPrefix(origin, "http://0.0.0.0:") ||
|
|
strings.HasPrefix(origin, "https://127.0.0.1:")
|
|
}
|
|
|
|
// matchesWildcardPattern checks if an origin matches a wildcard pattern.
|
|
// Supports patterns like *.example.com, https://*.example.com, or http://*.example.com
|
|
func matchesWildcardPattern(origin string, pattern string) bool {
|
|
// Convert wildcard pattern to regex pattern
|
|
// Escape special regex characters except *
|
|
regexPattern := regexp.QuoteMeta(pattern)
|
|
// Replace escaped \* with regex pattern for subdomain matching
|
|
// \* should match one or more characters that are not dots (to match a subdomain)
|
|
regexPattern = strings.ReplaceAll(regexPattern, `\*`, `[^/.]+`)
|
|
// Anchor the pattern to match the entire origin
|
|
regexPattern = "^" + regexPattern + "$"
|
|
|
|
// Compile and test the regex
|
|
re, err := regexp.Compile(regexPattern)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
return re.MatchString(origin)
|
|
}
|
|
|
|
// ParseModel parses a model string in the format "provider/model" or "provider/nested/model"
|
|
// Returns the provider and full model name after the first slash
|
|
func ParseModel(model string) (string, string, error) {
|
|
model = strings.TrimSpace(model)
|
|
if model == "" {
|
|
return "", "", fmt.Errorf("model cannot be empty")
|
|
}
|
|
|
|
parts := strings.SplitN(model, "/", 2)
|
|
if len(parts) < 2 {
|
|
return "", "", fmt.Errorf("model must be in the format 'provider/model'")
|
|
}
|
|
|
|
provider := strings.TrimSpace(parts[0])
|
|
name := strings.TrimSpace(parts[1])
|
|
if provider == "" || name == "" {
|
|
return "", "", fmt.Errorf("model must be in the format 'provider/model' with non-empty provider and model")
|
|
}
|
|
return provider, name, nil
|
|
}
|
|
|
|
// ClampPaginationParams applies default/max bounds to limit and offset so that
|
|
// the handler response matches the values the store actually uses.
|
|
func ClampPaginationParams(limit, offset int) (int, int) {
|
|
if limit <= 0 {
|
|
limit = 25
|
|
} else if limit > 100 {
|
|
limit = 100
|
|
}
|
|
if offset < 0 {
|
|
offset = 0
|
|
}
|
|
return limit, offset
|
|
}
|
|
|
|
// fuzzyMatch checks if all characters in query appear in text in order (case-insensitive)
|
|
// Example: "gpt4" matches "gpt-4", "gpt-4-turbo", etc.
|
|
func fuzzyMatch(text, query string) bool {
|
|
if query == "" {
|
|
return true
|
|
}
|
|
|
|
text = strings.ToLower(text)
|
|
query = strings.ToLower(query)
|
|
|
|
queryIndex := 0
|
|
queryRunes := []rune(query)
|
|
|
|
for _, textChar := range text {
|
|
if queryIndex < len(queryRunes) && textChar == queryRunes[queryIndex] {
|
|
queryIndex++
|
|
}
|
|
}
|
|
|
|
return queryIndex == len(queryRunes)
|
|
}
|