first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

View File

@@ -0,0 +1,548 @@
package handlers
import (
"fmt"
"strconv"
"github.com/fasthttp/router"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/logstore"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
// --- HTTP Handler ---
// AsyncHandler handles async job HTTP endpoints.
type AsyncHandler struct {
client *bifrost.Bifrost
executor *logstore.AsyncJobExecutor
handlerStore lib.HandlerStore
config *lib.Config
}
// AsyncPathToTypeMapping maps exact paths to request types (only for non-parameterized paths)
// Parameterized paths are set per-route in RegisterRoutes
var AsyncPathToTypeMapping = map[string]schemas.RequestType{
"/v1/async/completions": schemas.TextCompletionRequest,
"/v1/async/chat/completions": schemas.ChatCompletionRequest,
"/v1/async/responses": schemas.ResponsesRequest,
"/v1/async/embeddings": schemas.EmbeddingRequest,
"/v1/async/audio/speech": schemas.SpeechRequest,
"/v1/async/audio/transcriptions": schemas.TranscriptionRequest,
"/v1/async/images/generations": schemas.ImageGenerationRequest,
"/v1/async/images/edits": schemas.ImageEditRequest,
"/v1/async/images/variations": schemas.ImageVariationRequest,
"/v1/async/rerank": schemas.RerankRequest,
"/v1/async/ocr": schemas.OCRRequest,
}
// RegisterAsyncRequestTypeMiddleware handles exact path matching for non-parameterized routes
func RegisterAsyncRequestTypeMiddleware(next fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
path := string(ctx.Path())
if requestType, ok := AsyncPathToTypeMapping[path]; ok {
ctx.SetUserValue(schemas.BifrostContextKeyHTTPRequestType, requestType)
}
next(ctx)
}
}
// NewAsyncHandler creates a new AsyncHandler.
// If the async job executor is not available (e.g., LogsStore or governance plugin not configured),
// the handler is created with a nil executor and RegisterRoutes will skip async route registration.
func NewAsyncHandler(client *bifrost.Bifrost, config *lib.Config) *AsyncHandler {
return &AsyncHandler{
client: client,
executor: config.GetAsyncJobExecutor(),
handlerStore: config,
config: config,
}
}
// RegisterRoutes registers async job endpoints.
func (h *AsyncHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
if h.executor == nil {
return // LogStore not configured, skip async routes
}
baseMiddlewares := append([]schemas.BifrostHTTPMiddleware{RegisterAsyncRequestTypeMiddleware}, middlewares...)
// Async submission endpoints (non-parameterized, request type set via AsyncPathToTypeMapping)
r.POST("/v1/async/completions", lib.ChainMiddlewares(h.asyncTextCompletion, baseMiddlewares...))
r.POST("/v1/async/chat/completions", lib.ChainMiddlewares(h.asyncChatCompletion, baseMiddlewares...))
r.POST("/v1/async/responses", lib.ChainMiddlewares(h.asyncResponses, baseMiddlewares...))
r.POST("/v1/async/embeddings", lib.ChainMiddlewares(h.asyncEmbeddings, baseMiddlewares...))
r.POST("/v1/async/audio/speech", lib.ChainMiddlewares(h.asyncSpeech, baseMiddlewares...))
r.POST("/v1/async/audio/transcriptions", lib.ChainMiddlewares(h.asyncTranscription, baseMiddlewares...))
r.POST("/v1/async/images/generations", lib.ChainMiddlewares(h.asyncImageGeneration, baseMiddlewares...))
r.POST("/v1/async/images/edits", lib.ChainMiddlewares(h.asyncImageEdit, baseMiddlewares...))
r.POST("/v1/async/images/variations", lib.ChainMiddlewares(h.asyncImageVariation, baseMiddlewares...))
r.POST("/v1/async/rerank", lib.ChainMiddlewares(h.asyncRerank, baseMiddlewares...))
r.POST("/v1/async/ocr", lib.ChainMiddlewares(h.asyncOCR, baseMiddlewares...))
// Async job retrieval endpoints
r.GET("/v1/async/completions/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.TextCompletionRequest), middlewares...))
r.GET("/v1/async/chat/completions/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.ChatCompletionRequest), middlewares...))
r.GET("/v1/async/responses/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.ResponsesRequest), middlewares...))
r.GET("/v1/async/embeddings/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.EmbeddingRequest), middlewares...))
r.GET("/v1/async/audio/speech/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.SpeechRequest), middlewares...))
r.GET("/v1/async/audio/transcriptions/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.TranscriptionRequest), middlewares...))
r.GET("/v1/async/images/generations/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.ImageGenerationRequest), middlewares...))
r.GET("/v1/async/images/edits/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.ImageEditRequest), middlewares...))
r.GET("/v1/async/images/variations/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.ImageVariationRequest), middlewares...))
r.GET("/v1/async/rerank/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.RerankRequest), middlewares...))
r.GET("/v1/async/ocr/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.OCRRequest), middlewares...))
}
// --- Async submission handlers ---
// asyncTextCompletion handles POST /v1/async/completions
func (h *AsyncHandler) asyncTextCompletion(ctx *fasthttp.RequestCtx) {
req, bifrostTextReq, err := prepareTextCompletionRequest(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
if req.Stream != nil && *req.Stream {
SendError(ctx, fasthttp.StatusBadRequest, "stream is not supported for async text completions")
return
}
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
defer cancel()
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
job, err := h.executor.SubmitJob(
bifrostCtx,
resultTTL,
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
return h.client.TextCompletionRequest(bgCtx, bifrostTextReq)
},
schemas.TextCompletionRequest,
)
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, err.Error())
return
}
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
}
// asyncChatCompletion handles POST /v1/async/chat/completions
func (h *AsyncHandler) asyncChatCompletion(ctx *fasthttp.RequestCtx) {
req, bifrostChatReq, err := prepareChatCompletionRequest(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
if req.Stream != nil && *req.Stream {
SendError(ctx, fasthttp.StatusBadRequest, "stream is not supported for async chat completions")
return
}
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
defer cancel()
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
job, err := h.executor.SubmitJob(
bifrostCtx,
resultTTL,
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
return h.client.ChatCompletionRequest(bgCtx, bifrostChatReq)
},
schemas.ChatCompletionRequest,
)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
}
// asyncResponses handles POST /v1/async/responses
func (h *AsyncHandler) asyncResponses(ctx *fasthttp.RequestCtx) {
req, bifrostResponsesReq, err := prepareResponsesRequest(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
if req.Stream != nil && *req.Stream {
SendError(ctx, fasthttp.StatusBadRequest, "stream is not supported for async responses")
return
}
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
defer cancel()
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
job, err := h.executor.SubmitJob(
bifrostCtx,
resultTTL,
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
return h.client.ResponsesRequest(bgCtx, bifrostResponsesReq)
},
schemas.ResponsesRequest,
)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Failed to create async job: %v", err))
return
}
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
}
// asyncEmbeddings handles POST /v1/async/embeddings
func (h *AsyncHandler) asyncEmbeddings(ctx *fasthttp.RequestCtx) {
_, bifrostEmbeddingReq, err := prepareEmbeddingRequest(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
defer cancel()
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
job, err := h.executor.SubmitJob(
bifrostCtx,
resultTTL,
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
return h.client.EmbeddingRequest(bgCtx, bifrostEmbeddingReq)
},
schemas.EmbeddingRequest,
)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
}
// asyncSpeech handles POST /v1/async/audio/speech
func (h *AsyncHandler) asyncSpeech(ctx *fasthttp.RequestCtx) {
req, bifrostSpeechReq, err := prepareSpeechRequest(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
if req.StreamFormat != nil && *req.StreamFormat == "sse" {
SendError(ctx, fasthttp.StatusBadRequest, "stream is not supported for async speech")
return
}
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
defer cancel()
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
job, err := h.executor.SubmitJob(
bifrostCtx,
resultTTL,
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
return h.client.SpeechRequest(bgCtx, bifrostSpeechReq)
},
schemas.SpeechRequest,
)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
}
// asyncTranscription handles POST /v1/async/audio/transcriptions
func (h *AsyncHandler) asyncTranscription(ctx *fasthttp.RequestCtx) {
bifrostTranscriptionReq, stream, err := prepareTranscriptionRequest(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
if stream {
SendError(ctx, fasthttp.StatusBadRequest, "stream is not supported for async transcriptions")
return
}
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
defer cancel()
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
job, err := h.executor.SubmitJob(
bifrostCtx,
resultTTL,
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
return h.client.TranscriptionRequest(bgCtx, bifrostTranscriptionReq)
},
schemas.TranscriptionRequest,
)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
}
// asyncImageGeneration handles POST /v1/async/images/generations
func (h *AsyncHandler) asyncImageGeneration(ctx *fasthttp.RequestCtx) {
req, bifrostReq, err := prepareImageGenerationRequest(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
if req.BifrostParams.Stream != nil && *req.BifrostParams.Stream {
SendError(ctx, fasthttp.StatusBadRequest, "stream is not supported for async image generations")
return
}
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
defer cancel()
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
job, err := h.executor.SubmitJob(
bifrostCtx,
resultTTL,
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
return h.client.ImageGenerationRequest(bgCtx, bifrostReq)
},
schemas.ImageGenerationRequest,
)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
}
// asyncImageEdit handles POST /v1/async/images/edits
func (h *AsyncHandler) asyncImageEdit(ctx *fasthttp.RequestCtx) {
req, bifrostReq, err := prepareImageEditRequest(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
if req.Stream != nil && *req.Stream {
SendError(ctx, fasthttp.StatusBadRequest, "stream is not supported for async image edits")
return
}
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
defer cancel()
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
job, err := h.executor.SubmitJob(
bifrostCtx,
resultTTL,
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
return h.client.ImageEditRequest(bgCtx, bifrostReq)
},
schemas.ImageEditRequest,
)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
}
// asyncImageVariation handles POST /v1/async/images/variations
func (h *AsyncHandler) asyncImageVariation(ctx *fasthttp.RequestCtx) {
bifrostReq, err := prepareImageVariationRequest(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
defer cancel()
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
job, err := h.executor.SubmitJob(
bifrostCtx,
resultTTL,
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
return h.client.ImageVariationRequest(bgCtx, bifrostReq)
},
schemas.ImageVariationRequest,
)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
}
// asyncRerank handles POST /v1/async/rerank
func (h *AsyncHandler) asyncRerank(ctx *fasthttp.RequestCtx) {
_, bifrostReq, err := prepareRerankRequest(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context")
return
}
defer cancel()
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
job, err := h.executor.SubmitJob(
bifrostCtx,
resultTTL,
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
return h.client.RerankRequest(bgCtx, bifrostReq)
},
schemas.RerankRequest,
)
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, err.Error())
return
}
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
}
// asyncOCR handles POST /v1/async/ocr
func (h *AsyncHandler) asyncOCR(ctx *fasthttp.RequestCtx) {
_, bifrostReq, err := prepareOCRRequest(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context")
return
}
defer cancel()
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
job, err := h.executor.SubmitJob(
bifrostCtx,
resultTTL,
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
return h.client.OCRRequest(bgCtx, bifrostReq)
},
schemas.OCRRequest,
)
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, err.Error())
return
}
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
}
// --- Job retrieval handler ---
// getJob handles GET /v1/async/{type}/{job_id}
func (h *AsyncHandler) getJob(operationType schemas.RequestType) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
jobID, ok := ctx.UserValue("job_id").(string)
if !ok || jobID == "" {
SendError(ctx, fasthttp.StatusBadRequest, "job_id is required")
return
}
// Get the requesting user's VK for auth check
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
defer cancel()
job, err := h.executor.RetrieveJob(bifrostCtx, jobID, getVirtualKeyFromContext(bifrostCtx), operationType)
if err != nil {
SendError(ctx, fasthttp.StatusNotFound, err.Error())
return
}
resp := job.ToResponse()
// Return 202 for pending/processing, 200 for completed/failed
switch job.Status {
case schemas.AsyncJobStatusPending, schemas.AsyncJobStatusProcessing:
SendJSONWithStatus(ctx, resp, fasthttp.StatusAccepted)
default:
SendJSON(ctx, resp)
}
}
}
// --- Helper functions ---
// getVirtualKeyFromContext extracts the virtual key value from context.
// Returns nil if no VK is present (e.g., direct key mode or no governance).
func getVirtualKeyFromContext(ctx *schemas.BifrostContext) *string {
vkValue := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey)
if vkValue == "" {
return nil
}
return &vkValue
}
func getResultTTLFromHeaderWithDefault(ctx *fasthttp.RequestCtx, defaultTTL int) int {
resultTTL := string(ctx.Request.Header.Peek(schemas.AsyncHeaderResultTTL))
if resultTTL == "" {
return defaultTTL
}
resultTTLInt, err := strconv.Atoi(resultTTL)
if err != nil || resultTTLInt < 0 {
return defaultTTL
}
return resultTTLInt
}

View File

@@ -0,0 +1,61 @@
package handlers
import (
"github.com/fasthttp/router"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/plugins/semanticcache"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
type CacheHandler struct {
plugin *semanticcache.Plugin
}
func NewCacheHandler(plugin schemas.LLMPlugin) *CacheHandler {
semanticCachePlugin, ok := plugin.(*semanticcache.Plugin)
if !ok {
logger.Fatal("Cache handler requires a semantic cache plugin")
}
return &CacheHandler{
plugin: semanticCachePlugin,
}
}
func (h *CacheHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
r.DELETE("/api/cache/clear/{requestId}", lib.ChainMiddlewares(h.clearCache, middlewares...))
r.DELETE("/api/cache/clear-by-key/{cacheKey}", lib.ChainMiddlewares(h.clearCacheByKey, middlewares...))
}
func (h *CacheHandler) clearCache(ctx *fasthttp.RequestCtx) {
requestID, ok := ctx.UserValue("requestId").(string)
if !ok {
SendError(ctx, fasthttp.StatusBadRequest, "Invalid request ID")
return
}
if err := h.plugin.ClearCacheForRequestID(requestID); err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to clear cache")
return
}
SendJSON(ctx, map[string]any{
"message": "Cache cleared successfully",
})
}
func (h *CacheHandler) clearCacheByKey(ctx *fasthttp.RequestCtx) {
cacheKey, ok := ctx.UserValue("cacheKey").(string)
if !ok {
SendError(ctx, fasthttp.StatusBadRequest, "Invalid cache key")
return
}
if err := h.plugin.ClearCacheForKey(cacheKey); err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to clear cache")
return
}
SendJSON(ctx, map[string]any{
"message": "Cache cleared successfully",
})
}

View File

@@ -0,0 +1,887 @@
package handlers
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"slices"
"strings"
"github.com/fasthttp/router"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/network"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework"
"github.com/maximhq/bifrost/framework/configstore"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/maximhq/bifrost/framework/encrypt"
"github.com/maximhq/bifrost/framework/modelcatalog"
"github.com/maximhq/bifrost/plugins/compat"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
// securityHeaders is the list of headers that cannot be configured in allowlist/denylist
// These headers are always blocked for security reasons regardless of user configuration
var securityHeaders = []string{
"authorization",
"proxy-authorization",
"cookie",
"host",
"content-length",
"connection",
"transfer-encoding",
"x-api-key",
"x-goog-api-key",
"x-bf-api-key",
"x-bf-vk",
}
// ConfigManager is the interface for the config manager
type ConfigManager interface {
UpdateAuthConfig(ctx context.Context, authConfig *configstore.AuthConfig) error
ReloadClientConfigFromConfigStore(ctx context.Context) error
UpdateSyncConfig(ctx context.Context) error
ForceReloadPricing(ctx context.Context) error
UpdateDropExcessRequests(ctx context.Context, value bool)
UpdateMCPToolManagerConfig(ctx context.Context, maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string, disableAutoToolInject bool) error
ReloadPlugin(ctx context.Context, name string, path *string, pluginConfig any, placement *schemas.PluginPlacement, order *int) error
RemovePlugin(ctx context.Context, name string) error
ReloadProxyConfig(ctx context.Context, config *configstoreTables.GlobalProxyConfig) error
ReloadHeaderFilterConfig(ctx context.Context, config *configstoreTables.GlobalHeaderFilterConfig) error
}
// ConfigHandler manages runtime configuration updates for Bifrost.
// It provides endpoints to update and retrieve settings persisted via the ConfigStore backed by sql database.
type ConfigHandler struct {
store *lib.Config
configManager ConfigManager
}
// NewConfigHandler creates a new handler for configuration management.
// It requires the Bifrost client, a logger, and the config store.
func NewConfigHandler(configManager ConfigManager, store *lib.Config) *ConfigHandler {
return &ConfigHandler{
configManager: configManager,
store: store,
}
}
// RegisterRoutes registers the configuration-related routes.
// It adds the `PUT /api/config` endpoint.
func (h *ConfigHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
r.GET("/api/config", lib.ChainMiddlewares(h.getConfig, middlewares...))
r.PUT("/api/config", lib.ChainMiddlewares(h.updateConfig, middlewares...))
r.GET("/api/version", lib.ChainMiddlewares(h.getVersion, middlewares...))
r.GET("/api/proxy-config", lib.ChainMiddlewares(h.getProxyConfig, middlewares...))
r.PUT("/api/proxy-config", lib.ChainMiddlewares(h.updateProxyConfig, middlewares...))
r.POST("/api/pricing/force-sync", lib.ChainMiddlewares(h.forceSyncPricing, middlewares...))
}
// getVersion handles GET /api/version - Get the current version
func (h *ConfigHandler) getVersion(ctx *fasthttp.RequestCtx) {
SendJSON(ctx, version)
}
// getConfig handles GET /config - Get the current configuration
func (h *ConfigHandler) getConfig(ctx *fasthttp.RequestCtx) {
mapConfig := make(map[string]any)
if query := string(ctx.QueryArgs().Peek("from_db")); query == "true" {
if h.store.ConfigStore == nil {
SendError(ctx, fasthttp.StatusServiceUnavailable, "config store not available")
return
}
cc, err := h.store.ConfigStore.GetClientConfig(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError,
fmt.Sprintf("failed to fetch config from db: %v", err))
return
}
if cc != nil {
mapConfig["client_config"] = *cc
}
// Fetching framework config
fc, err := h.store.ConfigStore.GetFrameworkConfig(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to fetch framework config from db: %v", err))
return
}
normalizedFrameworkConfig, _, _ := lib.ResolveFrameworkPricingConfig(fc, nil)
mapConfig["framework_config"] = *normalizedFrameworkConfig
} else {
mapConfig["client_config"] = h.store.ClientConfig
normalizedFrameworkConfig, _, _ := lib.ResolveFrameworkPricingConfig(nil, h.store.FrameworkConfig)
mapConfig["framework_config"] = *normalizedFrameworkConfig
}
if h.store.ConfigStore != nil {
// Fetching governance config
authConfig, err := h.store.ConfigStore.GetAuthConfig(ctx)
if err != nil {
logger.Warn("failed to get auth config from store: %v", err)
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to get auth config from store: %v", err))
return
}
// Getting username and password from auth config
// This username password is for the dashboard authentication
if authConfig != nil {
// For password, return EnvVar structure with redacted value
// If from env, preserve env_var reference but clear value
// If not from env, show <redacted> as the value
var passwordEnvVar *schemas.EnvVar
if authConfig.AdminPassword != nil && authConfig.AdminPassword.IsFromEnv() {
passwordEnvVar = &schemas.EnvVar{
Val: "",
EnvVar: authConfig.AdminPassword.EnvVar,
FromEnv: true,
}
} else {
passwordEnvVar = &schemas.EnvVar{
Val: "<redacted>",
EnvVar: "",
FromEnv: false,
}
}
mapConfig["auth_config"] = map[string]any{
"admin_username": authConfig.AdminUserName,
"admin_password": passwordEnvVar,
"is_enabled": authConfig.IsEnabled,
"disable_auth_on_inference": authConfig.DisableAuthOnInference,
}
} else {
// No auth config exists yet, return default empty EnvVar values
mapConfig["auth_config"] = map[string]any{
"admin_username": &schemas.EnvVar{Val: "", EnvVar: "", FromEnv: false},
"admin_password": &schemas.EnvVar{Val: "", EnvVar: "", FromEnv: false},
"is_enabled": false,
"disable_auth_on_inference": false,
}
}
} else {
mapConfig["auth_config"] = map[string]any{
"admin_username": &schemas.EnvVar{Val: "", EnvVar: "", FromEnv: false},
"admin_password": &schemas.EnvVar{Val: "", EnvVar: "", FromEnv: false},
"is_enabled": false,
"disable_auth_on_inference": false,
}
}
mapConfig["is_db_connected"] = h.store.ConfigStore != nil
mapConfig["is_cache_connected"] = h.store.VectorStore != nil
mapConfig["is_logs_connected"] = h.store.LogsStore != nil
// Fetching proxy config
if h.store.ConfigStore != nil {
proxyConfig, err := h.store.ConfigStore.GetProxyConfig(ctx)
if err != nil {
logger.Warn("failed to get proxy config from store: %v", err)
} else if proxyConfig != nil {
// Redact password if present
if proxyConfig.Password != "" {
proxyConfig.Password = "<redacted>"
}
mapConfig["proxy_config"] = proxyConfig
}
// Fetching restart required config
restartConfig, err := h.store.ConfigStore.GetRestartRequiredConfig(ctx)
if err != nil {
logger.Warn("failed to get restart required config from store: %v", err)
} else if restartConfig != nil {
mapConfig["restart_required"] = restartConfig
}
}
SendJSON(ctx, mapConfig)
}
// updateConfig updates the core configuration settings.
// Currently, it supports hot-reloading of the `drop_excess_requests` setting.
// Note that settings like `prometheus_labels` cannot be changed at runtime.
func (h *ConfigHandler) updateConfig(ctx *fasthttp.RequestCtx) {
if h.store.ConfigStore == nil {
SendError(ctx, fasthttp.StatusInternalServerError, "Config store not initialized")
return
}
payload := struct {
ClientConfig configstore.ClientConfig `json:"client_config"`
FrameworkConfig configstoreTables.TableFrameworkConfig `json:"framework_config"`
AuthConfig *configstore.AuthConfig `json:"auth_config"`
}{}
if err := json.Unmarshal(ctx.PostBody(), &payload); err != nil {
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err))
return
}
// Validating framework config
if payload.FrameworkConfig.PricingURL != nil && *payload.FrameworkConfig.PricingURL != modelcatalog.DefaultPricingURL {
// Checking the accessibility of the pricing URL
resp, err := http.Get(*payload.FrameworkConfig.PricingURL)
if err != nil {
logger.Warn("failed to check the accessibility of the pricing URL: %v", err)
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to check the accessibility of the pricing URL: %v", err))
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
logger.Warn("failed to check the accessibility of the pricing URL: %v", resp.StatusCode)
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to check the accessibility of the pricing URL: %v", resp.StatusCode))
return
}
}
// Checking the pricing sync interval
if payload.FrameworkConfig.PricingSyncInterval != nil && *payload.FrameworkConfig.PricingSyncInterval <= 0 {
logger.Warn("pricing sync interval must be greater than 0")
SendError(ctx, fasthttp.StatusBadRequest, "pricing sync interval must be greater than 0")
return
}
// Get current config with proper locking
currentConfig := h.store.ClientConfig
updatedConfig := currentConfig
var restartReasons []string
if payload.ClientConfig.DropExcessRequests != currentConfig.DropExcessRequests {
h.configManager.UpdateDropExcessRequests(ctx, payload.ClientConfig.DropExcessRequests)
updatedConfig.DropExcessRequests = payload.ClientConfig.DropExcessRequests
}
if payload.ClientConfig.MCPCodeModeBindingLevel != "" {
if payload.ClientConfig.MCPCodeModeBindingLevel != string(schemas.CodeModeBindingLevelServer) && payload.ClientConfig.MCPCodeModeBindingLevel != string(schemas.CodeModeBindingLevelTool) {
logger.Warn("mcp_code_mode_binding_level must be 'server' or 'tool'")
SendError(ctx, fasthttp.StatusBadRequest, "mcp_code_mode_binding_level must be 'server' or 'tool'")
return
}
}
shouldReloadMCPToolManagerConfig := false
// Only process MCPAgentDepth if explicitly provided (> 0) and different from current
if payload.ClientConfig.MCPAgentDepth > 0 && payload.ClientConfig.MCPAgentDepth != currentConfig.MCPAgentDepth {
updatedConfig.MCPAgentDepth = payload.ClientConfig.MCPAgentDepth
shouldReloadMCPToolManagerConfig = true
}
// Only process MCPToolExecutionTimeout if explicitly provided (> 0) and different from current
if payload.ClientConfig.MCPToolExecutionTimeout > 0 && payload.ClientConfig.MCPToolExecutionTimeout != currentConfig.MCPToolExecutionTimeout {
updatedConfig.MCPToolExecutionTimeout = payload.ClientConfig.MCPToolExecutionTimeout
shouldReloadMCPToolManagerConfig = true
}
if payload.ClientConfig.MCPCodeModeBindingLevel != "" && payload.ClientConfig.MCPCodeModeBindingLevel != currentConfig.MCPCodeModeBindingLevel {
updatedConfig.MCPCodeModeBindingLevel = payload.ClientConfig.MCPCodeModeBindingLevel
shouldReloadMCPToolManagerConfig = true
}
if payload.ClientConfig.MCPDisableAutoToolInject != currentConfig.MCPDisableAutoToolInject {
updatedConfig.MCPDisableAutoToolInject = payload.ClientConfig.MCPDisableAutoToolInject
shouldReloadMCPToolManagerConfig = true
}
// Reload MCP tool manager config with all current values in one call
if shouldReloadMCPToolManagerConfig && h.store.MCPConfig != nil {
if err := h.configManager.UpdateMCPToolManagerConfig(ctx, updatedConfig.MCPAgentDepth, updatedConfig.MCPToolExecutionTimeout, updatedConfig.MCPCodeModeBindingLevel, updatedConfig.MCPDisableAutoToolInject); err != nil {
logger.Warn(fmt.Sprintf("failed to update mcp tool manager config: %v", err))
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to update mcp tool manager config: %v", err))
return
}
}
if !slices.Equal(payload.ClientConfig.PrometheusLabels, currentConfig.PrometheusLabels) {
updatedConfig.PrometheusLabels = payload.ClientConfig.PrometheusLabels
restartReasons = append(restartReasons, "Prometheus labels")
}
if !slices.Equal(payload.ClientConfig.AllowedOrigins, currentConfig.AllowedOrigins) {
updatedConfig.AllowedOrigins = payload.ClientConfig.AllowedOrigins
restartReasons = append(restartReasons, "Allowed origins")
}
if !slices.Equal(payload.ClientConfig.AllowedHeaders, currentConfig.AllowedHeaders) {
updatedConfig.AllowedHeaders = payload.ClientConfig.AllowedHeaders
restartReasons = append(restartReasons, "Allowed headers")
}
// Only update InitialPoolSize if explicitly provided (> 0) to avoid clearing stored value
if payload.ClientConfig.InitialPoolSize > 0 {
if payload.ClientConfig.InitialPoolSize != currentConfig.InitialPoolSize {
restartReasons = append(restartReasons, "Initial pool size")
}
updatedConfig.InitialPoolSize = payload.ClientConfig.InitialPoolSize
}
if payload.ClientConfig.EnableLogging != nil {
payloadLogging := *payload.ClientConfig.EnableLogging
currentLogging := currentConfig.EnableLogging == nil || *currentConfig.EnableLogging
if payloadLogging != currentLogging {
restartReasons = append(restartReasons, "Logging changed")
}
updatedConfig.EnableLogging = payload.ClientConfig.EnableLogging
}
if payload.ClientConfig.DisableContentLogging != currentConfig.DisableContentLogging {
restartReasons = append(restartReasons, "Content logging")
}
updatedConfig.DisableContentLogging = payload.ClientConfig.DisableContentLogging
updatedConfig.DisableDBPingsInHealth = payload.ClientConfig.DisableDBPingsInHealth
updatedConfig.AllowDirectKeys = payload.ClientConfig.AllowDirectKeys
updatedConfig.EnforceAuthOnInference = payload.ClientConfig.EnforceAuthOnInference
// Sync deprecated columns to match new field so they stay consistent in the DB
updatedConfig.EnforceGovernanceHeader = payload.ClientConfig.EnforceAuthOnInference
updatedConfig.EnforceSCIMAuth = payload.ClientConfig.EnforceAuthOnInference
// Only update MaxRequestBodySizeMB if explicitly provided (> 0) to avoid clearing stored value
if payload.ClientConfig.MaxRequestBodySizeMB > 0 {
if payload.ClientConfig.MaxRequestBodySizeMB != currentConfig.MaxRequestBodySizeMB {
restartReasons = append(restartReasons, "Max request body size")
}
updatedConfig.MaxRequestBodySizeMB = payload.ClientConfig.MaxRequestBodySizeMB
}
// Handle compat plugin toggle
newCompat := payload.ClientConfig.Compat
oldCompat := currentConfig.Compat
if newCompat != oldCompat {
newEnabled := newCompat.ConvertTextToChat || newCompat.ConvertChatToResponses || newCompat.ShouldDropParams || newCompat.ShouldConvertParams
if newEnabled {
compatCfg := &compat.Config{
ConvertTextToChat: newCompat.ConvertTextToChat,
ConvertChatToResponses: newCompat.ConvertChatToResponses,
ShouldDropParams: newCompat.ShouldDropParams,
ShouldConvertParams: newCompat.ShouldConvertParams,
}
if err := h.configManager.ReloadPlugin(ctx, compat.PluginName, nil, compatCfg, nil, nil); err != nil {
logger.Warn("failed to load compat plugin: %v", err)
SendError(ctx, 400, "Failed to load compat plugin")
return
}
} else {
disabledCtx := context.WithValue(ctx, PluginDisabledKey, true)
if err := h.configManager.RemovePlugin(disabledCtx, compat.PluginName); err != nil {
logger.Warn("failed to remove compat plugin: %v", err)
SendError(ctx, 400, "Failed to remove compat plugin")
return
}
}
}
updatedConfig.Compat = newCompat
// Only update MCP fields if explicitly provided (non-zero) to avoid clearing stored values
if payload.ClientConfig.MCPAgentDepth > 0 {
updatedConfig.MCPAgentDepth = payload.ClientConfig.MCPAgentDepth
}
if payload.ClientConfig.MCPToolExecutionTimeout > 0 {
updatedConfig.MCPToolExecutionTimeout = payload.ClientConfig.MCPToolExecutionTimeout
}
// Only update MCPCodeModeBindingLevel if payload is non-empty to avoid clearing stored value
if payload.ClientConfig.MCPCodeModeBindingLevel != "" {
updatedConfig.MCPCodeModeBindingLevel = payload.ClientConfig.MCPCodeModeBindingLevel
}
// Only update AsyncJobResultTTL if explicitly provided (> 0) to avoid clearing stored value
if payload.ClientConfig.AsyncJobResultTTL > 0 {
updatedConfig.AsyncJobResultTTL = payload.ClientConfig.AsyncJobResultTTL
}
// Handle RequiredHeaders changes (no restart needed - governance plugin reads via pointer)
updatedConfig.RequiredHeaders = payload.ClientConfig.RequiredHeaders
// Handle LoggingHeaders changes (no restart needed - logging plugin reads via pointer)
updatedConfig.LoggingHeaders = payload.ClientConfig.LoggingHeaders
// Handle WhitelistedRoutes changes (updated dynamically via AuthMiddleware)
updatedConfig.WhitelistedRoutes = payload.ClientConfig.WhitelistedRoutes
// Toggle whether deleted virtual keys should appear in logs filter data.
updatedConfig.HideDeletedVirtualKeysInFilters = payload.ClientConfig.HideDeletedVirtualKeysInFilters
// No restart needed - routing engine reads via pointer, change is effective immediately.
if payload.ClientConfig.RoutingChainMaxDepth > 0 {
updatedConfig.RoutingChainMaxDepth = payload.ClientConfig.RoutingChainMaxDepth
}
// Handle HeaderFilterConfig changes
if !headerFilterConfigEqual(payload.ClientConfig.HeaderFilterConfig, currentConfig.HeaderFilterConfig) {
// Validate that no security headers are in the allowlist or denylist
if err := validateHeaderFilterConfig(payload.ClientConfig.HeaderFilterConfig); err != nil {
logger.Warn("invalid header filter config: %v", err)
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
updatedConfig.HeaderFilterConfig = payload.ClientConfig.HeaderFilterConfig
if err := h.configManager.ReloadHeaderFilterConfig(ctx, payload.ClientConfig.HeaderFilterConfig); err != nil {
logger.Warn("failed to reload header filter config: %v", err)
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to reload header filter config: %v", err))
return
}
}
// Validate LogRetentionDays
if payload.ClientConfig.LogRetentionDays < 1 {
logger.Warn("log_retention_days must be at least 1")
SendError(ctx, fasthttp.StatusBadRequest, "log_retention_days must be at least 1")
return
}
updatedConfig.LogRetentionDays = payload.ClientConfig.LogRetentionDays
// Update the store with the new config
h.store.ClientConfig = updatedConfig
if err := h.store.ConfigStore.UpdateClientConfig(ctx, updatedConfig); err != nil {
logger.Warn("failed to save configuration: %v", err)
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to save configuration: %v", err))
return
}
// Reloading client config from config store
if err := h.configManager.ReloadClientConfigFromConfigStore(ctx); err != nil {
logger.Warn("failed to reload client config from config store: %v", err)
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to reload client config from config store: %v", err))
return
}
// Fetching existing framework config
frameworkConfig, err := h.store.ConfigStore.GetFrameworkConfig(ctx)
if err != nil {
logger.Warn("failed to get framework config from store: %v", err)
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to get framework config from store: %v", err))
return
}
// if framework config is nil, we will use the default pricing config
if frameworkConfig == nil {
frameworkConfig = &configstoreTables.TableFrameworkConfig{
ID: 0,
PricingURL: bifrost.Ptr(modelcatalog.DefaultPricingURL),
PricingSyncInterval: bifrost.Ptr(int64(modelcatalog.DefaultSyncInterval.Seconds())),
}
}
// Handling individual nil cases
if frameworkConfig.PricingURL == nil {
frameworkConfig.PricingURL = bifrost.Ptr(modelcatalog.DefaultPricingURL)
}
if frameworkConfig.PricingSyncInterval == nil {
frameworkConfig.PricingSyncInterval = bifrost.Ptr(int64(modelcatalog.DefaultSyncInterval.Seconds()))
}
// Updating framework config
shouldReloadFrameworkConfig := false
if payload.FrameworkConfig.PricingURL != nil && *payload.FrameworkConfig.PricingURL != *frameworkConfig.PricingURL {
// Checking the accessibility of the pricing URL
resp, err := http.Get(*payload.FrameworkConfig.PricingURL)
if err != nil {
logger.Warn("failed to check the accessibility of the pricing URL: %v", err)
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to check the accessibility of the pricing URL: %v", err))
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
logger.Warn("failed to check the accessibility of the pricing URL: %v", resp.StatusCode)
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to check the accessibility of the pricing URL: %v", resp.StatusCode))
return
}
frameworkConfig.PricingURL = payload.FrameworkConfig.PricingURL
shouldReloadFrameworkConfig = true
}
if payload.FrameworkConfig.PricingSyncInterval != nil {
syncInterval := int64(*payload.FrameworkConfig.PricingSyncInterval)
if syncInterval != *frameworkConfig.PricingSyncInterval {
frameworkConfig.PricingSyncInterval = &syncInterval
shouldReloadFrameworkConfig = true
}
}
// Reload config if required
if shouldReloadFrameworkConfig {
var syncSeconds int64
if frameworkConfig.PricingSyncInterval != nil {
syncSeconds = *frameworkConfig.PricingSyncInterval
} else {
syncSeconds = int64(modelcatalog.DefaultSyncInterval.Seconds())
}
h.store.FrameworkConfig = &framework.FrameworkConfig{
Pricing: &modelcatalog.Config{
PricingURL: frameworkConfig.PricingURL,
PricingSyncInterval: &syncSeconds,
},
}
// Saving framework config
if err := h.store.ConfigStore.UpdateFrameworkConfig(ctx, frameworkConfig); err != nil {
logger.Warn("failed to save framework configuration: %v", err)
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to save framework configuration: %v", err))
return
}
// Reloading pricing manager
h.configManager.UpdateSyncConfig(ctx)
}
// Checking auth config and trying to update if required
if payload.AuthConfig != nil {
// Getting current governance config
authConfig, err := h.store.ConfigStore.GetAuthConfig(ctx)
if err != nil {
if !errors.Is(err, configstore.ErrNotFound) {
logger.Warn("failed to get auth config from store: %v", err)
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to get auth config from store: %v", err))
return
}
}
// Check if auth config has changed
authChanged := false
if authConfig == nil {
// No existing config, any enabled state is a change
if payload.AuthConfig.IsEnabled {
authChanged = true
}
} else {
// Compare with existing config using value comparison (not pointer comparison)
// Password is considered changed only if it's NOT redacted and has a value
// (IsRedacted() returns true for <redacted>, asterisk patterns, and env var references)
passwordChanged := payload.AuthConfig.AdminPassword != nil &&
!payload.AuthConfig.AdminPassword.IsRedacted() &&
payload.AuthConfig.AdminPassword.GetValue() != ""
usernameChanged := payload.AuthConfig.AdminUserName != nil &&
!payload.AuthConfig.AdminUserName.Equals(authConfig.AdminUserName)
if payload.AuthConfig.IsEnabled != authConfig.IsEnabled ||
usernameChanged ||
passwordChanged {
authChanged = true
}
}
if payload.AuthConfig.IsEnabled {
// Initialize nil pointers to empty EnvVar to prevent nil-pointer dereference
if payload.AuthConfig.AdminUserName == nil {
payload.AuthConfig.AdminUserName = &schemas.EnvVar{}
}
if payload.AuthConfig.AdminPassword == nil {
payload.AuthConfig.AdminPassword = &schemas.EnvVar{}
}
// Validate env variables are set if referenced
if payload.AuthConfig.AdminUserName.IsFromEnv() && payload.AuthConfig.AdminUserName.GetValue() == "" {
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("environment variable %s is not set", payload.AuthConfig.AdminUserName.EnvVar))
return
}
if payload.AuthConfig.AdminPassword.IsFromEnv() && payload.AuthConfig.AdminPassword.GetValue() == "" {
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("environment variable %s is not set", payload.AuthConfig.AdminPassword.EnvVar))
return
}
if authConfig == nil && (payload.AuthConfig.AdminUserName.GetValue() == "" || payload.AuthConfig.AdminPassword.GetValue() == "") {
SendError(ctx, fasthttp.StatusBadRequest, "auth username and password must be provided")
return
}
// Fetching current Auth config
if payload.AuthConfig.AdminUserName.GetValue() != "" {
if payload.AuthConfig.AdminPassword.IsRedacted() {
if authConfig == nil || authConfig.AdminPassword.GetValue() == "" {
SendError(ctx, fasthttp.StatusBadRequest, "auth password must be provided")
return
}
// Assuming that password hasn't been changed
payload.AuthConfig.AdminPassword = authConfig.AdminPassword
} else {
// Password has been changed
// We will hash the password
hashedPassword, err := encrypt.Hash(payload.AuthConfig.AdminPassword.GetValue())
if err != nil {
logger.Warn("failed to hash password: %v", err)
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to hash password: %v", err))
return
}
// Preserve env-var metadata when storing hashed password
payload.AuthConfig.AdminPassword = &schemas.EnvVar{
Val: hashedPassword,
FromEnv: payload.AuthConfig.AdminPassword.IsFromEnv(),
EnvVar: payload.AuthConfig.AdminPassword.EnvVar,
}
}
}
// Save auth config - this handles both first-time creation and updates
err = h.configManager.UpdateAuthConfig(ctx, payload.AuthConfig)
if err != nil {
logger.Warn("failed to update auth config: %v", err)
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to update auth config: %v", err))
return
}
} else if authConfig != nil {
// Auth is being disabled but there's an existing config - preserve credentials and update disabled state
if payload.AuthConfig.AdminPassword == nil || payload.AuthConfig.AdminPassword.IsRedacted() || payload.AuthConfig.AdminPassword.GetValue() == "" {
payload.AuthConfig.AdminPassword = authConfig.AdminPassword
}
if payload.AuthConfig.AdminUserName == nil || payload.AuthConfig.AdminUserName.GetValue() == "" {
payload.AuthConfig.AdminUserName = authConfig.AdminUserName
}
err = h.configManager.UpdateAuthConfig(ctx, payload.AuthConfig)
if err != nil {
logger.Warn("failed to update auth config: %v", err)
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to update auth config: %v", err))
return
}
}
// Flush all existing sessions if auth details have been changed
if authChanged {
if err := h.store.ConfigStore.FlushSessions(ctx); err != nil {
logger.Warn("updated auth config but failed to flush existing sessions, please restart the server: %v", err)
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("updated auth config but failed to flush existing sessions, please restart the server: %v", err))
return
}
}
// Note: AuthMiddleware is updated via ServerCallbacks.UpdateAuthConfig (handled by BifrostHTTPServer)
}
// Set restart required flag if any restart-requiring configs changed
if len(restartReasons) > 0 {
reason := fmt.Sprintf("%s settings have been updated. A restart is required for changes to take full effect.", strings.Join(restartReasons, ", "))
if err := h.store.ConfigStore.SetRestartRequiredConfig(ctx, &configstoreTables.RestartRequiredConfig{
Required: true,
Reason: reason,
}); err != nil {
logger.Warn("failed to set restart required config: %v", err)
}
}
ctx.SetStatusCode(fasthttp.StatusOK)
SendJSON(ctx, map[string]any{
"status": "success",
"message": "configuration updated successfully",
})
}
// forceSyncPricing triggers an immediate pricing sync and resets the pricing sync timer
func (h *ConfigHandler) forceSyncPricing(ctx *fasthttp.RequestCtx) {
if h.store.ConfigStore == nil {
SendError(ctx, fasthttp.StatusServiceUnavailable, "config store not available")
return
}
if err := h.configManager.ForceReloadPricing(ctx); err != nil {
logger.Warn("failed to force pricing sync: %v", err)
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to force pricing sync: %v", err))
return
}
ctx.SetStatusCode(fasthttp.StatusOK)
SendJSON(ctx, map[string]any{
"status": "success",
"message": "pricing sync triggered",
})
}
// getProxyConfig handles GET /api/proxy-config - Get the current proxy configuration
func (h *ConfigHandler) getProxyConfig(ctx *fasthttp.RequestCtx) {
if h.store.ConfigStore == nil {
SendError(ctx, fasthttp.StatusServiceUnavailable, "config store not available")
return
}
proxyConfig, err := h.store.ConfigStore.GetProxyConfig(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to get proxy config: %v", err))
return
}
if proxyConfig == nil {
// Return default empty config
SendJSON(ctx, configstoreTables.GlobalProxyConfig{
Enabled: false,
Type: network.GlobalProxyTypeHTTP,
})
return
}
// Redact password if present
if proxyConfig.Password != "" {
proxyConfig.Password = "<redacted>"
}
SendJSON(ctx, proxyConfig)
}
// updateProxyConfig handles PUT /api/proxy-config - Update the proxy configuration
func (h *ConfigHandler) updateProxyConfig(ctx *fasthttp.RequestCtx) {
if h.store.ConfigStore == nil {
SendError(ctx, fasthttp.StatusInternalServerError, "config store not initialized")
return
}
var payload configstoreTables.GlobalProxyConfig
if err := json.Unmarshal(ctx.PostBody(), &payload); err != nil {
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("invalid request format: %v", err))
return
}
// Validate proxy config
if payload.Enabled {
// Validate proxy type
switch payload.Type {
case network.GlobalProxyTypeHTTP:
// HTTP proxy is supported
// Make sure the URL is provided
if payload.URL == "" {
SendError(ctx, fasthttp.StatusBadRequest, "proxy URL is required when proxy is enabled")
return
}
// Validate timeout if provided
if payload.Timeout < 0 {
SendError(ctx, fasthttp.StatusBadRequest, "proxy timeout must be non-negative")
return
}
case network.GlobalProxyTypeSOCKS5, network.GlobalProxyTypeTCP:
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("proxy type %s is not yet supported", payload.Type))
return
default:
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("invalid proxy type: %s", payload.Type))
return
}
// Validate URL is provided when enabled
if payload.URL == "" {
SendError(ctx, fasthttp.StatusBadRequest, "proxy URL is required when proxy is enabled")
return
}
// Validate timeout if provided
if payload.Timeout < 0 {
SendError(ctx, fasthttp.StatusBadRequest, "proxy timeout must be non-negative")
return
}
}
// Handle password - if it's "<redacted>", keep the existing password
if payload.Password == "<redacted>" {
existingConfig, err := h.store.ConfigStore.GetProxyConfig(ctx)
if err != nil && !errors.Is(err, configstore.ErrNotFound) {
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to get existing proxy config: %v", err))
return
}
if existingConfig != nil {
payload.Password = existingConfig.Password
} else {
payload.Password = ""
}
}
// Save proxy config
if err := h.store.ConfigStore.UpdateProxyConfig(ctx, &payload); err != nil {
logger.Warn("failed to save proxy configuration: %v", err)
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to save proxy configuration: %v", err))
return
}
// Pulling the proxy config from the config store
newProxyConfig, err := h.store.ConfigStore.GetProxyConfig(ctx)
if err != nil {
logger.Warn("failed to get proxy config from store: %v", err)
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to get proxy config from store: %v", err))
return
}
if newProxyConfig == nil {
newProxyConfig = &configstoreTables.GlobalProxyConfig{
Enabled: false,
Type: network.GlobalProxyTypeHTTP,
URL: "",
Username: "",
Password: "",
NoProxy: "",
Timeout: 0,
SkipTLSVerify: false,
}
}
// Reload proxy config in the server
if err := h.configManager.ReloadProxyConfig(ctx, newProxyConfig); err != nil {
logger.Warn("failed to reload proxy config: %v", err)
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to reload proxy config: %v", err))
return
}
// Set restart required flag for proxy config changes
if err := h.store.ConfigStore.SetRestartRequiredConfig(ctx, &configstoreTables.RestartRequiredConfig{
Required: true,
Reason: "Proxy configuration has been updated. A restart is required for all changes to take full effect.",
}); err != nil {
logger.Warn("failed to set restart required config: %v", err)
}
ctx.SetStatusCode(fasthttp.StatusOK)
SendJSON(ctx, map[string]any{
"status": "success",
"message": "proxy configuration updated successfully",
})
}
// headerFilterConfigEqual compares two GlobalHeaderFilterConfig for equality
func headerFilterConfigEqual(a, b *configstoreTables.GlobalHeaderFilterConfig) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
return slices.Equal(a.Allowlist, b.Allowlist) && slices.Equal(a.Denylist, b.Denylist)
}
// validateHeaderFilterConfig validates that no exact security header names are in the allowlist or denylist
// and that wildcard patterns use valid syntax (only trailing * is supported).
// Wildcard patterns that would match security headers are allowed because security headers
// are unconditionally stripped at runtime regardless of configuration.
// Returns an error if any exact security headers are found or patterns are invalid.
func validateHeaderFilterConfig(config *configstoreTables.GlobalHeaderFilterConfig) error {
if config == nil {
return nil
}
// Validate pattern syntax and normalize entries (trim, lowercase, drop empties)
filteredAllow := config.Allowlist[:0]
for _, header := range config.Allowlist {
h := strings.ToLower(strings.TrimSpace(header))
if h == "" {
continue
}
if idx := strings.Index(h, "*"); idx != -1 && idx != len(h)-1 {
return fmt.Errorf("invalid pattern %q: wildcard (*) is only supported at the end of a pattern", h)
}
filteredAllow = append(filteredAllow, h)
}
config.Allowlist = filteredAllow
filteredDeny := config.Denylist[:0]
for _, header := range config.Denylist {
h := strings.ToLower(strings.TrimSpace(header))
if h == "" {
continue
}
if idx := strings.Index(h, "*"); idx != -1 && idx != len(h)-1 {
return fmt.Errorf("invalid pattern %q: wildcard (*) is only supported at the end of a pattern", h)
}
filteredDeny = append(filteredDeny, h)
}
config.Denylist = filteredDeny
var foundSecurityHeaders []string
// Check allowlist for exact security header names.
// Wildcard patterns are allowed — security headers are always stripped at runtime
// unconditionally in ctx.go, regardless of allowlist/denylist configuration.
for _, header := range config.Allowlist {
headerLower := strings.ToLower(strings.TrimSpace(header))
if strings.Contains(headerLower, "*") {
continue
}
if slices.Contains(securityHeaders, headerLower) {
foundSecurityHeaders = append(foundSecurityHeaders, headerLower)
}
}
// Check denylist for exact security header names.
for _, header := range config.Denylist {
headerLower := strings.ToLower(strings.TrimSpace(header))
if strings.Contains(headerLower, "*") {
continue
}
if slices.Contains(securityHeaders, headerLower) && !slices.Contains(foundSecurityHeaders, headerLower) {
foundSecurityHeaders = append(foundSecurityHeaders, headerLower)
}
}
if len(foundSecurityHeaders) > 0 {
return fmt.Errorf("the following headers are not allowed to be configured: %s. These headers are security headers and are always blocked", strings.Join(foundSecurityHeaders, ", "))
}
return nil
}

View File

@@ -0,0 +1,198 @@
package handlers
import (
"testing"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
)
func TestValidateHeaderFilterConfig(t *testing.T) {
tests := []struct {
name string
config *configstoreTables.GlobalHeaderFilterConfig
wantErr bool
errSubstr string
}{
{
name: "nil config",
config: nil,
},
{
name: "empty lists",
config: &configstoreTables.GlobalHeaderFilterConfig{},
},
{
name: "empty allowlist and denylist slices",
config: &configstoreTables.GlobalHeaderFilterConfig{
Allowlist: []string{},
Denylist: []string{},
},
},
{
name: "valid allowlist patterns",
config: &configstoreTables.GlobalHeaderFilterConfig{
Allowlist: []string{"anthropic-beta", "x-custom-*", "content-type"},
},
},
{
name: "valid denylist patterns",
config: &configstoreTables.GlobalHeaderFilterConfig{
Denylist: []string{"x-internal-*", "x-debug"},
},
},
{
name: "valid allowlist and denylist together",
config: &configstoreTables.GlobalHeaderFilterConfig{
Allowlist: []string{"anthropic-*", "content-type"},
Denylist: []string{"x-internal-*"},
},
},
// Empty/whitespace entries should be silently dropped, not cause errors
{
name: "whitespace-only entries in allowlist are dropped",
config: &configstoreTables.GlobalHeaderFilterConfig{
Allowlist: []string{" ", "anthropic-beta", ""},
},
},
{
name: "whitespace-only entries in denylist are dropped",
config: &configstoreTables.GlobalHeaderFilterConfig{
Denylist: []string{"", "x-debug", " "},
},
},
{
name: "all-empty allowlist becomes effectively empty",
config: &configstoreTables.GlobalHeaderFilterConfig{
Allowlist: []string{"", " ", "\t"},
},
},
// Security header checks
{
name: "security header in allowlist rejected",
config: &configstoreTables.GlobalHeaderFilterConfig{
Allowlist: []string{"authorization"},
},
wantErr: true,
errSubstr: "not allowed to be configured",
},
{
name: "security header in denylist rejected",
config: &configstoreTables.GlobalHeaderFilterConfig{
Denylist: []string{"x-api-key"},
},
wantErr: true,
errSubstr: "not allowed to be configured",
},
{
name: "wildcard matching security header allowed (runtime strips security headers)",
config: &configstoreTables.GlobalHeaderFilterConfig{
Allowlist: []string{"authorization*"},
},
},
{
name: "wildcard prefix matching security headers allowed (runtime strips security headers)",
config: &configstoreTables.GlobalHeaderFilterConfig{
Allowlist: []string{"x-api-*"},
},
},
{
name: "bare wildcard in allowlist allowed (runtime strips security headers)",
config: &configstoreTables.GlobalHeaderFilterConfig{
Allowlist: []string{"*"},
},
},
// Invalid wildcard syntax
{
name: "wildcard in middle of pattern rejected",
config: &configstoreTables.GlobalHeaderFilterConfig{
Allowlist: []string{"x-*-header"},
},
wantErr: true,
errSubstr: "wildcard (*) is only supported at the end",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateHeaderFilterConfig(tt.config)
if tt.wantErr {
if err == nil {
t.Fatalf("expected error containing %q, got nil", tt.errSubstr)
}
if tt.errSubstr != "" && !contains(err.Error(), tt.errSubstr) {
t.Fatalf("expected error containing %q, got %q", tt.errSubstr, err.Error())
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
})
}
}
func TestValidateHeaderFilterConfig_EmptyEntriesDropped(t *testing.T) {
// Verify that empty/whitespace entries are actually removed from the stored config
config := &configstoreTables.GlobalHeaderFilterConfig{
Allowlist: []string{" ", "anthropic-beta", "", "content-type", "\t"},
Denylist: []string{"", "x-debug", " "},
}
if err := validateHeaderFilterConfig(config); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(config.Allowlist) != 2 {
t.Fatalf("expected allowlist length 2, got %d: %v", len(config.Allowlist), config.Allowlist)
}
if config.Allowlist[0] != "anthropic-beta" || config.Allowlist[1] != "content-type" {
t.Fatalf("unexpected allowlist: %v", config.Allowlist)
}
if len(config.Denylist) != 1 {
t.Fatalf("expected denylist length 1, got %d: %v", len(config.Denylist), config.Denylist)
}
if config.Denylist[0] != "x-debug" {
t.Fatalf("unexpected denylist: %v", config.Denylist)
}
}
// TestValidateHeaderFilterConfig_EmptyConfigStillForwardsHeaders verifies that when
// all entries are empty/whitespace, validation strips them and the compiled matcher
// allows all headers through (same behavior as no config — x-bf-eh-* headers forwarded as-is).
func TestValidateHeaderFilterConfig_EmptyConfigStillForwardsHeaders(t *testing.T) {
// Config where all entries are whitespace-only
config := &configstoreTables.GlobalHeaderFilterConfig{
Allowlist: []string{"", " ", "\t"},
Denylist: []string{"", " "},
}
if err := validateHeaderFilterConfig(config); err != nil {
t.Fatalf("unexpected error: %v", err)
}
// After validation, both lists should be empty
if len(config.Allowlist) != 0 {
t.Fatalf("expected empty allowlist, got %v", config.Allowlist)
}
if len(config.Denylist) != 0 {
t.Fatalf("expected empty denylist, got %v", config.Denylist)
}
// Compile the validated config into a matcher — should allow everything
m := lib.NewHeaderMatcher(config)
// Matcher with empty lists should allow all headers (x-bf-eh-* forwarded as-is)
for _, header := range []string{"anthropic-beta", "x-custom-header", "content-type", "x-anything"} {
if !m.ShouldAllow(header) {
t.Errorf("expected header %q to be allowed with empty config, but it was denied", header)
}
}
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && searchString(s, substr)
}
func searchString(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}

View File

@@ -0,0 +1,778 @@
//go:build dev
package handlers
import (
"bytes"
"os"
"regexp"
"runtime"
"runtime/pprof"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/fasthttp/router"
"github.com/google/pprof/profile"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
const (
// Collection interval for metrics
metricsCollectionInterval = 10 * time.Second
// Number of data points to keep (5 minutes / 10 seconds = 30 points)
historySize = 30
// Top allocations to return per table (cumulative and in-use)
topAllocationsCount = 50
)
// MemoryStats represents memory statistics at a point in time
type MemoryStats struct {
Alloc uint64 `json:"alloc"`
TotalAlloc uint64 `json:"total_alloc"`
HeapInuse uint64 `json:"heap_inuse"`
HeapObjects uint64 `json:"heap_objects"`
Sys uint64 `json:"sys"`
}
// CPUStats represents CPU statistics
type CPUStats struct {
UsagePercent float64 `json:"usage_percent"`
UserTime float64 `json:"user_time"`
SystemTime float64 `json:"system_time"`
}
// RuntimeStats represents runtime statistics
type RuntimeStats struct {
NumGoroutine int `json:"num_goroutine"`
NumGC uint32 `json:"num_gc"`
GCPauseNs uint64 `json:"gc_pause_ns"`
NumCPU int `json:"num_cpu"`
GOMAXPROCS int `json:"gomaxprocs"`
}
// AllocationInfo represents a single allocation site
type AllocationInfo struct {
Function string `json:"function"`
File string `json:"file"`
Line int `json:"line"`
Bytes int64 `json:"bytes"`
Count int64 `json:"count"`
Stack []string `json:"stack"`
}
// GoroutineGroup represents a group of goroutines with the same stack trace
type GoroutineGroup struct {
Count int `json:"count"`
State string `json:"state"`
WaitReason string `json:"wait_reason,omitempty"`
WaitMinutes int `json:"wait_minutes,omitempty"` // Parsed wait time in minutes
TopFunc string `json:"top_func"`
Stack []string `json:"stack"`
Category string `json:"category"` // "background", "per-request", "unknown"
}
// GoroutineProfile represents the goroutine profile response
type GoroutineProfile struct {
Timestamp string `json:"timestamp"`
TotalGoroutines int `json:"total_goroutines"`
Groups []GoroutineGroup `json:"groups"`
Summary GoroutineSummary `json:"summary"`
RawProfile string `json:"raw_profile,omitempty"`
}
// GoroutineSummary provides a quick overview of goroutine health
type GoroutineSummary struct {
Background int `json:"background"` // Expected long-running goroutines
PerRequest int `json:"per_request"` // Goroutines that should complete with requests
LongWaiting int `json:"long_waiting"` // Goroutines waiting > 1 minute (potential leaks)
PotentiallyStuck int `json:"potentially_stuck"` // Per-request goroutines waiting > 1 minute
}
// HistoryPoint represents a single point in the metrics history
type HistoryPoint struct {
Timestamp string `json:"timestamp"`
Alloc uint64 `json:"alloc"`
HeapInuse uint64 `json:"heap_inuse"`
Goroutines int `json:"goroutines"`
GCPauseNs uint64 `json:"gc_pause_ns"`
CPUPercent float64 `json:"cpu_percent"`
}
// PprofData represents the complete pprof response
type PprofData struct {
Timestamp string `json:"timestamp"`
Memory MemoryStats `json:"memory"`
CPU CPUStats `json:"cpu"`
Runtime RuntimeStats `json:"runtime"`
TopAllocations []AllocationInfo `json:"top_allocations"`
InuseAllocations []AllocationInfo `json:"inuse_allocations"`
History []HistoryPoint `json:"history"`
}
// cpuSample holds a CPU time sample for calculating usage
type cpuSample struct {
timestamp time.Time
userTime time.Duration
systemTime time.Duration
}
// MetricsCollector collects and stores runtime metrics
type MetricsCollector struct {
mu sync.RWMutex
history []HistoryPoint
stopCh chan struct{}
started bool
lastCPUSample cpuSample
currentCPU CPUStats
}
// DevPprofHandler handles development profiling endpoints
type DevPprofHandler struct {
collector *MetricsCollector
}
// Global collector instance
var globalCollector *MetricsCollector
var collectorOnce sync.Once
// IsDevMode checks if dev mode is enabled via environment variable
func IsDevMode() bool {
return os.Getenv("BIFROST_UI_DEV") == "true"
}
// getOrCreateCollector returns the global metrics collector, creating it if needed
func getOrCreateCollector() *MetricsCollector {
collectorOnce.Do(func() {
globalCollector = &MetricsCollector{
history: make([]HistoryPoint, 0, historySize),
stopCh: make(chan struct{}),
}
})
return globalCollector
}
// NewDevPprofHandler creates a new dev pprof handler
func NewDevPprofHandler() *DevPprofHandler {
return &DevPprofHandler{
collector: getOrCreateCollector(),
}
}
// Start begins the background metrics collection
func (c *MetricsCollector) Start() {
c.mu.Lock()
if c.started {
c.mu.Unlock()
return
}
c.stopCh = make(chan struct{})
c.started = true
c.mu.Unlock()
go c.collectLoop()
}
// Stop stops the background metrics collection
func (c *MetricsCollector) Stop() {
c.mu.Lock()
defer c.mu.Unlock()
if !c.started {
return
}
close(c.stopCh)
c.stopCh = nil
c.started = false
}
func (c *MetricsCollector) collectLoop() {
// Initialize CPU sample
c.lastCPUSample = getCPUSample()
// Wait a bit before first collection to get accurate CPU reading
time.Sleep(100 * time.Millisecond)
// Collect immediately on start
c.collect()
ticker := time.NewTicker(metricsCollectionInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
c.collect()
case <-c.stopCh:
return
}
}
}
// calculateCPUUsage calculates CPU usage percentage between two samples
func calculateCPUUsage(prev, curr cpuSample, numCPU int) CPUStats {
elapsed := curr.timestamp.Sub(prev.timestamp)
if elapsed <= 0 {
return CPUStats{}
}
userDelta := curr.userTime - prev.userTime
systemDelta := curr.systemTime - prev.systemTime
totalCPUTime := userDelta + systemDelta
// Calculate percentage: (CPU time used / wall time) * 100
// Normalized by number of CPUs to get 0-100% range
cpuPercent := (float64(totalCPUTime) / float64(elapsed)) * 100.0
// Cap at 100% * numCPU (in case of measurement errors)
maxPercent := float64(numCPU) * 100.0
if cpuPercent > maxPercent {
cpuPercent = maxPercent
}
return CPUStats{
UsagePercent: cpuPercent,
UserTime: userDelta.Seconds(),
SystemTime: systemDelta.Seconds(),
}
}
func (c *MetricsCollector) collect() {
var memStats runtime.MemStats
runtime.ReadMemStats(&memStats)
// Get current CPU sample and calculate usage
currentSample := getCPUSample()
cpuStats := calculateCPUUsage(c.lastCPUSample, currentSample, runtime.NumCPU())
c.lastCPUSample = currentSample
point := HistoryPoint{
Timestamp: time.Now().Format(time.RFC3339),
Alloc: memStats.Alloc,
HeapInuse: memStats.HeapInuse,
Goroutines: runtime.NumGoroutine(),
GCPauseNs: memStats.PauseNs[(memStats.NumGC+255)%256],
CPUPercent: cpuStats.UsagePercent,
}
c.mu.Lock()
defer c.mu.Unlock()
// Store current CPU stats for API response
c.currentCPU = cpuStats
// Append to history, maintaining ring buffer behavior
if len(c.history) >= historySize {
// Shift left by one and append
copy(c.history, c.history[1:])
c.history[len(c.history)-1] = point
} else {
c.history = append(c.history, point)
}
}
func (c *MetricsCollector) getHistory() []HistoryPoint {
c.mu.RLock()
defer c.mu.RUnlock()
// Return a copy to avoid race conditions
result := make([]HistoryPoint, len(c.history))
copy(result, c.history)
return result
}
func (c *MetricsCollector) getCPUStats() CPUStats {
c.mu.RLock()
defer c.mu.RUnlock()
return c.currentCPU
}
// getAllocations analyzes the heap profile and returns two allocation lists
// aggregated by full call stack:
// - cumulative: alloc_space / alloc_objects (total since process start)
// - inuse: inuse_space / inuse_objects (currently live on the heap)
//
// Both are produced from a single pprof.WriteHeapProfile call.
func getAllocations() (cumulative, inuse []AllocationInfo) {
var buf bytes.Buffer
if err := pprof.WriteHeapProfile(&buf); err != nil {
return nil, nil
}
p, err := profile.Parse(&buf)
if err != nil {
return nil, nil
}
allocObjectsIdx, allocSpaceIdx := -1, -1
inuseObjectsIdx, inuseSpaceIdx := -1, -1
for i, st := range p.SampleType {
switch st.Type {
case "alloc_objects":
allocObjectsIdx = i
case "alloc_space":
allocSpaceIdx = i
case "inuse_objects":
inuseObjectsIdx = i
case "inuse_space":
inuseSpaceIdx = i
}
}
allocMap := make(map[string]*AllocationInfo)
inuseMap := make(map[string]*AllocationInfo)
for _, sample := range p.Sample {
if len(sample.Location) == 0 {
continue
}
topLoc := sample.Location[0]
if len(topLoc.Line) == 0 {
continue
}
topLine := topLoc.Line[0]
topFn := topLine.Function
if topFn == nil {
continue
}
// Filter only the top frame — filtering inner frames would drop real
// user allocations that merely pass through runtime/profiler code.
if isProfilerFunction(topFn.Name, topFn.Filename) {
continue
}
// Build full stack in goroutine-dump format: alternating "funcName" and
// "\tfile:line" entries, top-down. Matches GoroutineGroup.Stack so the
// UI can render both with the same code path.
stack := make([]string, 0, len(sample.Location)*2)
for _, loc := range sample.Location {
if len(loc.Line) == 0 {
continue
}
frame := loc.Line[0]
if frame.Function == nil {
continue
}
stack = append(stack, frame.Function.Name)
stack = append(stack, "\t"+frame.Function.Filename+":"+strconv.FormatInt(frame.Line, 10))
}
if len(stack) == 0 {
continue
}
key := strings.Join(stack, "\n")
if allocSpaceIdx >= 0 && allocObjectsIdx >= 0 {
b := sample.Value[allocSpaceIdx]
c := sample.Value[allocObjectsIdx]
if existing, ok := allocMap[key]; ok {
existing.Bytes += b
existing.Count += c
} else {
allocMap[key] = &AllocationInfo{
Function: topFn.Name,
File: topFn.Filename,
Line: int(topLine.Line),
Bytes: b,
Count: c,
Stack: stack,
}
}
}
if inuseSpaceIdx >= 0 && inuseObjectsIdx >= 0 {
b := sample.Value[inuseSpaceIdx]
c := sample.Value[inuseObjectsIdx]
// Most samples have inuse=0 (already freed) — skip them so the live
// table isn't padded with noise.
if b == 0 && c == 0 {
continue
}
if existing, ok := inuseMap[key]; ok {
existing.Bytes += b
existing.Count += c
} else {
inuseMap[key] = &AllocationInfo{
Function: topFn.Name,
File: topFn.Filename,
Line: int(topLine.Line),
Bytes: b,
Count: c,
Stack: stack,
}
}
}
}
return flattenAndTopN(allocMap), flattenAndTopN(inuseMap)
}
// flattenAndTopN sorts an allocation map by bytes desc and caps it.
func flattenAndTopN(m map[string]*AllocationInfo) []AllocationInfo {
out := make([]AllocationInfo, 0, len(m))
for _, a := range m {
out = append(out, *a)
}
sort.Slice(out, func(i, j int) bool { return out[i].Bytes > out[j].Bytes })
if len(out) > topAllocationsCount {
out = out[:topAllocationsCount]
}
return out
}
// RegisterRoutes registers the dev pprof routes
func (h *DevPprofHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
// Start the collector when routes are registered
h.collector.Start()
r.GET("/api/dev/pprof", lib.ChainMiddlewares(h.getPprof, middlewares...))
r.GET("/api/dev/pprof/goroutines", lib.ChainMiddlewares(h.getGoroutines, middlewares...))
}
// getPprof handles GET /api/dev/pprof
func (h *DevPprofHandler) getPprof(ctx *fasthttp.RequestCtx) {
var memStats runtime.MemStats
runtime.ReadMemStats(&memStats)
data := PprofData{
Timestamp: time.Now().Format(time.RFC3339),
Memory: MemoryStats{
Alloc: memStats.Alloc,
TotalAlloc: memStats.TotalAlloc,
HeapInuse: memStats.HeapInuse,
HeapObjects: memStats.HeapObjects,
Sys: memStats.Sys,
},
CPU: h.collector.getCPUStats(),
Runtime: RuntimeStats{
NumGoroutine: runtime.NumGoroutine(),
NumGC: memStats.NumGC,
GCPauseNs: memStats.PauseNs[(memStats.NumGC+255)%256],
NumCPU: runtime.NumCPU(),
GOMAXPROCS: runtime.GOMAXPROCS(0),
},
History: h.collector.getHistory(),
}
data.TopAllocations, data.InuseAllocations = getAllocations()
SendJSON(ctx, data)
}
// getGoroutines handles GET /api/dev/pprof/goroutines
// Returns goroutine stack traces grouped by stack signature
func (h *DevPprofHandler) getGoroutines(ctx *fasthttp.RequestCtx) {
// Check if raw output is requested
includeRaw := string(ctx.QueryArgs().Peek("raw")) == "true"
// Get goroutine profile
var buf bytes.Buffer
if err := pprof.Lookup("goroutine").WriteTo(&buf, 2); err != nil {
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
SendJSON(ctx, map[string]string{"error": "failed to get goroutine profile"})
return
}
rawProfile := buf.String()
allGroups := parseGoroutineProfile(rawProfile)
// Filter out profiler goroutines and calculate summary
groups := make([]GoroutineGroup, 0, len(allGroups))
summary := GoroutineSummary{}
profilerGoroutineCount := 0
for i := range allGroups {
categorizeGoroutine(&allGroups[i])
// Skip profiler's own goroutines
if isProfilerGoroutine(&allGroups[i]) {
profilerGoroutineCount += allGroups[i].Count
continue
}
groups = append(groups, allGroups[i])
switch allGroups[i].Category {
case "background":
summary.Background += allGroups[i].Count
case "per-request":
summary.PerRequest += allGroups[i].Count
}
if allGroups[i].WaitMinutes >= 1 {
summary.LongWaiting += allGroups[i].Count
if allGroups[i].Category == "per-request" {
summary.PotentiallyStuck += allGroups[i].Count
}
}
}
// Sort: potentially stuck first, then by wait time, then by count
sort.Slice(groups, func(i, j int) bool {
// Potentially stuck (per-request + long wait) first
iStuck := groups[i].Category == "per-request" && groups[i].WaitMinutes >= 1
jStuck := groups[j].Category == "per-request" && groups[j].WaitMinutes >= 1
if iStuck != jStuck {
return iStuck
}
// Then by wait time
if groups[i].WaitMinutes != groups[j].WaitMinutes {
return groups[i].WaitMinutes > groups[j].WaitMinutes
}
// Then by count
return groups[i].Count > groups[j].Count
})
// Calculate app goroutines (total minus profiler goroutines)
// Calculate total goroutines from profile snapshot
totalFromProfile := 0
for _, g := range groups {
totalFromProfile += g.Count
}
response := GoroutineProfile{
Timestamp: time.Now().Format(time.RFC3339),
TotalGoroutines: totalFromProfile,
Groups: groups,
Summary: summary,
}
if includeRaw {
response.RawProfile = rawProfile
}
SendJSON(ctx, response)
}
// categorizeGoroutine determines if a goroutine is a background worker or per-request
func categorizeGoroutine(g *GoroutineGroup) {
// Parse wait time from wait reason (e.g., "5 minutes" -> 5)
g.WaitMinutes = parseWaitMinutes(g.WaitReason)
stackStr := strings.Join(g.Stack, " ")
// Background goroutines - expected to run forever
backgroundPatterns := []string{
"requestWorker", // Provider queue workers
"collectLoop", // Metrics collector
"cleanupWorker", // Various cleanup workers
"startAccumulatorMapCleanup", // Stream accumulator cleanup
"cleanupOldTraces", // Trace store cleanup
"startCleanup", // Generic cleanup
"monitorLoop", // Health monitor
"StartHeartbeat", // WebSocket heartbeat
"time.Sleep", // Ticker-based workers
"runtime.gopark", // Runtime parking (often tickers)
"sync.(*Cond).Wait", // Condition variable waits
"net/http.(*persistConn)", // HTTP connection pool
"internal/poll.runtime_pollWait", // Network polling
}
for _, pattern := range backgroundPatterns {
if strings.Contains(stackStr, pattern) {
g.Category = "background"
return
}
}
// Per-request goroutines - should complete when request ends
perRequestPatterns := []string{
"PreLLMHook",
"PostLLMHook",
"PreMCPHook",
"PostMCPHook",
"HTTPTransportPreHook",
"HTTPTransportPostHook",
"CompleteAndFlushTrace",
"ProcessAndSend",
"handleProvider",
"Inject", // Observability plugin inject
"insertInitialLogEntry", // Logging
"updateLogEntry", // Logging
"retryOnNotFound",
"BroadcastLogUpdate",
}
for _, pattern := range perRequestPatterns {
if strings.Contains(stackStr, pattern) {
g.Category = "per-request"
return
}
}
g.Category = "unknown"
}
// parseWaitMinutes extracts wait time in minutes from wait reason string
func parseWaitMinutes(waitReason string) int {
if waitReason == "" {
return 0
}
// Match patterns like "5 minutes", "1 minute", "30 seconds", "2 hours"
minuteRegex := regexp.MustCompile(`(\d+)\s*minute`)
if matches := minuteRegex.FindStringSubmatch(waitReason); len(matches) >= 2 {
if mins, err := strconv.Atoi(matches[1]); err == nil {
return mins
}
}
hourRegex := regexp.MustCompile(`(\d+)\s*hour`)
if matches := hourRegex.FindStringSubmatch(waitReason); len(matches) >= 2 {
if hours, err := strconv.Atoi(matches[1]); err == nil {
return hours * 60
}
}
secondRegex := regexp.MustCompile(`(\d+)\s*second`)
if matches := secondRegex.FindStringSubmatch(waitReason); len(matches) >= 2 {
if secs, err := strconv.Atoi(matches[1]); err == nil {
return secs / 60 // Convert to minutes, will be 0 for < 60 seconds
}
}
return 0
}
// parseGoroutineProfile parses the text output of pprof goroutine profile
// and groups goroutines by their stack trace
func parseGoroutineProfile(profile string) []GoroutineGroup {
// Regex to match goroutine header: "goroutine N [state, wait reason]:"
// Examples:
// goroutine 1 [running]:
// goroutine 42 [select, 5 minutes]:
// goroutine 100 [chan receive]:
headerRegex := regexp.MustCompile(`goroutine \d+ \[([^\]]+)\]:`)
// Split by "goroutine " to get individual goroutine blocks
blocks := strings.Split(profile, "goroutine ")
// Map to group goroutines by stack signature
groupMap := make(map[string]*GoroutineGroup)
for _, block := range blocks {
block = strings.TrimSpace(block)
if block == "" {
continue
}
// Re-add "goroutine " prefix for regex matching
fullBlock := "goroutine " + block
// Extract state from header
matches := headerRegex.FindStringSubmatch(fullBlock)
if len(matches) < 2 {
continue
}
stateInfo := matches[1]
state := stateInfo
waitReason := ""
// Parse state and wait reason (e.g., "select, 5 minutes" -> state="select", waitReason="5 minutes")
if idx := strings.Index(stateInfo, ","); idx != -1 {
state = strings.TrimSpace(stateInfo[:idx])
waitReason = strings.TrimSpace(stateInfo[idx+1:])
}
// Get stack trace (everything after the header line)
lines := strings.Split(block, "\n")
if len(lines) < 2 {
continue
}
// Extract stack frames (skip the header line which is lines[0])
var stackLines []string
var topFunc string
for i := 1; i < len(lines); i++ {
line := strings.TrimSpace(lines[i])
if line == "" {
continue
}
stackLines = append(stackLines, line)
// First function line (not a file:line) is the top function
if topFunc == "" && !strings.HasPrefix(line, "/") && !strings.Contains(line, ".go:") {
topFunc = line
}
}
if len(stackLines) == 0 {
continue
}
// Create a signature from the stack (top 10 frames for grouping)
maxFrames := 10
if len(stackLines) < maxFrames {
maxFrames = len(stackLines)
}
signature := state + "|" + strings.Join(stackLines[:maxFrames], "|")
// Group by signature
if existing, ok := groupMap[signature]; ok {
existing.Count++
} else {
groupMap[signature] = &GoroutineGroup{
Count: 1,
State: state,
WaitReason: waitReason,
TopFunc: topFunc,
Stack: stackLines,
}
}
}
// Convert map to slice
groups := make([]GoroutineGroup, 0, len(groupMap))
for _, group := range groupMap {
groups = append(groups, *group)
}
return groups
}
// profilerPatterns contains patterns to identify profiler-related code
var profilerPatterns = []string{
"devpprof",
"pprof.WriteHeapProfile",
"pprof.Lookup",
"profile.Parse",
"MetricsCollector",
"collectLoop",
"getAllocations",
"flattenAndTopN",
"parseGoroutineProfile",
"getGoroutines",
"getCPUSample",
}
// isProfilerFunction checks if a function belongs to the profiler itself
func isProfilerFunction(funcName, fileName string) bool {
for _, pattern := range profilerPatterns {
if strings.Contains(funcName, pattern) || strings.Contains(fileName, pattern) {
return true
}
}
return false
}
// isProfilerGoroutine checks if a goroutine belongs to the profiler
func isProfilerGoroutine(g *GoroutineGroup) bool {
stackStr := strings.Join(g.Stack, " ")
for _, pattern := range profilerPatterns {
if strings.Contains(stackStr, pattern) {
return true
}
}
return false
}
// Cleanup stops the metrics collector
func (h *DevPprofHandler) Cleanup() {
if h.collector != nil {
h.collector.Stop()
}
}

View File

@@ -0,0 +1,23 @@
//go:build !dev
package handlers
import (
"github.com/fasthttp/router"
"github.com/maximhq/bifrost/core/schemas"
)
// DevPprofHandler is a no-op stub for production builds (built without the "dev" tag).
type DevPprofHandler struct{}
// IsDevMode always returns false in production builds.
func IsDevMode() bool { return false }
// NewDevPprofHandler returns nil in production builds.
func NewDevPprofHandler() *DevPprofHandler { return nil }
// RegisterRoutes is a no-op in production builds.
func (h *DevPprofHandler) RegisterRoutes(_ *router.Router, _ ...schemas.BifrostHTTPMiddleware) {}
// Cleanup is a no-op in production builds.
func (h *DevPprofHandler) Cleanup() {}

View File

@@ -0,0 +1,26 @@
//go:build dev && !windows
package handlers
import (
"syscall"
"time"
)
// getCPUSample gets the current CPU time sample using syscall
func getCPUSample() cpuSample {
var rusage syscall.Rusage
if err := syscall.Getrusage(syscall.RUSAGE_SELF, &rusage); err != nil {
return cpuSample{timestamp: time.Now()}
}
userTime := time.Duration(rusage.Utime.Sec)*time.Second + time.Duration(rusage.Utime.Usec)*time.Microsecond
systemTime := time.Duration(rusage.Stime.Sec)*time.Second + time.Duration(rusage.Stime.Usec)*time.Microsecond
return cpuSample{
timestamp: time.Now(),
userTime: userTime,
systemTime: systemTime,
}
}

View File

@@ -0,0 +1,12 @@
//go:build dev && windows
package handlers
import "time"
// getCPUSample returns a zeroed CPU sample on Windows
// Windows does not support syscall.Getrusage
func getCPUSample() cpuSample {
return cpuSample{timestamp: time.Now()}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,337 @@
package handlers
import (
"context"
"encoding/json"
"testing"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/maximhq/bifrost/plugins/governance"
"github.com/valyala/fasthttp"
)
// mockGovernanceManagerForVK embeds the interface so unimplemented methods panic.
// Only GetGovernanceData is needed for the getVirtualKeys handler path.
type mockGovernanceManagerForVK struct {
GovernanceManager
}
func (m *mockGovernanceManagerForVK) GetGovernanceData(ctx context.Context) *governance.GovernanceData {
return nil
}
// mockConfigStoreForVK embeds the interface so unimplemented methods panic.
// Only GetVirtualKeysPaginated is called in the non-from_memory path.
type mockConfigStoreForVK struct {
configstore.ConfigStore
}
func (m *mockConfigStoreForVK) GetVirtualKeysPaginated(_ context.Context, _ configstore.VirtualKeyQueryParams) ([]configstoreTables.TableVirtualKey, int64, error) {
return nil, 0, nil
}
func (m *mockConfigStoreForVK) GetVirtualKeys(_ context.Context) ([]configstoreTables.TableVirtualKey, error) {
return nil, nil
}
// TestGetVirtualKeys_PaginatedEndpoint_ResponseShape verifies the JSON response
// from the paginated virtual keys endpoint contains all expected fields.
func TestGetVirtualKeys_PaginatedEndpoint_ResponseShape(t *testing.T) {
SetLogger(&mockLogger{})
h := &GovernanceHandler{
configStore: &mockConfigStoreForVK{},
governanceManager: &mockGovernanceManagerForVK{},
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET")
ctx.Request.SetRequestURI("/api/governance/virtual-keys?limit=10&offset=0")
h.getVirtualKeys(ctx)
if ctx.Response.StatusCode() != 200 {
t.Fatalf("expected status 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
}
var resp map[string]interface{}
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
t.Fatalf("failed to parse JSON response: %v", err)
}
// Assert expected fields exist with correct types
requiredFields := []struct {
key string
wantType string
}{
{"virtual_keys", "array"},
{"total_count", "number"},
{"count", "number"},
{"limit", "number"},
{"offset", "number"},
}
for _, f := range requiredFields {
val, ok := resp[f.key]
if !ok {
t.Errorf("response missing required field %q", f.key)
continue
}
switch f.wantType {
case "array":
if _, ok := val.([]interface{}); !ok {
// nil decodes as nil, which is fine — JSON null for empty array
if val != nil {
t.Errorf("field %q: expected array, got %T", f.key, val)
}
}
case "number":
if _, ok := val.(float64); !ok {
t.Errorf("field %q: expected number, got %T", f.key, val)
}
}
}
// Verify no unexpected extra top-level fields
allowedKeys := map[string]bool{
"virtual_keys": true,
"total_count": true,
"count": true,
"limit": true,
"offset": true,
}
for key := range resp {
if !allowedKeys[key] {
t.Errorf("unexpected field %q in response", key)
}
}
}
// TestGetVirtualKeys_PaginatedEndpoint_QueryParams verifies query parameters are
// parsed and reflected in the response.
func TestGetVirtualKeys_PaginatedEndpoint_QueryParams(t *testing.T) {
SetLogger(&mockLogger{})
h := &GovernanceHandler{
configStore: &mockConfigStoreForVK{},
governanceManager: &mockGovernanceManagerForVK{},
}
tests := []struct {
name string
uri string
wantLimit float64
wantOffset float64
}{
{
name: "explicit limit and offset",
uri: "/api/governance/virtual-keys?limit=10&offset=5",
wantLimit: 10,
wantOffset: 5,
},
{
name: "no params uses defaults",
uri: "/api/governance/virtual-keys",
wantLimit: 0,
wantOffset: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET")
ctx.Request.SetRequestURI(tt.uri)
h.getVirtualKeys(ctx)
if ctx.Response.StatusCode() != 200 {
t.Fatalf("expected status 200, got %d", ctx.Response.StatusCode())
}
var resp map[string]interface{}
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
t.Fatalf("failed to parse JSON: %v", err)
}
if got := resp["limit"].(float64); got != tt.wantLimit {
t.Errorf("limit: got %v, want %v", got, tt.wantLimit)
}
if got := resp["offset"].(float64); got != tt.wantOffset {
t.Errorf("offset: got %v, want %v", got, tt.wantOffset)
}
})
}
}
// Ensure mockLogger satisfies schemas.Logger (already defined in middlewares_test.go
// but we reference it here — same package, so no redeclaration needed).
var _ schemas.Logger = (*mockLogger)(nil)
func TestBudgetRemovalRequestDetection(t *testing.T) {
tests := []struct {
name string
req *UpdateBudgetRequest
want bool
}{
{
name: "nil request is not removal",
req: nil,
want: false,
},
{
name: "empty object is removal",
req: &UpdateBudgetRequest{},
want: true,
},
{
name: "max limit present is not removal",
req: &UpdateBudgetRequest{MaxLimit: bifrostFloat(10)},
want: false,
},
{
name: "reset duration only is not removal",
req: &UpdateBudgetRequest{ResetDuration: bifrostString("1h")},
want: false,
},
{
name: "calendar aligned only is treated as removal",
req: &UpdateBudgetRequest{CalendarAligned: bifrostBool(true)},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isBudgetRemovalRequest(tt.req); got != tt.want {
t.Fatalf("isBudgetRemovalRequest() = %v, want %v", got, tt.want)
}
})
}
}
func TestRateLimitRemovalRequestDetection(t *testing.T) {
tests := []struct {
name string
req *UpdateRateLimitRequest
want bool
}{
{
name: "nil request is not removal",
req: nil,
want: false,
},
{
name: "empty object is removal",
req: &UpdateRateLimitRequest{},
want: true,
},
{
name: "token limit present is not removal",
req: &UpdateRateLimitRequest{TokenMaxLimit: bifrostInt64(100)},
want: false,
},
{
name: "request limit present is not removal",
req: &UpdateRateLimitRequest{RequestMaxLimit: bifrostInt64(10)},
want: false,
},
{
name: "durations only is not removal",
req: &UpdateRateLimitRequest{
TokenResetDuration: bifrostString("1h"),
RequestResetDuration: bifrostString("1h"),
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isRateLimitRemovalRequest(tt.req); got != tt.want {
t.Fatalf("isRateLimitRemovalRequest() = %v, want %v", got, tt.want)
}
})
}
}
func TestCollectProviderConfigDeleteIDs(t *testing.T) {
budgetID := "budget-1"
rateLimitID := "rate-limit-1"
tests := []struct {
name string
config configstoreTables.TableVirtualKeyProviderConfig
initialBudgetIDs []string
initialRateIDs []string
wantBudgetIDs []string
wantRateIDs []string
}{
{
name: "collects both IDs",
config: configstoreTables.TableVirtualKeyProviderConfig{
Budgets: []configstoreTables.TableBudget{{ID: budgetID}},
RateLimitID: &rateLimitID,
},
wantBudgetIDs: []string{budgetID},
wantRateIDs: []string{rateLimitID},
},
{
name: "appends to existing slices",
config: configstoreTables.TableVirtualKeyProviderConfig{
Budgets: []configstoreTables.TableBudget{{ID: budgetID}},
RateLimitID: &rateLimitID,
},
initialBudgetIDs: []string{"budget-0"},
initialRateIDs: []string{"rate-limit-0"},
wantBudgetIDs: []string{"budget-0", budgetID},
wantRateIDs: []string{"rate-limit-0", rateLimitID},
},
{
name: "ignores missing IDs",
config: configstoreTables.TableVirtualKeyProviderConfig{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotBudgetIDs, gotRateIDs := collectProviderConfigDeleteIDs(tt.config, tt.initialBudgetIDs, tt.initialRateIDs)
if len(gotBudgetIDs) != len(tt.wantBudgetIDs) {
t.Fatalf("budget IDs length = %d, want %d", len(gotBudgetIDs), len(tt.wantBudgetIDs))
}
for i := range gotBudgetIDs {
if gotBudgetIDs[i] != tt.wantBudgetIDs[i] {
t.Fatalf("budget IDs[%d] = %q, want %q", i, gotBudgetIDs[i], tt.wantBudgetIDs[i])
}
}
if len(gotRateIDs) != len(tt.wantRateIDs) {
t.Fatalf("rate limit IDs length = %d, want %d", len(gotRateIDs), len(tt.wantRateIDs))
}
for i := range gotRateIDs {
if gotRateIDs[i] != tt.wantRateIDs[i] {
t.Fatalf("rate limit IDs[%d] = %q, want %q", i, gotRateIDs[i], tt.wantRateIDs[i])
}
}
})
}
}
func bifrostFloat(v float64) *float64 {
return &v
}
func bifrostInt64(v int64) *int64 {
return &v
}
func bifrostString(v string) *string {
return &v
}
func bifrostBool(v bool) *bool {
return &v
}

View File

@@ -0,0 +1,90 @@
package handlers
import (
"context"
"sync"
"time"
"github.com/fasthttp/router"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
// HealthHandler manages HTTP requests for health checks.
type HealthHandler struct {
config *lib.Config
}
// NewHealthHandler creates a new health handler instance.
func NewHealthHandler(config *lib.Config) *HealthHandler {
return &HealthHandler{
config: config,
}
}
// RegisterRoutes registers the health-related routes.
func (h *HealthHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
r.GET("/health", lib.ChainMiddlewares(h.getHealth, middlewares...))
}
// getHealth handles GET /api/health - Get the health status of the server.
func (h *HealthHandler) getHealth(ctx *fasthttp.RequestCtx) {
// If DB pings are disabled, just return OK
if h.config.ClientConfig.DisableDBPingsInHealth {
SendJSON(ctx, map[string]any{"status": "ok", "components": map[string]any{"db_pings": "disabled"}})
return
}
// Pinging config store
reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
var errors []string
var mu sync.Mutex
var wg sync.WaitGroup
if h.config.ConfigStore != nil {
wg.Add(1)
go func() {
defer wg.Done()
if err := h.config.ConfigStore.Ping(reqCtx); err != nil {
mu.Lock()
errors = append(errors, "config store not available")
mu.Unlock()
}
}()
}
// Pinging log store
if h.config.LogsStore != nil {
wg.Add(1)
go func() {
defer wg.Done()
if err := h.config.LogsStore.Ping(reqCtx); err != nil {
mu.Lock()
errors = append(errors, "log store not available")
mu.Unlock()
}
}()
}
// Pinging vector store
if h.config.VectorStore != nil {
wg.Add(1)
go func() {
defer wg.Done()
if err := h.config.VectorStore.Ping(reqCtx); err != nil {
mu.Lock()
errors = append(errors, "vector store not available")
mu.Unlock()
}
}()
}
wg.Wait()
if len(errors) > 0 {
SendError(ctx, fasthttp.StatusServiceUnavailable, errors[0])
return
}
SendJSON(ctx, map[string]any{"status": "ok", "components": map[string]any{"db_pings": "ok"}})
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,20 @@
package handlers
import "github.com/maximhq/bifrost/core/schemas"
var version string
var logger schemas.Logger
// SetLogger sets the logger for the application.
func SetLogger(l schemas.Logger) {
logger = l
}
// SetVersion sets the version for the application.
func SetVersion(v string) {
version = v
}
func GetVersion() string {
return version
}

View File

@@ -0,0 +1,111 @@
// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
// This file contains integration management handlers for AI provider integrations.
package handlers
import (
"github.com/fasthttp/router"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/transports/bifrost-http/integrations"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
)
// IntegrationHandler manages HTTP requests for AI provider integrations
type IntegrationHandler struct {
extensions []integrations.ExtensionRouter
wsResponses *WSResponsesHandler
wsRealtime *WSRealtimeHandler
webrtcRealtime *WebRTCRealtimeHandler
realtimeClientSecrets *RealtimeClientSecretsHandler
}
// NewIntegrationHandler creates a new integration handler instance.
// WebSocket handlers may be nil if WebSocket support is not configured.
func NewIntegrationHandler(client *bifrost.Bifrost, handlerStore lib.HandlerStore, wsResponses *WSResponsesHandler, wsRealtime *WSRealtimeHandler, webrtcRealtime *WebRTCRealtimeHandler, realtimeClientSecrets *RealtimeClientSecretsHandler) *IntegrationHandler {
// Initialize all available integration routers
extensions := []integrations.ExtensionRouter{
integrations.NewOpenAIRouter(client, handlerStore, logger),
integrations.NewAnthropicRouter(client, handlerStore, logger),
integrations.NewGenAIRouter(client, handlerStore, logger),
integrations.NewLiteLLMRouter(client, handlerStore, logger),
integrations.NewCohereRouter(client, handlerStore, logger),
integrations.NewLangChainRouter(client, handlerStore, logger),
integrations.NewPydanticAIRouter(client, handlerStore, logger),
integrations.NewBedrockRouter(client, handlerStore, logger),
// passthrough routers
integrations.NewGenAIPassthroughRouter(client, handlerStore, logger),
integrations.NewOpenAIPassthroughRouter(client, handlerStore, logger),
integrations.NewAnthropicPassthroughRouter(client, handlerStore, logger),
integrations.NewAzurePassthroughRouter(client, handlerStore, logger),
integrations.NewCursorRouter(client, handlerStore, logger),
}
return &IntegrationHandler{
extensions: extensions,
wsResponses: wsResponses,
wsRealtime: wsRealtime,
webrtcRealtime: webrtcRealtime,
realtimeClientSecrets: realtimeClientSecrets,
}
}
// RegisterRoutes registers all integration routes for AI provider compatibility endpoints
func (h *IntegrationHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
// Register routes for each integration extension
for _, extension := range h.extensions {
extension.RegisterRoutes(r, middlewares...)
}
// Register WebSocket routes (base path + integration paths)
if h.wsResponses != nil {
h.wsResponses.RegisterRoutes(r, middlewares...)
}
if h.wsRealtime != nil {
h.wsRealtime.RegisterRoutes(r, middlewares...)
}
if h.webrtcRealtime != nil {
h.webrtcRealtime.RegisterRoutes(r, middlewares...)
}
if h.realtimeClientSecrets != nil {
h.realtimeClientSecrets.RegisterRoutes(r, middlewares...)
}
}
func (h *IntegrationHandler) Close() {
if h == nil {
return
}
if h.wsResponses != nil {
h.wsResponses.Close()
}
if h.wsRealtime != nil {
h.wsRealtime.Close()
}
if h.webrtcRealtime != nil {
h.webrtcRealtime.Close()
}
}
// SetLargePayloadHook sets the large payload detection hook on all integration routers
// that support it. This is used by enterprise to inject large payload optimization.
func (h *IntegrationHandler) SetLargePayloadHook(hook integrations.LargePayloadHook) {
for _, extension := range h.extensions {
if setter, ok := extension.(interface {
SetLargePayloadHook(integrations.LargePayloadHook)
}); ok {
setter.SetLargePayloadHook(hook)
}
}
}
// SetLargeResponseHook sets the large response scanning hook on all integration routers
// that support it. Enterprise uses this to inject Phase B usage extraction into the
// response stream without embedding scanning logic in the OSS router.
func (h *IntegrationHandler) SetLargeResponseHook(hook integrations.LargeResponseHook) {
for _, extension := range h.extensions {
if setter, ok := extension.(interface {
SetLargeResponseHook(integrations.LargeResponseHook)
}); ok {
setter.SetLargeResponseHook(hook)
}
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,112 @@
package handlers
import (
"fmt"
"strings"
"github.com/bytedance/sonic"
"github.com/fasthttp/router"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
type MCPInferenceHandler struct {
client *bifrost.Bifrost
config *lib.Config
}
// NewMCPInferenceHandler creates a new MCP inference handler instance
func NewMCPInferenceHandler(client *bifrost.Bifrost, config *lib.Config) *MCPInferenceHandler {
return &MCPInferenceHandler{
client: client,
config: config,
}
}
// RegisterRoutes registers the MCP inference routes
func (h *MCPInferenceHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
r.POST("/v1/mcp/tool/execute", lib.ChainMiddlewares(h.executeTool, middlewares...))
}
// executeTool handles POST /v1/mcp/tool/execute - Execute MCP tool
func (h *MCPInferenceHandler) executeTool(ctx *fasthttp.RequestCtx) {
// Check format query parameter
format := strings.ToLower(string(ctx.QueryArgs().Peek("format")))
switch format {
case "chat", "":
h.executeChatMCPTool(ctx)
case "responses":
h.executeResponsesMCPTool(ctx)
default:
SendError(ctx, fasthttp.StatusBadRequest, "Invalid format value, must be 'chat' or 'responses'")
return
}
}
// executeChatMCPTool handles POST /v1/mcp/tool/execute?format=chat - Execute MCP tool
func (h *MCPInferenceHandler) executeChatMCPTool(ctx *fasthttp.RequestCtx) {
var req schemas.ChatAssistantMessageToolCall
if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil {
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err))
return
}
// Validate required fields
if req.Function.Name == nil || *req.Function.Name == "" {
SendError(ctx, fasthttp.StatusBadRequest, "Tool function name is required")
return
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel() // Ensure cleanup on function exit
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
// Execute MCP tool
toolMessage, bifrostErr := h.client.ExecuteChatMCPTool(bifrostCtx, &req)
if bifrostErr != nil {
SendBifrostError(ctx, bifrostErr)
return
}
// Send successful response
SendJSON(ctx, toolMessage)
}
// executeResponsesMCPTool handles POST /v1/mcp/tool/execute?format=responses - Execute MCP tool
func (h *MCPInferenceHandler) executeResponsesMCPTool(ctx *fasthttp.RequestCtx) {
var req schemas.ResponsesToolMessage
if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil {
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err))
return
}
// Validate required fields
if req.Name == nil || *req.Name == "" {
SendError(ctx, fasthttp.StatusBadRequest, "Tool function name is required")
return
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel() // Ensure cleanup on function exit
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
// Execute MCP tool
toolMessage, bifrostErr := h.client.ExecuteResponsesMCPTool(bifrostCtx, &req)
if bifrostErr != nil {
SendBifrostError(ctx, bifrostErr)
return
}
// Send successful response
SendJSON(ctx, toolMessage)
}

View File

@@ -0,0 +1,568 @@
// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
// This file contains MCP (Model Context Protocol) server implementation for HTTP streaming.
package handlers
import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/bytedance/sonic"
"github.com/fasthttp/router"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/maximhq/bifrost/plugins/governance"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
// MCPToolExecutor interface defines the method needed for executing MCP tools
type MCPToolManager interface {
GetAvailableMCPTools(ctx context.Context) []schemas.ChatTool
ExecuteChatMCPTool(ctx context.Context, toolCall *schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, *schemas.BifrostError)
ExecuteResponsesMCPTool(ctx context.Context, toolCall *schemas.ResponsesToolMessage) (*schemas.ResponsesMessage, *schemas.BifrostError)
}
// MCPServerHandler manages HTTP requests for MCP server operations
// It implements the MCP protocol over HTTP streaming (SSE) for MCP clients
type MCPServerHandler struct {
toolManager MCPToolManager
globalMCPServer *server.MCPServer
vkMCPServers map[string]*server.MCPServer // Map of vk value -> mcp server
config *lib.Config
mu sync.RWMutex
}
// NewMCPServerHandler creates a new MCP server handler instance
func NewMCPServerHandler(ctx context.Context, config *lib.Config, toolManager MCPToolManager) (*MCPServerHandler, error) {
if config == nil {
return nil, fmt.Errorf("config is required")
}
if toolManager == nil {
return nil, fmt.Errorf("tool manager is required")
}
// Create MCP server instance using mcp-go
globalMCPServer := server.NewMCPServer(
"global",
version,
server.WithToolCapabilities(true),
)
handler := &MCPServerHandler{
toolManager: toolManager,
globalMCPServer: globalMCPServer,
config: config,
vkMCPServers: make(map[string]*server.MCPServer),
}
// Register per-request tool filter so x-bf-mcp-include-clients and x-bf-mcp-include-tools are respected on tools/list
server.WithToolFilter(handler.makeIncludeClientsFilter())(handler.globalMCPServer)
// Register per-request tool filter so x-bf-mcp-include-clients and x-bf-mcp-include-tools are respected on tools/list
server.WithToolFilter(handler.makeIncludeClientsFilter())(handler.globalMCPServer)
if err := handler.SyncAllMCPServers(ctx); err != nil {
return nil, fmt.Errorf("failed to sync all MCP servers: %w", err)
}
return handler, nil
}
// RegisterRoutes registers the MCP server route
func (h *MCPServerHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
// MCP server endpoint - supports both POST (JSON-RPC) and GET (SSE)
r.POST("/mcp", lib.ChainMiddlewares(h.handleMCPServer, middlewares...))
r.GET("/mcp", lib.ChainMiddlewares(h.handleMCPServerSSE, middlewares...))
}
// handleMCPServer handles POST requests for MCP JSON-RPC 2.0 messages
// injectMCPSessionIdentity sets the MCP gateway flag and, if a per-user OAuth
// session exists, injects the session token and identity (VK / User ID) directly
// into the BifrostContext. This avoids header-based identity propagation which
// would be vulnerable to spoofing by upstream callers.
//
// Governance context keys are set here intentionally (bypassing governance plugin)
// because in the MCP gateway path, identity is pre-authenticated via the OAuth session.
func injectMCPSessionIdentity(bifrostCtx *schemas.BifrostContext, session *tables.TablePerUserOAuthSession) {
bifrostCtx.SetValue(schemas.BifrostContextKeyIsMCPGateway, true)
if session != nil {
if session.AccessToken != "" {
bifrostCtx.SetValue(schemas.BifrostContextKeyMCPUserSession, session.AccessToken)
}
if session.VirtualKeyID != nil && *session.VirtualKeyID != "" && session.VirtualKey != nil && session.VirtualKey.Value != "" {
bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, session.VirtualKey.Value)
}
if session.UserID != nil && *session.UserID != "" {
bifrostCtx.SetValue(schemas.BifrostContextKeyUserID, *session.UserID)
}
}
}
func (h *MCPServerHandler) handleMCPServer(ctx *fasthttp.RequestCtx) {
mcpServer, session, err := h.getMCPServerForRequest(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusUnauthorized, err.Error())
return
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
injectMCPSessionIdentity(bifrostCtx, session)
// Use mcp-go server to handle the request
// HandleMessage processes JSON-RPC messages and returns appropriate responses
response := mcpServer.HandleMessage(bifrostCtx, ctx.PostBody())
// Check if response is nil (notification - no response needed)
if response == nil {
ctx.SetStatusCode(fasthttp.StatusAccepted)
return
}
// Marshal and send response
responseJSON, err := sonic.Marshal(response)
if err != nil {
logger.Warn(fmt.Sprintf("Failed to marshal MCP response: %v", err))
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to encode response: %v", err))
return
}
ctx.SetContentType("application/json")
ctx.SetBody(responseJSON)
}
// handleMCPServerSSE handles GET requests for MCP Server-Sent Events streaming
func (h *MCPServerHandler) handleMCPServerSSE(ctx *fasthttp.RequestCtx) {
_, session, err := h.getMCPServerForRequest(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusUnauthorized, err.Error())
return
}
// Set SSE headers
ctx.SetContentType("text/event-stream")
ctx.Response.Header.Set("Cache-Control", "no-cache")
ctx.Response.Header.Set("Connection", "keep-alive")
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
injectMCPSessionIdentity(bifrostCtx, session)
// Use SSEStreamReader to bypass fasthttp's internal pipe batching
reader := lib.NewSSEStreamReader()
ctx.Response.SetBodyStream(reader, -1)
go func() {
defer func() {
cancel()
reader.Done()
}()
// Send initial connection message
initMessage := map[string]interface{}{
"jsonrpc": "2.0",
"method": "connection/opened",
}
if initJSON, err := sonic.Marshal(initMessage); err == nil {
buf := make([]byte, 0, len(initJSON)+8)
buf = append(buf, "data: "...)
buf = append(buf, initJSON...)
buf = append(buf, '\n', '\n')
reader.Send(buf)
}
// Wait for context cancellation (client disconnect or server-side cancel)
<-(*bifrostCtx).Done()
}()
}
// Sync methods for MCP servers
func (h *MCPServerHandler) SyncAllMCPServers(ctx context.Context) error {
h.mu.Lock()
defer h.mu.Unlock()
availableTools := h.toolManager.GetAvailableMCPTools(ctx)
h.syncServer(h.globalMCPServer, availableTools, nil)
logger.Debug("Synced global MCP server with %d tools", len(availableTools))
// initialize vkMCPServers map
if h.config.ConfigStore != nil {
virtualKeys, err := h.config.ConfigStore.GetVirtualKeys(ctx)
if err != nil {
return fmt.Errorf("failed to get virtual keys: %w", err)
}
h.vkMCPServers = make(map[string]*server.MCPServer)
for i := range virtualKeys {
vk := &virtualKeys[i]
vkServer := server.NewMCPServer(
vk.Name,
version,
server.WithToolCapabilities(true),
)
server.WithToolFilter(h.makeIncludeClientsFilter())(vkServer)
h.vkMCPServers[vk.Value] = vkServer
availableTools, toolFilter := h.fetchToolsForVK(vk)
h.syncServer(h.vkMCPServers[vk.Value], availableTools, toolFilter)
logger.Debug("Synced MCP server for virtual key '%s' with %d tools", vk.Name, len(availableTools))
}
}
return nil
}
func (h *MCPServerHandler) SyncVKMCPServer(vk *tables.TableVirtualKey) {
h.mu.Lock()
defer h.mu.Unlock()
vkServer, ok := h.vkMCPServers[vk.Value]
if !ok {
// Add new server
vkServer = server.NewMCPServer(
vk.Name,
version,
server.WithToolCapabilities(true),
)
server.WithToolFilter(h.makeIncludeClientsFilter())(vkServer)
h.vkMCPServers[vk.Value] = vkServer
}
availableTools, toolFilter := h.fetchToolsForVK(vk)
h.syncServer(vkServer, availableTools, toolFilter)
h.vkMCPServers[vk.Value] = vkServer
logger.Debug("Synced MCP server for virtual key '%s' with %d tools", vk.Name, len(availableTools))
}
func (h *MCPServerHandler) DeleteVKMCPServer(vkValue string) {
h.mu.Lock()
defer h.mu.Unlock()
delete(h.vkMCPServers, vkValue)
}
func (h *MCPServerHandler) syncServer(server *server.MCPServer, availableTools []schemas.ChatTool, toolFilter []string) {
// Clear existing tools
toolMap := server.ListTools()
for toolName, _ := range toolMap {
server.DeleteTools(toolName)
}
// Register tools from all connected clients
for _, tool := range availableTools {
// Only process function tools (skip custom tools)
if tool.Function == nil {
continue
}
// Capture tool name for closure
toolName := tool.Function.Name
handler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
// Inject tool filter into execution context if present
if toolFilter != nil {
ctx = context.WithValue(ctx, schemas.MCPContextKeyIncludeTools, toolFilter)
}
// Convert to Bifrost tool call format
toolCallType := "function"
toolCallID := fmt.Sprintf("mcp-%s", toolName)
argsJSON, jsonErr := sonic.Marshal(request.GetArguments())
if jsonErr != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to marshal tool arguments: %v", jsonErr)), nil
}
toolCall := schemas.ChatAssistantMessageToolCall{
ID: &toolCallID,
Type: &toolCallType,
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: &toolName,
Arguments: string(argsJSON),
},
}
// Execute the tool via tool executor
toolMessage, err := h.toolManager.ExecuteChatMCPTool(ctx, &toolCall)
if err != nil {
if err.ExtraFields.MCPAuthRequired != nil {
return mcp.NewToolResultError(fmt.Sprintf(
"Authentication required for %s. Open this URL to connect your account: %s",
err.ExtraFields.MCPAuthRequired.MCPClientName, err.ExtraFields.MCPAuthRequired.AuthorizeURL,
)), nil
}
return mcp.NewToolResultError(fmt.Sprintf("Tool execution failed: %v", bifrost.GetErrorMessage(err))), nil
}
// Extract content from tool message
var resultText string
if toolMessage != nil && toolMessage.Content != nil {
// Handle ContentStr (string content)
if toolMessage.Content.ContentStr != nil {
resultText = *toolMessage.Content.ContentStr
} else if toolMessage.Content.ContentBlocks != nil {
// Handle ContentBlocks (structured content)
for _, block := range toolMessage.Content.ContentBlocks {
if block.Type == schemas.ChatContentBlockTypeText && block.Text != nil {
resultText += *block.Text
}
}
}
}
// Return result using mcp-go helper
return mcp.NewToolResultText(resultText), nil
}
// Convert description from *string to string
description := ""
if tool.Function.Description != nil {
description = *tool.Function.Description
}
// Convert Parameters to mcp.ToolInputSchema
var inputSchema mcp.ToolInputSchema
if tool.Function.Parameters != nil {
inputSchema.Type = tool.Function.Parameters.Type
if tool.Function.Parameters.Properties != nil {
// Convert *map[string]interface{} to map[string]any
props := make(map[string]any)
tool.Function.Parameters.Properties.Range(func(key string, value interface{}) bool {
props[key] = value
return true
})
inputSchema.Properties = props
}
if tool.Function.Parameters.Required != nil {
inputSchema.Required = tool.Function.Parameters.Required
}
} else {
// Default to empty object schema if no parameters
inputSchema.Type = "object"
inputSchema.Properties = make(map[string]any)
}
// Map Bifrost annotations back to MCP tool annotations
var toolAnnotation mcp.ToolAnnotation
if tool.Annotations != nil {
toolAnnotation = mcp.ToolAnnotation{
Title: tool.Annotations.Title,
ReadOnlyHint: tool.Annotations.ReadOnlyHint,
DestructiveHint: tool.Annotations.DestructiveHint,
IdempotentHint: tool.Annotations.IdempotentHint,
OpenWorldHint: tool.Annotations.OpenWorldHint,
}
}
// Register tool with the server
server.AddTool(mcp.Tool{
Name: toolName,
Description: description,
InputSchema: inputSchema,
Annotations: toolAnnotation,
}, handler)
}
}
// fetchToolsForVK fetches the tools for a given virtual key value.
// vkValue is the virtual key value for the server, if empty, all tools will be fetched for global mcp server.
// Returns the list of available tools and the tool filter to be applied during execution.
func (h *MCPServerHandler) fetchToolsForVK(vk *tables.TableVirtualKey) ([]schemas.ChatTool, []string) {
ctx := context.Background()
var toolFilter []string
executeOnlyTools := make([]string, 0)
// Build a lookup of AllowOnAllVirtualKeys clients: clientID -> clientName.
// Explicit VK MCPConfigs always take precedence over AllowOnAllVirtualKeys.
allowAllVKsClients := h.config.GetAllowOnAllVirtualKeysClients()
if allowAllVKsClients == nil {
allowAllVKsClients = make(map[string]string)
}
// Process explicit VK MCPConfigs first.
handledClients := make(map[string]bool)
for _, vkMcpConfig := range vk.MCPConfigs {
clientID := vkMcpConfig.MCPClient.ClientID
if _, isAllowAll := allowAllVKsClients[clientID]; isAllowAll {
// Explicit config exists — it takes precedence; mark handled regardless of tool list.
handledClients[clientID] = true
}
if vkMcpConfig.ToolsToExecute.IsEmpty() {
continue
}
if vkMcpConfig.ToolsToExecute.IsUnrestricted() {
executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-*", vkMcpConfig.MCPClient.Name))
continue
}
for _, tool := range vkMcpConfig.ToolsToExecute {
if tool != "" {
// Add the tool - client config filtering will be handled by mcp.go
// Note: Use '-' separator for individual tools (wildcard uses '-*' after client name, e.g., "client-*")
executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-%s", vkMcpConfig.MCPClient.Name, tool))
}
}
}
// For AllowOnAllVirtualKeys clients with no explicit VK config, allow all their tools.
for clientID, clientName := range allowAllVKsClients {
if !handledClients[clientID] {
executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-*", clientName))
}
}
// Always set the include-tools filter (empty = deny-all when no MCPConfigs and no AllowOnAllVirtualKeys clients)
ctx = context.WithValue(ctx, schemas.MCPContextKeyIncludeTools, executeOnlyTools)
toolFilter = executeOnlyTools
return h.toolManager.GetAvailableMCPTools(ctx), toolFilter
}
// makeIncludeClientsFilter returns a ToolFilterFunc that dynamically filters the tools/list
// response based on the x-bf-mcp-include-clients and x-bf-mcp-include-tools request headers.
// When neither header is present the filter is a no-op, preserving existing behaviour.
func (h *MCPServerHandler) makeIncludeClientsFilter() server.ToolFilterFunc {
return func(ctx context.Context, tools []mcp.Tool) []mcp.Tool {
if ctx.Value(schemas.MCPContextKeyIncludeClients) == nil && ctx.Value(schemas.MCPContextKeyIncludeTools) == nil {
return tools
}
allowed := h.toolManager.GetAvailableMCPTools(ctx)
allowedNames := make(map[string]bool, len(allowed))
for _, t := range allowed {
if t.Function != nil {
allowedNames[t.Function.Name] = true
}
}
result := make([]mcp.Tool, 0, len(tools))
for _, tool := range tools {
if allowedNames[tool.Name] {
result = append(result, tool)
}
}
return result
}
}
// Utility methods
func (h *MCPServerHandler) getMCPServerForRequest(ctx *fasthttp.RequestCtx) (*server.MCPServer, *tables.TablePerUserOAuthSession, error) {
h.mu.RLock()
defer h.mu.RUnlock()
h.config.Mu.RLock()
enforceVK := h.config.ClientConfig.EnforceAuthOnInference
h.config.Mu.RUnlock()
vk := getVKFromRequest(ctx)
// Check for Bifrost per-user OAuth Bearer token (not a VK)
userOauthSession, sessionErr := h.getPerUserOAuthSession(ctx)
if sessionErr != nil {
return nil, nil, fmt.Errorf("failed to look up OAuth session: %w", sessionErr)
}
// If per_user_oauth MCP clients are configured and no valid auth, return 401 with discovery
if clients := h.config.GetPerUserOAuthMCPClients(); len(clients) > 0 && userOauthSession == nil && vk == "" {
scheme := "http"
if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" {
scheme = "https"
}
host := string(ctx.Host())
resourceMetadataURL := fmt.Sprintf("%s://%s/.well-known/oauth-protected-resource", scheme, host)
ctx.Response.Header.Set("WWW-Authenticate",
fmt.Sprintf(`Bearer resource_metadata="%s"`, resourceMetadataURL))
return nil, nil, fmt.Errorf("oauth authentication required for mcp access")
}
if userOauthSession != nil {
if !enforceVK && (userOauthSession.VirtualKeyID == nil || *userOauthSession.VirtualKeyID == "") {
return h.globalMCPServer, userOauthSession, nil
}
if userOauthSession.VirtualKeyID == nil || *userOauthSession.VirtualKeyID == "" || userOauthSession.VirtualKey == nil {
return nil, nil, fmt.Errorf("virtual key required in oauth session to access mcp server, please re-authenticate with a virtual key")
}
vkServer, ok := h.vkMCPServers[userOauthSession.VirtualKey.Value]
if !ok {
return nil, nil, fmt.Errorf("virtual key not found")
}
return vkServer, userOauthSession, nil
}
// Return global MCP server if not enforcing virtual key header and no virtual key is provided
if !enforceVK && vk == "" {
return h.globalMCPServer, nil, nil
}
if vk == "" {
return nil, nil, fmt.Errorf("virtual key header required to access mcp server")
}
vkServer, ok := h.vkMCPServers[vk]
if !ok {
return nil, nil, fmt.Errorf("virtual key not found")
}
return vkServer, nil, nil
}
// getPerUserOAuthSession extracts and validates a Bifrost-issued per-user OAuth
// token from the Authorization header. Returns the session if valid, nil otherwise.
func (h *MCPServerHandler) getPerUserOAuthSession(ctx *fasthttp.RequestCtx) (*tables.TablePerUserOAuthSession, error) {
authHeader := strings.TrimSpace(string(ctx.Request.Header.Peek("Authorization")))
if authHeader == "" || !strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
return nil, nil
}
token := strings.TrimSpace(authHeader[7:])
if token == "" || strings.HasPrefix(strings.ToLower(token), governance.VirtualKeyPrefix) {
return nil, nil // It's a virtual key, not a per-user OAuth token
}
if h.config.ConfigStore == nil {
return nil, nil
}
session, err := h.config.ConfigStore.GetPerUserOAuthSessionByAccessToken(ctx, token)
if err != nil {
logger.Warn("[mcp/auth] GetPerUserOAuthSessionByAccessToken error: %v", err)
return nil, err
}
if session == nil {
logger.Debug("[mcp/auth] Session not found for token")
return nil, nil
}
// Check expiry
if session.ExpiresAt.Before(time.Now()) {
logger.Debug("[mcp/auth] Session expired: session_id=%s expires_at=%v", session.ID, session.ExpiresAt)
return nil, nil
}
return session, nil
}
func getVKFromRequest(ctx *fasthttp.RequestCtx) string {
if value := strings.TrimSpace(string(ctx.Request.Header.Peek(string(schemas.BifrostContextKeyVirtualKey)))); value != "" {
return value
}
authHeader := strings.TrimSpace(string(ctx.Request.Header.Peek("Authorization")))
if authHeader != "" {
if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
token := strings.TrimSpace(authHeader[7:])
if token != "" && strings.HasPrefix(strings.ToLower(token), governance.VirtualKeyPrefix) {
return token
}
}
}
if apiKey := strings.TrimSpace(string(ctx.Request.Header.Peek("x-api-key"))); apiKey != "" {
if strings.HasPrefix(strings.ToLower(apiKey), governance.VirtualKeyPrefix) {
return apiKey
}
}
return ""
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,320 @@
// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
// This file contains OAuth 2.0 authentication flow handlers.
package handlers
import (
"context"
"encoding/json"
"errors"
"fmt"
"html"
"net/url"
"strings"
"github.com/fasthttp/router"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/oauth2"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
// OAuth2Handler manages HTTP requests for OAuth2 operations
type OAuthHandler struct {
client *bifrost.Bifrost
store *lib.Config
oauthProvider *oauth2.OAuth2Provider
}
// NewOAuthHandler creates a new OAuth handler instance
func NewOAuthHandler(oauthProvider *oauth2.OAuth2Provider, client *bifrost.Bifrost, store *lib.Config) *OAuthHandler {
return &OAuthHandler{
client: client,
store: store,
oauthProvider: oauthProvider,
}
}
// RegisterRoutes registers all OAuth-related routes
func (h *OAuthHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
r.GET("/api/oauth/callback", lib.ChainMiddlewares(h.handleOAuthCallback, middlewares...))
r.GET("/api/oauth/config/{id}/status", lib.ChainMiddlewares(h.getOAuthConfigStatus, middlewares...))
r.DELETE("/api/oauth/config/{id}", lib.ChainMiddlewares(h.revokeOAuthConfig, middlewares...))
}
// handleOAuthCallback handles the OAuth provider callback
// GET /api/oauth/callback?state=xxx&code=yyy&error=zzz
func (h *OAuthHandler) handleOAuthCallback(ctx *fasthttp.RequestCtx) {
state := string(ctx.QueryArgs().Peek("state"))
code := string(ctx.QueryArgs().Peek("code"))
errorParam := string(ctx.QueryArgs().Peek("error"))
errorDescription := string(ctx.QueryArgs().Peek("error_description"))
// Handle authorization denial
if errorParam != "" {
h.handleCallbackError(ctx, state, errorParam, errorDescription)
return
}
// Validate required parameters
if state == "" || code == "" {
SendError(ctx, fasthttp.StatusBadRequest, "Missing required parameters: state and code")
return
}
// Try per-user OAuth runtime flow first (state from oauth_user_sessions table).
// This handles the case where an end-user authenticates during inference.
sessionToken, perUserErr := h.oauthProvider.CompleteUserOAuthFlow(context.Background(), state, code)
if perUserErr != nil && !errors.Is(perUserErr, schemas.ErrOAuth2NotPerUserSession) {
// Real per-user error (not "state not found") — don't fall through to admin flow
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Per-user OAuth flow failed: %v", perUserErr))
return
}
if perUserErr == nil && sessionToken != "" {
// Consent flow: session token is a flow proxy ("flow:<flowID>:<mcpClientID>").
// Redirect back to the MCPs consent page so the user can continue.
if strings.HasPrefix(sessionToken, "flow:") {
rest := strings.TrimPrefix(sessionToken, "flow:")
flowID := strings.SplitN(rest, ":", 2)[0]
mcpsURL := fmt.Sprintf("/oauth/consent/mcps?flow_id=%s", url.QueryEscape(flowID))
ctx.Redirect(mcpsURL, fasthttp.StatusFound)
return
}
// Per-user runtime OAuth flow completed — show success page.
ctx.SetStatusCode(fasthttp.StatusOK)
ctx.SetContentType("text/html")
ctx.SetBodyString(oauthSuccessPage(`
if (window.opener) {
window.opener.postMessage({ type: 'oauth_success' }, window.location.origin);
window.close();
}
`, "Authorization Successful", "You can close this tab."))
return
}
// Fall through to standard OAuth flow (handles both admin test logins for
// per_user_oauth setup and regular server-level OAuth).
if err := h.oauthProvider.CompleteOAuthFlow(context.Background(), state, code); err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("OAuth flow completion failed: %v", err))
return
}
// Redirect to success page (or close popup)
ctx.SetStatusCode(fasthttp.StatusOK)
ctx.SetContentType("text/html")
ctx.SetBodyString(oauthSuccessPage(`
if (window.opener) {
window.opener.postMessage({ type: 'oauth_success' }, window.location.origin);
window.close();
}
`, "Authorization Successful", "OAuth authorization successful! You can close this window."))
}
// handleCallbackError handles OAuth callback errors
func (h *OAuthHandler) handleCallbackError(ctx *fasthttp.RequestCtx, state, errorParam, errorDescription string) {
// Update OAuth config status to failed if state is provided
if state != "" {
oauthConfig, err := h.store.ConfigStore.GetOauthConfigByState(context.Background(), state)
if err == nil && oauthConfig != nil {
oauthConfig.Status = "failed"
h.store.ConfigStore.UpdateOauthConfig(context.Background(), oauthConfig)
}
}
// Show error page
ctx.SetStatusCode(fasthttp.StatusBadRequest)
ctx.SetContentType("text/html")
errorMsg := errorParam
if errorDescription != "" {
errorMsg = fmt.Sprintf("%s: %s", errorParam, errorDescription)
}
// JSON-encode for safe embedding in JavaScript context (prevents JS injection)
jsEscaped, _ := json.Marshal(errorMsg)
// HTML-escape for safe embedding in HTML body (prevents HTML injection)
htmlEscaped := html.EscapeString(errorMsg)
ctx.SetBodyString(oauthErrorPage(string(jsEscaped), htmlEscaped))
}
// getOAuthConfigStatus returns the current status of an OAuth config
// GET /api/oauth/config/{id}/status
func (h *OAuthHandler) getOAuthConfigStatus(ctx *fasthttp.RequestCtx) {
configID := ctx.UserValue("id").(string)
oauthConfig, err := h.store.ConfigStore.GetOauthConfigByID(context.Background(), configID)
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get OAuth config: %v", err))
return
}
if oauthConfig == nil {
SendError(ctx, fasthttp.StatusNotFound, "OAuth config not found")
return
}
response := map[string]interface{}{
"id": oauthConfig.ID,
"status": oauthConfig.Status,
"created_at": oauthConfig.CreatedAt,
"expires_at": oauthConfig.ExpiresAt,
}
if oauthConfig.Status == "authorized" && oauthConfig.TokenID != nil {
response["token_id"] = *oauthConfig.TokenID
// Get token metadata
token, err := h.store.ConfigStore.GetOauthTokenByID(context.Background(), *oauthConfig.TokenID)
if err == nil && token != nil {
response["token_expires_at"] = token.ExpiresAt
response["token_scopes"] = token.Scopes
}
}
SendJSON(ctx, response)
}
// revokeOAuthConfig revokes an OAuth configuration and its associated token
// DELETE /api/oauth/config/{id}
func (h *OAuthHandler) revokeOAuthConfig(ctx *fasthttp.RequestCtx) {
configID := ctx.UserValue("id").(string)
if err := h.oauthProvider.RevokeToken(context.Background(), configID); err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to revoke OAuth token: %v", err))
return
}
SendJSON(ctx, map[string]interface{}{
"message": "OAuth token revoked successfully",
})
}
// OAuthInitiationRequest represents the request to initiate an OAuth flow
type OAuthInitiationRequest struct {
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
AuthorizeURL string `json:"authorize_url"`
TokenURL string `json:"token_url"`
RegistrationURL string `json:"registration_url"`
RedirectURI string `json:"redirect_uri"`
Scopes []string `json:"scopes"`
ServerURL string `json:"server_url"` // For OAuth discovery
}
// InitiateOAuthFlow initiates an OAuth flow and returns the authorization URL
// This is called internally by the MCP client creation endpoint
func (h *OAuthHandler) InitiateOAuthFlow(ctx context.Context, req OAuthInitiationRequest) (*schemas.OAuth2FlowInitiation, error) {
var registrationURL *string
if req.RegistrationURL != "" {
registrationURL = &req.RegistrationURL
}
config := &schemas.OAuth2Config{
ClientID: req.ClientID,
ClientSecret: req.ClientSecret,
AuthorizeURL: req.AuthorizeURL,
TokenURL: req.TokenURL,
RegistrationURL: registrationURL,
RedirectURI: req.RedirectURI,
Scopes: req.Scopes,
ServerURL: req.ServerURL, // MCP server URL for OAuth discovery
}
return h.oauthProvider.InitiateOAuthFlow(ctx, config)
}
// StorePendingMCPClient stores an MCP client config in the database while waiting for OAuth completion
// This supports multi-instance deployments where OAuth callback may hit a different server instance
func (h *OAuthHandler) StorePendingMCPClient(oauthConfigID string, mcpClientConfig schemas.MCPClientConfig) error {
return h.oauthProvider.StorePendingMCPClient(oauthConfigID, mcpClientConfig)
}
// GetPendingMCPClient retrieves a pending MCP client config by oauth_config_id
func (h *OAuthHandler) GetPendingMCPClient(oauthConfigID string) (*schemas.MCPClientConfig, error) {
return h.oauthProvider.GetPendingMCPClient(oauthConfigID)
}
// GetPendingMCPClientByState retrieves a pending MCP client config by OAuth state token
func (h *OAuthHandler) GetPendingMCPClientByState(state string) (*schemas.MCPClientConfig, string, error) {
return h.oauthProvider.GetPendingMCPClientByState(state)
}
// RemovePendingMCPClient removes a pending MCP client after OAuth completion.
func (h *OAuthHandler) RemovePendingMCPClient(oauthConfigID string) error {
return h.oauthProvider.RemovePendingMCPClient(oauthConfigID)
}
// GetAccessToken retrieves the access token for a given oauth_config_id.
// Used during per-user OAuth setup to get the admin's temporary token for verification.
func (h *OAuthHandler) GetAccessToken(ctx context.Context, oauthConfigID string) (string, error) {
return h.oauthProvider.GetAccessToken(ctx, oauthConfigID)
}
// oauthSuccessPage renders a Bifrost-themed success HTML page.
// extraScript is injected verbatim into a <script> tag (caller is responsible for safety).
func oauthSuccessPage(extraScript, title, message string) string {
return fmt.Sprintf(`<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>%s</title>
<style>%s
.icon{font-size:2.5rem;margin-bottom:16px}
.msg{font-size:0.9rem;color:oklch(0.552 0.016 285.938);margin-top:8px}
</style>
<script>%s</script>
</head>
<body>
<div class="card" style="text-align:center">
<div class="icon">&#10003;</div>
<h1>%s</h1>
<p class="msg">%s</p>
</div>
</body>
</html>`, html.EscapeString(title), bifrostPageCSS, extraScript, html.EscapeString(title), html.EscapeString(message))
}
// oauthErrorPage renders a Bifrost-themed error HTML page.
// jsEscapedError must already be JSON-encoded (with quotes) for safe JS embedding.
// htmlError must already be HTML-escaped for safe body embedding.
func oauthErrorPage(jsEscapedError, htmlError string) string {
return fmt.Sprintf(`<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Authorization Failed</title>
<style>%s
.icon{font-size:2.5rem;margin-bottom:16px;color:oklch(0.50 0.18 27)}
.err-msg{font-size:0.9rem;color:oklch(0.552 0.016 285.938);margin-top:8px}
.hint{font-size:0.8rem;color:oklch(0.65 0.01 286);margin-top:16px}
</style>
<script>
if (window.opener) {
window.opener.postMessage({ type: 'oauth_failed', error: %s }, window.location.origin);
window.close();
}
</script>
</head>
<body>
<div class="card" style="text-align:center">
<div class="icon">&#10007;</div>
<h1>Authorization Failed</h1>
<p class="err-msg">%s</p>
<p class="hint">You can close this window.</p>
</div>
</body>
</html>`, bifrostPageCSS, jsEscapedError, htmlError)
}
// jsEscapeString returns a JSON-encoded string (with quotes) safe for embedding in JavaScript.
func jsEscapeString(s string) string {
b, _ := json.Marshal(s)
return string(b)
}
// RevokeToken revokes the OAuth token for a given oauth_config_id.
// Used during per-user OAuth setup to discard the admin's temporary token after verification.
func (h *OAuthHandler) RevokeToken(ctx context.Context, oauthConfigID string) error {
return h.oauthProvider.RevokeToken(ctx, oauthConfigID)
}

View File

@@ -0,0 +1,643 @@
// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
// This file implements the per-user OAuth consent flow — the intermediate screens
// shown between the MCP client's authorize request and the final authorization code
// issuance. The flow is:
//
// 1. GET /oauth/consent?flow_id=xxx → VK input page (HTML)
// 2. POST /api/oauth/per-user/consent/vk → validate VK, update PendingFlow, redirect
// 3. GET /oauth/consent/mcps?flow_id=xxx → MCPs page (HTML, server-rendered)
// 4. POST /api/oauth/per-user/consent/submit → create session + code, redirect to client
package handlers
import (
"errors"
"fmt"
"html"
"net/url"
"sort"
"strings"
"time"
"github.com/fasthttp/router"
"github.com/google/uuid"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
// ConsentHandler manages the per-user OAuth consent flow screens.
type ConsentHandler struct {
store *lib.Config
}
// NewConsentHandler creates a new consent handler instance.
func NewConsentHandler(store *lib.Config) *ConsentHandler {
return &ConsentHandler{store: store}
}
// RegisterRoutes registers the consent flow routes.
// All routes are public — no auth middleware — since they are part of the OAuth
// flow for unauthenticated users acquiring credentials.
func (h *ConsentHandler) RegisterRoutes(r *router.Router) {
// HTML pages (GET, served by Go)
r.GET("/oauth/consent", h.handleIdentityPage)
r.GET("/oauth/consent/mcps", h.handleMCPsPage)
// API actions (POST)
// NOTE: All state-mutating endpoints use POST. CSRF protection relies on the
// SameSite=Lax browser-binding cookie (__bifrost_flow_secret) combined with
// the flow_id — SameSite=Lax blocks cross-site POST, and the cookie is
// HttpOnly+Secure. This is sufficient for the threat model here.
r.POST("/api/oauth/per-user/consent/vk", h.handleSubmitVK)
r.POST("/api/oauth/per-user/consent/user-id", h.handleSubmitUserID)
r.POST("/api/oauth/per-user/consent/skip", h.handleSkip)
r.POST("/api/oauth/per-user/consent/submit", h.handleSubmit)
}
// ---------- HTML pages ----------
// handleIdentityPage renders the identity selection page with three options:
// User ID, Virtual Key, or skip (lazy auth when tools are called).
// GET /oauth/consent?flow_id=xxx[&error=xxx]
func (h *ConsentHandler) handleIdentityPage(ctx *fasthttp.RequestCtx) {
flowID := string(ctx.QueryArgs().Peek("flow_id"))
errorMsg := string(ctx.QueryArgs().Peek("error"))
if flowID == "" {
ctx.SetStatusCode(fasthttp.StatusBadRequest)
ctx.SetBodyString("Missing flow_id")
return
}
if h.store.ConfigStore == nil {
ctx.SetStatusCode(fasthttp.StatusServiceUnavailable)
ctx.SetBodyString("Config store unavailable")
return
}
flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID)
if err != nil {
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
ctx.SetBodyString("Failed to load consent flow.")
return
}
if flow == nil || time.Now().After(flow.ExpiresAt) {
ctx.SetStatusCode(fasthttp.StatusBadRequest)
ctx.SetBodyString("Invalid or expired consent flow. Please restart the authentication process.")
return
}
if !validateFlowBrowserSecret(ctx, flow) {
ctx.SetStatusCode(fasthttp.StatusForbidden)
ctx.SetBodyString("Flow does not belong to this browser session. Please restart the authentication process.")
return
}
h.store.Mu.RLock()
enforceVK := h.store.ClientConfig.EnforceAuthOnInference
h.store.Mu.RUnlock()
safeFlowID := html.EscapeString(flowID)
safeError := html.EscapeString(errorMsg)
errorBanner := ""
if safeError != "" {
errorBanner = fmt.Sprintf(`<div class="error-banner">%s</div>`, safeError)
}
skipOption := ""
if !enforceVK {
skipOption = fmt.Sprintf(`
<div class="option">
<span class="option-title">Skip for now</span>
<span class="option-desc">Connect to services when a tool is called</span>
<form action="/api/oauth/per-user/consent/skip" method="POST" style="margin-top:10px">
<input type="hidden" name="flow_id" value="%s">
<button type="submit" class="btn btn-ghost">Skip</button>
</form>
</div>`, safeFlowID)
}
ctx.SetStatusCode(fasthttp.StatusOK)
ctx.SetContentType("text/html; charset=utf-8")
ctx.SetBodyString(fmt.Sprintf(`<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Connect to Bifrost</title>
<style>
%s
.option{border:1px solid oklch(0.92 0.004 286.32);border-radius:0.5rem;padding:16px 18px;margin-bottom:10px}
.option-title{display:block;font-size:0.9rem;font-weight:600;color:oklch(0.141 0.005 285.823);margin-bottom:2px}
.option-desc{display:block;font-size:0.8rem;color:oklch(0.552 0.016 285.938);margin-bottom:12px}
</style>
</head>
<body>
<div class="card">
<h1>Connect to Bifrost</h1>
<p class="subtitle">Choose how to identify yourself for this session.</p>
<p style="font-size:0.75rem;color:oklch(0.65 0.01 286);margin-bottom:18px">This setup page expires in 15 minutes.</p>
%s
<div class="option">
<span class="option-title">User ID</span>
<span class="option-desc">Use a stable identifier — access all available services</span>
<form action="/api/oauth/per-user/consent/user-id" method="POST">
<input type="hidden" name="flow_id" value="%s">
<label for="user_id">User ID</label>
<input type="text" id="user_id" name="user_id" placeholder="e.g. alice" autocomplete="off" spellcheck="false" autocapitalize="none" autocorrect="off">
<button type="submit" class="btn btn-primary">Continue with User ID</button>
</form>
</div>
<div class="option">
<span class="option-title">Virtual Key</span>
<span class="option-desc">Use a VK — access services within your key's limits</span>
<form action="/api/oauth/per-user/consent/vk" method="POST">
<input type="hidden" name="flow_id" value="%s">
<label for="vk">Virtual Key</label>
<input type="password" id="vk" name="vk" placeholder="sk-bf-..." autocomplete="off" spellcheck="false" autocapitalize="none">
<button type="submit" class="btn btn-primary">Continue with Virtual Key</button>
</form>
</div>
%s
</div>
</body>
</html>`, bifrostPageCSS, errorBanner, safeFlowID, safeFlowID, skipOption))
}
// handleMCPsPage renders the MCP authentication list page.
// GET /oauth/consent/mcps?flow_id=xxx
func (h *ConsentHandler) handleMCPsPage(ctx *fasthttp.RequestCtx) {
flowID := string(ctx.QueryArgs().Peek("flow_id"))
if flowID == "" {
ctx.SetStatusCode(fasthttp.StatusBadRequest)
ctx.SetBodyString("Missing flow_id")
return
}
if h.store.ConfigStore == nil {
ctx.SetStatusCode(fasthttp.StatusServiceUnavailable)
ctx.SetBodyString("Config store unavailable")
return
}
flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID)
if err != nil {
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
ctx.SetBodyString("Failed to load consent flow.")
return
}
if flow == nil || time.Now().After(flow.ExpiresAt) {
ctx.SetStatusCode(fasthttp.StatusBadRequest)
ctx.SetBodyString("Invalid or expired consent flow. Please restart the authentication process.")
return
}
if !validateFlowBrowserSecret(ctx, flow) {
ctx.SetStatusCode(fasthttp.StatusForbidden)
ctx.SetBodyString("Flow does not belong to this browser session. Please restart the authentication process.")
return
}
// Find which MCP clients the user has already authed.
// Check both: tokens stored under the flow proxy (connected during this flow)
// and tokens already stored under the VK/user identity (connected in a prior flow).
completedTokens, err := h.store.ConfigStore.GetOauthUserTokensByGatewaySessionID(ctx, flowID)
if err != nil {
completedTokens = nil // non-fatal; just show no checkmarks
}
completedMCPs := make(map[string]bool, len(completedTokens))
for _, t := range completedTokens {
completedMCPs[t.MCPClientID] = true
}
// Per_user_oauth MCP clients visible to this identity — sorted for deterministic rendering.
// When a VK is set on the flow, only show clients that VK is allowed to use.
perUserClients := h.store.GetPerUserOAuthMCPClientsForVirtualKey(ctx, strVal(flow.VirtualKeyID))
clientIDs := make([]string, 0, len(perUserClients))
for id := range perUserClients {
clientIDs = append(clientIDs, id)
}
sort.Strings(clientIDs)
safeFlowID := html.EscapeString(flowID)
// Determine if user skipped identity selection.
isSkipped := strVal(flow.VirtualKeyID) == "" && strVal(flow.UserID) == ""
// Build MCP rows — only show connect buttons if user has an identity.
var mcpRows strings.Builder
if isSkipped {
mcpRows.WriteString(`<p style="color:#6b7280;font-size:14px;">You skipped identity selection. Services will be connected when you first use their tools. Since no identity is attached, your connections will only persist as long as the service keeps the OAuth token active — they will not be remembered across sessions.</p>`)
} else {
for _, clientID := range clientIDs {
clientName := perUserClients[clientID]
safeName := html.EscapeString(clientName)
// Also check if a token already exists under the user's identity (e.g. from a prior LLM gateway auth).
alreadyConnected := completedMCPs[clientID]
if !alreadyConnected && (strVal(flow.VirtualKeyID) != "" || strVal(flow.UserID) != "") {
existing, tokenErr := h.store.ConfigStore.GetOauthUserTokenByIdentity(ctx, strVal(flow.VirtualKeyID), strVal(flow.UserID), "", clientID)
if tokenErr != nil {
logger.Warn("[consent/mcps] failed to check existing token: mcp_client_id=%s err=%v", clientID, tokenErr)
}
alreadyConnected = existing != nil
}
if alreadyConnected {
mcpRows.WriteString(fmt.Sprintf(`
<div class="mcp-row">
<div class="mcp-name">%s</div>
<span class="badge connected">&#10003; Connected</span>
</div>`, safeName))
} else {
connectURL := fmt.Sprintf("/api/oauth/per-user/upstream/authorize?mcp_client_id=%s&flow_id=%s",
url.QueryEscape(clientID), url.QueryEscape(flowID))
mcpRows.WriteString(fmt.Sprintf(`
<div class="mcp-row">
<div class="mcp-name">%s</div>
<a class="badge connect" href="%s">Connect</a>
</div>`, safeName, html.EscapeString(connectURL)))
}
}
if len(perUserClients) == 0 {
mcpRows.WriteString(`<p style="color:#6b7280;font-size:14px;">No MCP services require authentication.</p>`)
}
}
ctx.SetStatusCode(fasthttp.StatusOK)
ctx.SetContentType("text/html; charset=utf-8")
ctx.SetBodyString(fmt.Sprintf(`<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Connect Your Apps — Bifrost</title>
<style>
%s
.mcp-row{display:flex;align-items:center;justify-content:space-between;padding:12px 0;border-bottom:1px solid oklch(0.92 0.004 286.32)}
.mcp-row:last-of-type{border-bottom:none}
.mcp-name{font-size:0.9rem;font-weight:500;color:oklch(0.141 0.005 285.823)}
.badge{font-size:0.8rem;font-weight:500;padding:4px 12px;border-radius:20px;text-decoration:none;display:inline-block}
.badge.connected{background:oklch(0.95 0.05 160);color:oklch(0.35 0.08 160)}
.badge.connect{background:oklch(0.5081 0.1049 165.61);color:oklch(0.985 0 0);cursor:pointer;
padding:8px 18px;border-radius:0.5rem;font-weight:500;
transition:background .15s}
.badge.connect:hover{background:oklch(0.43 0.1049 165.61)}
.mcp-list{margin-bottom:4px}
</style>
</head>
<body>
<div class="card">
<h1>Connect Your Apps</h1>
<p class="subtitle">Authenticate with the services below to enable their tools.</p>
<p style="font-size:0.75rem;color:oklch(0.65 0.01 286);margin-bottom:18px">This setup page expires in 15 minutes.</p>
<div class="mcp-list">%s</div>
<form action="/api/oauth/per-user/consent/submit" method="POST" style="margin-top:24px">
<input type="hidden" name="flow_id" value="%s">
<button type="submit" class="btn btn-primary">Finish Setup</button>
</form>
<div style="text-align:center;margin-top:12px">
<a href="/oauth/consent?flow_id=%s" style="font-size:0.8rem;color:oklch(0.552 0.016 285.938);text-decoration:none">Change identity</a>
</div>
</div>
</body>
</html>`, bifrostPageCSS, mcpRows.String(), safeFlowID, safeFlowID))
}
// ---------- API action handlers ----------
// handleSubmitVK validates the submitted Virtual Key, links it to the pending flow,
// and redirects to the MCPs page.
// POST /api/oauth/per-user/consent/vk (form: flow_id, vk)
func (h *ConsentHandler) handleSubmitVK(ctx *fasthttp.RequestCtx) {
if h.store.ConfigStore == nil {
SendError(ctx, fasthttp.StatusServiceUnavailable, "Config store unavailable")
return
}
flowID := string(ctx.FormValue("flow_id"))
vkValue := strings.TrimSpace(string(ctx.FormValue("vk")))
if flowID == "" {
SendError(ctx, fasthttp.StatusBadRequest, "flow_id is required")
return
}
flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID)
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to load consent flow")
return
}
if flow == nil || time.Now().After(flow.ExpiresAt) {
SendError(ctx, fasthttp.StatusBadRequest, "Invalid or expired consent flow")
return
}
if !validateFlowBrowserSecret(ctx, flow) {
SendError(ctx, fasthttp.StatusForbidden, "Flow does not belong to this browser session")
return
}
if vkValue == "" {
redirectToIdentityPage(ctx, flowID, "Please enter a Virtual Key.")
return
}
vk, err := h.store.ConfigStore.GetVirtualKeyByValue(ctx, vkValue)
if err != nil {
redirectToIdentityPage(ctx, flowID, "Failed to validate Virtual Key. Please try again.")
return
}
if vk == nil || !vk.IsActive {
redirectToIdentityPage(ctx, flowID, "Virtual Key not found or inactive. Please check and try again.")
return
}
flow.VirtualKeyID = &vk.ID
flow.UserID = nil // Clear other identity to keep selection exclusive
if err := h.store.ConfigStore.UpdatePerUserOAuthPendingFlow(ctx, flow); err != nil {
redirectToIdentityPage(ctx, flowID, "Failed to save Virtual Key. Please try again.")
return
}
ctx.Redirect(fmt.Sprintf("/oauth/consent/mcps?flow_id=%s", url.QueryEscape(flowID)), fasthttp.StatusFound)
}
// handleSubmitUserID links a user-supplied User ID to the pending flow and proceeds to MCPs page.
// SECURITY: The User ID is self-declared (typed in a form) with no server-side verification.
// This matches the trust model of X-Bf-User-Id in the LLM gateway path. Deployments requiring
// verified identity should use Virtual Keys or an auth layer in front of Bifrost.
// POST /api/oauth/per-user/consent/user-id (form: flow_id, user_id)
func (h *ConsentHandler) handleSubmitUserID(ctx *fasthttp.RequestCtx) {
if h.store.ConfigStore == nil {
SendError(ctx, fasthttp.StatusServiceUnavailable, "Config store unavailable")
return
}
flowID := string(ctx.FormValue("flow_id"))
userID := strings.TrimSpace(string(ctx.FormValue("user_id")))
if flowID == "" {
SendError(ctx, fasthttp.StatusBadRequest, "flow_id is required")
return
}
flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID)
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to load consent flow")
return
}
if flow == nil || time.Now().After(flow.ExpiresAt) {
SendError(ctx, fasthttp.StatusBadRequest, "Invalid or expired consent flow")
return
}
if !validateFlowBrowserSecret(ctx, flow) {
SendError(ctx, fasthttp.StatusForbidden, "Flow does not belong to this browser session")
return
}
if userID == "" {
redirectToIdentityPage(ctx, flowID, "Please enter a User ID.")
return
}
if len(userID) > 255 {
redirectToIdentityPage(ctx, flowID, "User ID is too long (max 255 characters).")
return
}
if userID != "" {
flow.UserID = &userID
}
flow.VirtualKeyID = nil // Clear other identity to keep selection exclusive
if err := h.store.ConfigStore.UpdatePerUserOAuthPendingFlow(ctx, flow); err != nil {
redirectToIdentityPage(ctx, flowID, "Failed to save User ID. Please try again.")
return
}
ctx.Redirect(fmt.Sprintf("/oauth/consent/mcps?flow_id=%s", url.QueryEscape(flowID)), fasthttp.StatusFound)
}
// handleSkip skips identity selection and proceeds directly to the MCPs page.
// Upstream services will be connected lazily when tools are first called.
// POST /api/oauth/per-user/consent/skip (form: flow_id)
func (h *ConsentHandler) handleSkip(ctx *fasthttp.RequestCtx) {
if h.store.ConfigStore == nil {
SendError(ctx, fasthttp.StatusServiceUnavailable, "Config store unavailable")
return
}
flowID := string(ctx.FormValue("flow_id"))
if flowID == "" {
SendError(ctx, fasthttp.StatusBadRequest, "flow_id is required")
return
}
flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID)
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to load consent flow")
return
}
if flow == nil || time.Now().After(flow.ExpiresAt) {
SendError(ctx, fasthttp.StatusBadRequest, "Invalid or expired consent flow")
return
}
if !validateFlowBrowserSecret(ctx, flow) {
SendError(ctx, fasthttp.StatusForbidden, "Flow does not belong to this browser session")
return
}
h.store.Mu.RLock()
enforceVK := h.store.ClientConfig.EnforceAuthOnInference
h.store.Mu.RUnlock()
if enforceVK {
redirectToIdentityPage(ctx, flowID, "An identity (Virtual Key or User ID) is required. Please choose one to continue.")
return
}
// Clear any previously selected identity so skip truly resets the flow.
if strVal(flow.VirtualKeyID) != "" || strVal(flow.UserID) != "" {
flow.VirtualKeyID = nil
flow.UserID = nil
if err := h.store.ConfigStore.UpdatePerUserOAuthPendingFlow(ctx, flow); err != nil {
redirectToIdentityPage(ctx, flowID, "Failed to clear identity. Please try again.")
return
}
}
// Skip goes straight to MCPs page; no identity means only lazy auth is available.
ctx.Redirect(fmt.Sprintf("/oauth/consent/mcps?flow_id=%s", url.QueryEscape(flowID)), fasthttp.StatusFound)
}
// handleSubmit finalises the consent flow:
// 1. Creates a real Bifrost session (TablePerUserOAuthSession)
// 2. Migrates upstream tokens from the flow proxy to the real session
// 3. Issues a TablePerUserOAuthCode
// 4. Deletes the PendingFlow
// 5. Redirects to the original MCP client callback URL with code + state
//
// POST /api/oauth/per-user/consent/submit (form: flow_id)
func (h *ConsentHandler) handleSubmit(ctx *fasthttp.RequestCtx) {
if h.store.ConfigStore == nil {
SendError(ctx, fasthttp.StatusServiceUnavailable, "Config store unavailable")
return
}
flowID := string(ctx.FormValue("flow_id"))
if flowID == "" {
SendError(ctx, fasthttp.StatusBadRequest, "flow_id is required")
return
}
flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID)
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to load consent flow")
return
}
if flow == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Invalid consent flow")
return
}
if time.Now().After(flow.ExpiresAt) {
SendError(ctx, fasthttp.StatusBadRequest, "Consent flow has expired. Please restart the authentication process.")
return
}
if !validateFlowBrowserSecret(ctx, flow) {
SendError(ctx, fasthttp.StatusForbidden, "Flow does not belong to this browser session")
return
}
// Server-side enforcement: reject if identity is required but not provided.
h.store.Mu.RLock()
enforceAuth := h.store.ClientConfig.EnforceAuthOnInference
h.store.Mu.RUnlock()
if enforceAuth && strVal(flow.VirtualKeyID) == "" && strVal(flow.UserID) == "" {
redirectToIdentityPage(ctx, flowID, "An identity (Virtual Key or User ID) is required. Please choose one to continue.")
return
}
// 1. Generate session credentials.
accessToken, err := generateOpaqueToken(32)
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate session token")
return
}
refreshToken, err := generateOpaqueToken(32)
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate refresh token")
return
}
session := &tables.TablePerUserOAuthSession{
ID: uuid.New().String(),
AccessToken: accessToken,
RefreshToken: refreshToken,
ClientID: flow.ClientID,
VirtualKeyID: flow.VirtualKeyID,
UserID: flow.UserID,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
// 2. Generate authorization code.
code, err := generateOpaqueToken(32)
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate authorization code")
return
}
codeRecord := &tables.TablePerUserOAuthCode{
ID: uuid.New().String(),
Code: code,
ClientID: flow.ClientID,
RedirectURI: flow.RedirectURI,
CodeChallenge: flow.CodeChallenge,
SessionID: session.ID, // Links token endpoint to this session so it can return the same access token
// Scopes intentionally omitted: the consent flow has no scope selection step.
ExpiresAt: time.Now().Add(5 * time.Minute),
}
// 3. Atomically consume the pending flow, create session, and create auth code.
// If another concurrent request already consumed the flow, rowsAffected will be 0.
rowsAffected, err := h.store.ConfigStore.FinalizePerUserOAuthConsent(ctx, flowID, session, codeRecord)
if err != nil {
if errors.Is(err, schemas.ErrPerUserOAuthPendingFlowExpired) {
SendError(ctx, fasthttp.StatusGone, "Consent flow has expired. Please restart the authentication process.")
return
}
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to finalize consent flow")
return
}
if rowsAffected == 0 {
SendError(ctx, fasthttp.StatusConflict, "Consent flow has already been submitted")
return
}
logger.Debug("[consent/submit] session created: session_id=%s flow_id=%s", session.ID, flowID)
// 4. Migrate upstream tokens from flow proxy sessions to real session (non-fatal).
if err := h.store.ConfigStore.TransferOauthUserTokensFromGatewaySession(ctx, flowID, accessToken, strVal(flow.VirtualKeyID), strVal(flow.UserID)); err != nil {
// Non-fatal: tokens can be re-acquired on first tool use.
logger.Warn("[consent/submit] failed to transfer upstream tokens: flow_id=%s err=%v", flowID, err)
}
// 5. Redirect to MCP client callback with code + original state.
redirectURL, err := url.Parse(flow.RedirectURI)
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, "Invalid redirect URI in pending flow")
return
}
q := redirectURL.Query()
q.Set("code", code)
if flow.State != "" {
q.Set("state", flow.State)
}
redirectURL.RawQuery = q.Encode()
ctx.Redirect(redirectURL.String(), fasthttp.StatusFound)
}
// ---------- helpers ----------
// bifrostPageCSS is the shared inline CSS for all Go-rendered consent/callback pages.
// It mirrors Bifrost's UI design tokens: teal primary, zinc palette, Geist font stack.
const bifrostPageCSS = `
*,*::before,*::after{box-sizing:border-box;margin:0;padding:0}
body{font-family:"Geist",system-ui,-apple-system,sans-serif;font-size:0.95rem;
line-height:1.5;background:#f4f4f5;color:oklch(0.141 0.005 285.823);
display:flex;align-items:center;justify-content:center;min-height:100vh;
-webkit-font-smoothing:antialiased}
.card{background:#fff;border:1px solid oklch(0.92 0.004 286.32);border-radius:12px;
padding:40px;width:100%;max-width:480px}
h1{font-size:1.25rem;font-weight:600;color:oklch(0.141 0.005 285.823);margin-bottom:6px}
.subtitle{font-size:0.825rem;color:oklch(0.552 0.016 285.938);line-height:1.5;margin-bottom:24px}
label{display:block;font-size:0.825rem;font-weight:500;color:oklch(0.141 0.005 285.823);margin-bottom:5px}
input[type=text],input[type=password]{width:100%;padding:8px 12px;border:1px solid oklch(0.92 0.004 286.32);
border-radius:0.5rem;font-size:0.875rem;outline:none;
transition:border-color .15s,box-shadow .15s;margin-bottom:10px;
background:#fff;color:oklch(0.141 0.005 285.823)}
input[type=text]:focus,input[type=password]:focus{border-color:oklch(0.5081 0.1049 165.61);
box-shadow:0 0 0 3px oklch(0.5081 0.1049 165.61 / 0.15)}
.btn{display:block;width:100%;padding:9px 16px;border-radius:0.5rem;font-size:0.875rem;
font-weight:500;cursor:pointer;border:none;text-align:center;text-decoration:none;
transition:background .15s;font-family:inherit}
.btn-primary{background:oklch(0.5081 0.1049 165.61);color:oklch(0.985 0 0)}
.btn-primary:hover{background:oklch(0.43 0.1049 165.61)}
.btn-ghost{background:transparent;border:1px solid oklch(0.92 0.004 286.32);
color:oklch(0.552 0.016 285.938);display:inline-block;width:auto;padding:8px 16px}
.btn-ghost:hover{background:#f4f4f5}
.error-banner{background:oklch(0.97 0.02 27);border:1px solid oklch(0.88 0.06 27);
border-radius:0.5rem;padding:12px 14px;margin-bottom:18px;
color:oklch(0.50 0.18 27);font-size:0.825rem}
`
// redirectToIdentityPage redirects to the identity selection page with an error message.
func redirectToIdentityPage(ctx *fasthttp.RequestCtx, flowID, errorMsg string) {
u := fmt.Sprintf("/oauth/consent?flow_id=%s&error=%s",
url.QueryEscape(flowID), url.QueryEscape(errorMsg))
ctx.Redirect(u, fasthttp.StatusFound)
}
// strVal safely dereferences a *string, returning "" for nil.
func strVal(s *string) string {
if s == nil {
return ""
}
return *s
}

View File

@@ -0,0 +1,93 @@
// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
// This file implements OAuth 2.0 metadata discovery endpoints per RFC 9728
// (Protected Resource Metadata) and RFC 8414 (Authorization Server Metadata).
// These endpoints enable MCP-spec-compliant clients (like Claude Code) to
// automatically discover Bifrost's OAuth configuration and authenticate.
package handlers
import (
"fmt"
"github.com/fasthttp/router"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
// OAuthMetadataHandler serves OAuth 2.0 discovery metadata endpoints.
// It provides the Protected Resource Metadata (RFC 9728) and Authorization
// Server Metadata (RFC 8414) that MCP clients use to discover how to
// authenticate with Bifrost's MCP server endpoint.
type OAuthMetadataHandler struct {
store *lib.Config
}
// NewOAuthMetadataHandler creates a new OAuth metadata handler instance.
func NewOAuthMetadataHandler(store *lib.Config) *OAuthMetadataHandler {
return &OAuthMetadataHandler{store: store}
}
// RegisterRoutes registers the well-known metadata discovery routes.
// These routes do NOT go through auth middleware since they must be
// accessible to unauthenticated clients during OAuth discovery.
func (h *OAuthMetadataHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
// RFC 9728: Protected Resource Metadata
r.GET("/.well-known/oauth-protected-resource", lib.ChainMiddlewares(h.handleProtectedResourceMetadata, middlewares...))
// RFC 8414: Authorization Server Metadata
r.GET("/.well-known/oauth-authorization-server", lib.ChainMiddlewares(h.handleAuthorizationServerMetadata, middlewares...))
}
// handleProtectedResourceMetadata serves the Protected Resource Metadata
// document per RFC 9728. MCP clients fetch this after receiving a 401 response
// to discover which authorization server(s) protect the MCP resource.
//
// GET /.well-known/oauth-protected-resource
func (h *OAuthMetadataHandler) handleProtectedResourceMetadata(ctx *fasthttp.RequestCtx) {
if clients := h.store.GetPerUserOAuthMCPClients(); len(clients) == 0 {
sendStringError(ctx, fasthttp.StatusNotFound, "Not Found")
return
}
scheme := "http"
if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" {
scheme = "https"
}
host := string(ctx.Host())
baseURL := fmt.Sprintf("%s://%s", scheme, host)
SendJSON(ctx, map[string]interface{}{
"resource": baseURL + "/mcp",
"authorization_servers": []string{baseURL},
"scopes_supported": []string{"mcp:read", "mcp:write"},
"bearer_methods_supported": []string{"header"},
})
}
// handleAuthorizationServerMetadata serves the Authorization Server Metadata
// document per RFC 8414. MCP clients use this to discover Bifrost's OAuth
// endpoints (authorize, token, register) and supported capabilities.
//
// GET /.well-known/oauth-authorization-server
func (h *OAuthMetadataHandler) handleAuthorizationServerMetadata(ctx *fasthttp.RequestCtx) {
if clients := h.store.GetPerUserOAuthMCPClients(); len(clients) == 0 {
sendStringError(ctx, fasthttp.StatusNotFound, "Not Found")
return
}
scheme := "http"
if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" {
scheme = "https"
}
host := string(ctx.Host())
baseURL := fmt.Sprintf("%s://%s", scheme, host)
SendJSON(ctx, map[string]interface{}{
"issuer": baseURL,
"authorization_endpoint": baseURL + "/api/oauth/per-user/authorize",
"token_endpoint": baseURL + "/api/oauth/per-user/token",
"registration_endpoint": baseURL + "/api/oauth/per-user/register",
"response_types_supported": []string{"code"},
"grant_types_supported": []string{"authorization_code"},
"code_challenge_methods_supported": []string{"S256"},
"token_endpoint_auth_methods_supported": []string{"none"},
"scopes_supported": []string{"mcp:read", "mcp:write"},
})
}

View File

@@ -0,0 +1,577 @@
// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
// This file implements Bifrost's OAuth 2.1 Authorization Server for per-user MCP
// authentication. It provides Dynamic Client Registration (RFC 7591), Authorization
// Code flow with PKCE, and token issuance. MCP clients (Claude Code, IDEs) use
// these endpoints to authenticate users before accessing Bifrost's /mcp endpoint.
package handlers
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"html"
"net/url"
"strings"
"time"
"github.com/fasthttp/router"
"github.com/google/uuid"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore"
"github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
// PerUserOAuthHandler implements Bifrost's OAuth 2.1 Authorization Server.
// It handles dynamic client registration, authorization code issuance with PKCE,
// and token exchange for MCP per-user authentication.
type PerUserOAuthHandler struct {
store *lib.Config
}
// NewPerUserOAuthHandler creates a new per-user OAuth handler instance.
func NewPerUserOAuthHandler(store *lib.Config) *PerUserOAuthHandler {
return &PerUserOAuthHandler{store: store}
}
// RegisterRoutes registers the per-user OAuth authorization server routes.
// These routes do NOT go through auth middleware since they are part of the
// OAuth flow that unauthenticated clients use to obtain tokens.
func (h *PerUserOAuthHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
r.POST("/api/oauth/per-user/register", lib.ChainMiddlewares(h.handleDynamicClientRegistration, middlewares...))
r.GET("/api/oauth/per-user/authorize", lib.ChainMiddlewares(h.handleAuthorize, middlewares...))
r.POST("/api/oauth/per-user/token", lib.ChainMiddlewares(h.handleToken, middlewares...))
r.GET("/api/oauth/per-user/upstream/authorize", lib.ChainMiddlewares(h.handleUpstreamAuthorize, middlewares...))
}
// handleDynamicClientRegistration handles OAuth 2.0 Dynamic Client Registration
// per RFC 7591. MCP clients register themselves to obtain a client_id.
//
// POST /api/oauth/per-user/register
func (h *PerUserOAuthHandler) handleDynamicClientRegistration(ctx *fasthttp.RequestCtx) {
if h.store.ConfigStore == nil {
SendError(ctx, fasthttp.StatusServiceUnavailable, "OAuth registration unavailable: config store is disabled")
return
}
if len(h.store.GetPerUserOAuthMCPClients()) == 0 {
sendStringError(ctx, fasthttp.StatusNotFound, "Not found")
return
}
var req struct {
ClientName string `json:"client_name"`
RedirectURIs []string `json:"redirect_uris"`
GrantTypes []string `json:"grant_types"`
ResponseTypes []string `json:"response_types"`
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
Scope string `json:"scope"`
}
if err := json.Unmarshal(ctx.PostBody(), &req); err != nil {
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid registration request: %v", err))
return
}
if len(req.RedirectURIs) == 0 {
SendError(ctx, fasthttp.StatusBadRequest, "redirect_uris is required")
return
}
// Generate client_id
clientID := uuid.New().String()
// Serialize arrays
redirectURIsJSON, _ := json.Marshal(req.RedirectURIs)
grantTypes := req.GrantTypes
if len(grantTypes) == 0 {
grantTypes = []string{"authorization_code"}
}
grantTypesJSON, _ := json.Marshal(grantTypes)
client := &tables.TablePerUserOAuthClient{
ID: uuid.New().String(),
ClientID: clientID,
ClientName: req.ClientName,
RedirectURIs: string(redirectURIsJSON),
GrantTypes: string(grantTypesJSON),
}
if err := h.store.ConfigStore.CreatePerUserOAuthClient(ctx, client); err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to register client: %v", err))
return
}
// Return RFC 7591 response
ctx.SetStatusCode(fasthttp.StatusCreated)
SendJSON(ctx, map[string]interface{}{
"client_id": clientID,
"client_name": req.ClientName,
"redirect_uris": req.RedirectURIs,
"grant_types": grantTypes,
"response_types": req.ResponseTypes,
"token_endpoint_auth_method": "none",
})
}
// handleAuthorize handles the OAuth 2.1 authorization endpoint.
// Instead of issuing a code immediately, it validates the request parameters,
// creates a PendingFlow record, and redirects the user to the consent screen.
// The code is only issued after the user completes the consent flow (VK + MCP auths).
//
// GET /api/oauth/per-user/authorize?response_type=code&client_id=xxx&redirect_uri=xxx&code_challenge=xxx&code_challenge_method=S256[&state=xxx]
func (h *PerUserOAuthHandler) handleAuthorize(ctx *fasthttp.RequestCtx) {
if h.store.ConfigStore == nil {
SendError(ctx, fasthttp.StatusServiceUnavailable, "OAuth authorization unavailable: config store is disabled")
return
}
if len(h.store.GetPerUserOAuthMCPClients()) == 0 {
sendStringError(ctx, fasthttp.StatusNotFound, "Not found")
return
}
// Extract parameters
responseType := string(ctx.QueryArgs().Peek("response_type"))
clientID := string(ctx.QueryArgs().Peek("client_id"))
redirectURI := string(ctx.QueryArgs().Peek("redirect_uri"))
state := string(ctx.QueryArgs().Peek("state"))
codeChallenge := string(ctx.QueryArgs().Peek("code_challenge"))
codeChallengeMethod := string(ctx.QueryArgs().Peek("code_challenge_method"))
// Validate required parameters
if responseType != "code" {
SendError(ctx, fasthttp.StatusBadRequest, "response_type must be 'code'")
return
}
if clientID == "" || redirectURI == "" {
SendError(ctx, fasthttp.StatusBadRequest, "client_id and redirect_uri are required")
return
}
if codeChallenge == "" || codeChallengeMethod != "S256" {
SendError(ctx, fasthttp.StatusBadRequest, "PKCE is required: code_challenge and code_challenge_method=S256")
return
}
// Validate client exists and redirect_uri is registered
client, err := h.store.ConfigStore.GetPerUserOAuthClientByClientID(ctx, clientID)
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to validate client: %v", err))
return
}
if client == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Unknown client_id")
return
}
var allowedURIs []string
json.Unmarshal([]byte(client.RedirectURIs), &allowedURIs)
uriAllowed := false
for _, allowed := range allowedURIs {
if allowed == redirectURI {
uriAllowed = true
break
}
}
if !uriAllowed {
SendError(ctx, fasthttp.StatusBadRequest, "redirect_uri not registered for this client")
return
}
// Generate a browser-binding secret so only the initiating browser can resume this flow.
browserSecret, err := generateOpaqueToken(32)
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate browser secret")
return
}
browserSecretHash := fmt.Sprintf("%x", sha256.Sum256([]byte(browserSecret)))
// Create a PendingFlow to carry OAuth params through the consent screen.
flow := &tables.TablePerUserOAuthPendingFlow{
ID: uuid.New().String(),
ClientID: clientID,
RedirectURI: redirectURI,
CodeChallenge: codeChallenge,
State: state,
BrowserSecretHash: browserSecretHash,
ExpiresAt: time.Now().Add(15 * time.Minute),
}
if err := h.store.ConfigStore.CreatePerUserOAuthPendingFlow(ctx, flow); err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to create pending flow: %v", err))
return
}
logger.Debug("[oauth/authorize] PendingFlow created: flow_id=%s client_id=%s", flow.ID, clientID)
// Set HttpOnly cookie binding this flow to the current browser.
var cookie fasthttp.Cookie
cookie.SetKey("__bifrost_flow_secret")
cookie.SetValue(browserSecret)
cookie.SetPath("/")
cookie.SetHTTPOnly(true)
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
isSecure := ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https"
cookie.SetSecure(isSecure)
cookie.SetMaxAge(15 * 60) // 15 minutes, matching flow TTL
ctx.Response.Header.SetCookie(&cookie)
// Redirect to consent screen with flow_id (relative path — stays on current origin).
consentURL := fmt.Sprintf("/oauth/consent?flow_id=%s", url.QueryEscape(flow.ID))
ctx.Redirect(consentURL, fasthttp.StatusFound)
}
// handleToken handles the OAuth 2.1 token endpoint.
// It validates the authorization code + PKCE verifier and issues access/refresh tokens.
//
// POST /api/oauth/per-user/token
func (h *PerUserOAuthHandler) handleToken(ctx *fasthttp.RequestCtx) {
if h.store.ConfigStore == nil {
SendError(ctx, fasthttp.StatusServiceUnavailable, "OAuth token endpoint unavailable: config store is disabled")
return
}
if len(h.store.GetPerUserOAuthMCPClients()) == 0 {
sendStringError(ctx, fasthttp.StatusNotFound, "Not found")
return
}
// Parse form-encoded body
grantType := string(ctx.FormValue("grant_type"))
code := string(ctx.FormValue("code"))
redirectURI := string(ctx.FormValue("redirect_uri"))
clientID := string(ctx.FormValue("client_id"))
codeVerifier := string(ctx.FormValue("code_verifier"))
if grantType != "authorization_code" {
sendOAuthError(ctx, fasthttp.StatusBadRequest, "unsupported_grant_type", "Only authorization_code grant is supported")
return
}
if code == "" || codeVerifier == "" {
sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_request", "code and code_verifier are required")
return
}
// Atomically claim authorization code (prevents concurrent redemption)
codeRecord, err := h.store.ConfigStore.ClaimPerUserOAuthCode(ctx, code)
if err != nil {
sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Failed to validate code")
return
}
if codeRecord == nil {
sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "Invalid or already used authorization code")
return
}
// Validate code is not expired
if time.Now().After(codeRecord.ExpiresAt) {
sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "Authorization code expired")
return
}
// Validate client_id if provided — some public clients omit it (RFC 6749 §4.1.3 allows
// omitting client_id when the client is not authenticating with the server).
// The code record already binds the code to the correct client, so this is safe.
if clientID != "" && codeRecord.ClientID != clientID {
logger.Debug("[oauth/token] client_id mismatch: code_client=%s request_client=%s", codeRecord.ClientID, clientID)
sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "client_id mismatch")
return
}
// Use the client_id from the code record as the authoritative value.
clientID = codeRecord.ClientID
// Validate redirect_uri matches
if redirectURI != "" && codeRecord.RedirectURI != redirectURI {
logger.Debug("[oauth/token] redirect_uri mismatch: code=%s request=%s", codeRecord.RedirectURI, redirectURI)
sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "redirect_uri mismatch")
return
}
// Validate PKCE: SHA256(code_verifier) must match code_challenge
verifierHash := sha256.Sum256([]byte(codeVerifier))
computedChallenge := base64.RawURLEncoding.EncodeToString(verifierHash[:])
if computedChallenge != codeRecord.CodeChallenge {
logger.Debug("[oauth/token] PKCE verification failed")
sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "PKCE verification failed")
return
}
// If the code was issued by the consent flow (handleSubmit), the session already exists
// with the upstream tokens transferred to it. Reuse that session's access token so the
// client receives the token that the upstream (Notion, GitHub, etc.) tokens are linked to.
var accessToken string
var expiresAt time.Time
if codeRecord.SessionID != "" {
existingSession, err := h.store.ConfigStore.GetPerUserOAuthSessionByID(ctx, codeRecord.SessionID)
if err != nil {
logger.Info("[oauth/token] Failed to load existing session: session_id=%s err=%v", codeRecord.SessionID, err)
sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Failed to load session")
return
}
if existingSession == nil {
logger.Info("[oauth/token] Existing session not found: session_id=%s", codeRecord.SessionID)
sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Session not found")
return
}
if !existingSession.ExpiresAt.After(time.Now()) {
sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "Session expired")
return
}
accessToken = existingSession.AccessToken
expiresAt = existingSession.ExpiresAt
logger.Debug("[oauth/token] reusing consent session: session_id=%s", existingSession.ID)
} else {
// Fallback: no linked session (legacy path) — create a new one.
var newAccessToken, newRefreshToken string
newAccessToken, err = generateOpaqueToken(32)
if err != nil {
sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Failed to generate access token")
return
}
newRefreshToken, err = generateOpaqueToken(32)
if err != nil {
sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Failed to generate refresh token")
return
}
expiresAt = time.Now().Add(24 * time.Hour)
newSession := &tables.TablePerUserOAuthSession{
ID: uuid.New().String(),
AccessToken: newAccessToken,
RefreshToken: newRefreshToken,
ClientID: clientID,
ExpiresAt: expiresAt,
}
if err := h.store.ConfigStore.CreatePerUserOAuthSession(ctx, newSession); err != nil {
sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Failed to create session")
return
}
accessToken = newAccessToken
logger.Debug("[oauth/token] created new session (legacy path): session_id=%s", newSession.ID)
}
// Return OAuth token response
ctx.SetContentType("application/json")
ctx.SetStatusCode(fasthttp.StatusOK)
SendJSON(ctx, map[string]interface{}{
"access_token": accessToken,
"token_type": "Bearer",
"expires_in": int(time.Until(expiresAt).Seconds()),
"scope": codeRecord.Scopes,
})
}
// sendOAuthError sends an OAuth 2.0 error response per RFC 6749 Section 5.2.
func sendOAuthError(ctx *fasthttp.RequestCtx, statusCode int, errorCode, description string) {
ctx.SetContentType("application/json")
ctx.SetStatusCode(statusCode)
resp, _ := json.Marshal(map[string]string{
"error": errorCode,
"error_description": description,
})
ctx.SetBody(resp)
}
func sendStringError(ctx *fasthttp.RequestCtx, statusCode int, message string) {
ctx.SetContentType("text/plain")
ctx.SetStatusCode(statusCode)
ctx.SetBodyString(message)
}
// generateOpaqueToken generates a cryptographically secure random token.
// validateFlowBrowserSecret checks that the request carries the __bifrost_flow_secret
// cookie matching the hash stored on the pending flow. Returns true if valid.
func validateFlowBrowserSecret(ctx *fasthttp.RequestCtx, flow *tables.TablePerUserOAuthPendingFlow) bool {
if flow.BrowserSecretHash == "" {
// Legacy flow without browser binding — allow for backwards compatibility.
return true
}
secret := ctx.Request.Header.Cookie("__bifrost_flow_secret")
if len(secret) == 0 {
return false
}
hash := fmt.Sprintf("%x", sha256.Sum256(secret))
return hash == flow.BrowserSecretHash
}
func generateOpaqueToken(length int) (string, error) {
bytes := make([]byte, length)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(bytes), nil
}
// handleUpstreamAuthorize handles the upstream OAuth proxy for per-user OAuth.
// When a user needs to authenticate with an upstream MCP server (e.g., Notion),
// this endpoint redirects them to the upstream provider's OAuth authorize URL.
// After the user authenticates, the callback stores their upstream token linked
// to either their Bifrost session (runtime flow) or a PendingFlow (consent flow).
//
// Runtime flow: GET /api/oauth/per-user/upstream/authorize?mcp_client_id=xxx&session=xxx
// Consent flow: GET /api/oauth/per-user/upstream/authorize?mcp_client_id=xxx&flow_id=xxx
func (h *PerUserOAuthHandler) handleUpstreamAuthorize(ctx *fasthttp.RequestCtx) {
if h.store.ConfigStore == nil {
SendError(ctx, fasthttp.StatusServiceUnavailable, "OAuth upstream authorization unavailable: config store is disabled")
return
}
mcpClientID := string(ctx.QueryArgs().Peek("mcp_client_id"))
sessionID := string(ctx.QueryArgs().Peek("session"))
flowID := string(ctx.QueryArgs().Peek("flow_id"))
if mcpClientID == "" || (sessionID == "" && flowID == "") {
SendError(ctx, fasthttp.StatusBadRequest, "mcp_client_id and either session or flow_id are required")
return
}
// Resolve identity depending on whether this is a runtime session or a consent flow.
var virtualKeyID, userID, proxySessionToken, gatewaySessionID string
if flowID != "" {
// Consent flow: use the pending flow for identity and proxy token.
flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID)
if err != nil || flow == nil || time.Now().After(flow.ExpiresAt) {
SendError(ctx, fasthttp.StatusUnauthorized, "Invalid or expired consent flow")
return
}
if !validateFlowBrowserSecret(ctx, flow) {
SendError(ctx, fasthttp.StatusForbidden, "Flow does not belong to this browser session")
return
}
if strVal(flow.VirtualKeyID) != "" {
virtualKeyID = *flow.VirtualKeyID
}
if strVal(flow.UserID) != "" {
userID = *flow.UserID
}
// Use a prefixed flow token so the callback can detect the consent path.
// Include mcpClientID to avoid unique constraint violations when multiple
// MCP services are connected in the same consent flow.
proxySessionToken = "flow:" + flowID + ":" + mcpClientID
gatewaySessionID = flowID
} else {
// Runtime flow: validate the existing Bifrost session.
bifrostSession, err := h.store.ConfigStore.GetPerUserOAuthSessionByID(ctx, sessionID)
if err != nil || bifrostSession == nil {
SendError(ctx, fasthttp.StatusUnauthorized, "Invalid or expired session")
return
}
if !bifrostSession.ExpiresAt.After(time.Now()) {
SendError(ctx, fasthttp.StatusUnauthorized, "Invalid or expired session")
return
}
virtualKeyID = strVal(bifrostSession.VirtualKeyID)
userID = strVal(bifrostSession.UserID)
proxySessionToken = "runtime:" + sessionID + ":" + mcpClientID
gatewaySessionID = sessionID
}
// Look up the MCP client config to get the template OAuth config.
mcpClient, err := h.store.ConfigStore.GetMCPClientByID(ctx, mcpClientID)
if err != nil || mcpClient == nil {
SendError(ctx, fasthttp.StatusNotFound, "MCP client not found")
return
}
if mcpClient.AuthType != string(schemas.MCPAuthTypePerUserOauth) {
SendError(ctx, fasthttp.StatusBadRequest, "MCP client does not use per-user OAuth")
return
}
if mcpClient.OauthConfigID == nil || *mcpClient.OauthConfigID == "" {
SendError(ctx, fasthttp.StatusBadRequest, "MCP client has no OAuth configuration")
return
}
// Load template OAuth config (has upstream authorize_url, client_id, etc.)
templateConfig, err := h.store.ConfigStore.GetOauthConfigByID(ctx, *mcpClient.OauthConfigID)
if err != nil || templateConfig == nil {
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to load OAuth template config")
return
}
// Generate PKCE challenge for upstream.
codeVerifier, err := generateOpaqueToken(32)
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate PKCE verifier")
return
}
verifierHash := sha256.Sum256([]byte(codeVerifier))
codeChallenge := base64.RawURLEncoding.EncodeToString(verifierHash[:])
// Generate state for upstream.
state, err := generateOpaqueToken(32)
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate state token")
return
}
// Build redirect URI (Bifrost's callback endpoint).
scheme := "http"
if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" {
scheme = "https"
}
host := string(ctx.Host())
redirectURI := fmt.Sprintf("%s://%s/api/oauth/callback", scheme, host)
var vkId *string
if virtualKeyID != "" {
vkId = &virtualKeyID
}
var uid *string
if userID != "" {
uid = &userID
}
// Store upstream OAuth session linking state → MCP client + identity.
upstreamSession := &tables.TableOauthUserSession{
ID: uuid.New().String(),
MCPClientID: mcpClientID,
OauthConfigID: *mcpClient.OauthConfigID,
State: state,
CodeVerifier: codeVerifier,
SessionToken: proxySessionToken, // "runtime:xxx" for runtime flow; "flow:xxx" for consent flow
GatewaySessionID: gatewaySessionID,
VirtualKeyID: vkId,
UserID: uid,
Status: "pending",
ExpiresAt: time.Now().Add(15 * time.Minute),
}
logger.Debug("[oauth/upstream-authorize] creating upstream session: mcp_client=%s flow=%s", mcpClientID, proxySessionToken)
if err := h.store.ConfigStore.CreateOauthUserSession(ctx, upstreamSession); err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to create upstream OAuth session: %v", err))
return
}
// Parse scopes from template config.
var scopes []string
if templateConfig.Scopes != "" {
json.Unmarshal([]byte(templateConfig.Scopes), &scopes)
}
// Build upstream authorize URL with PKCE.
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", templateConfig.ClientID)
params.Set("redirect_uri", redirectURI)
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
if len(scopes) > 0 {
params.Set("scope", strings.Join(scopes, " "))
}
baseURL, err := url.Parse(templateConfig.AuthorizeURL)
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, "Invalid upstream authorize URL")
return
}
existing := baseURL.Query()
for k, vals := range params {
for _, v := range vals {
existing.Set(k, v)
}
}
baseURL.RawQuery = existing.Encode()
ctx.Redirect(baseURL.String(), fasthttp.StatusFound)
}
// Ensure unused imports are referenced.
var _ = html.EscapeString
var _ configstore.ConfigStore

View File

@@ -0,0 +1,491 @@
package handlers
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/fasthttp/router"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/maximhq/bifrost/framework/plugins"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
type PluginsLoader interface {
ReloadPlugin(ctx context.Context, name string, path *string, pluginConfig any, placement *schemas.PluginPlacement, order *int) error
RemovePlugin(ctx context.Context, name string) error
GetPluginStatus(ctx context.Context) map[string]schemas.PluginStatus
}
// PluginsHandler is the handler for the plugins API
type PluginsHandler struct {
configStore configstore.ConfigStore
pluginsLoader PluginsLoader
}
// NewPluginsHandler creates a new PluginsHandler
func NewPluginsHandler(pluginsLoader PluginsLoader, configStore configstore.ConfigStore) *PluginsHandler {
return &PluginsHandler{
pluginsLoader: pluginsLoader,
configStore: configStore,
}
}
// CreatePluginRequest is the request body for creating a plugin
type CreatePluginRequest struct {
Name string `json:"name"`
Enabled bool `json:"enabled"`
Config map[string]any `json:"config"`
Path *string `json:"path"`
Placement *schemas.PluginPlacement `json:"placement,omitempty"`
Order *int `json:"order,omitempty"`
}
// UpdatePluginRequest is the request body for updating a plugin
type UpdatePluginRequest struct {
Enabled bool `json:"enabled"`
Path *string `json:"path"`
Config map[string]any `json:"config"`
Placement *schemas.PluginPlacement `json:"placement,omitempty"`
Order *int `json:"order,omitempty"`
}
// RegisterRoutes registers the routes for the PluginsHandler
func (h *PluginsHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
r.GET("/api/plugins", lib.ChainMiddlewares(h.getPlugins, middlewares...))
r.GET("/api/plugins/{name}", lib.ChainMiddlewares(h.getPlugin, middlewares...))
r.POST("/api/plugins", lib.ChainMiddlewares(h.createPlugin, middlewares...))
r.PUT("/api/plugins/{name}", lib.ChainMiddlewares(h.updatePlugin, middlewares...))
r.DELETE("/api/plugins/{name}", lib.ChainMiddlewares(h.deletePlugin, middlewares...))
}
type PluginResponse struct {
Name string `json:"name"`
ActualName string `json:"actualName"`
Enabled bool `json:"enabled"`
Config any `json:"config"`
IsCustom bool `json:"isCustom"`
Path *string `json:"path"`
Placement *schemas.PluginPlacement `json:"placement,omitempty"`
Order *int `json:"order,omitempty"`
Status schemas.PluginStatus `json:"status"`
}
// buildPluginResponse constructs a PluginResponse with status for a given TablePlugin.
func (h *PluginsHandler) buildPluginResponse(ctx context.Context, plugin *configstoreTables.TablePlugin) PluginResponse {
pluginStatus := schemas.PluginStatus{
Name: plugin.Name,
Status: schemas.PluginStatusUninitialized,
Logs: []string{},
}
if !plugin.Enabled {
pluginStatus.Status = schemas.PluginStatusDisabled
} else {
for _, status := range h.pluginsLoader.GetPluginStatus(ctx) {
if plugin.Name == status.Name {
pluginStatus = status
break
}
}
}
return PluginResponse{
Name: plugin.Name,
ActualName: pluginStatus.Name,
Enabled: plugin.Enabled,
Config: plugin.Config,
IsCustom: plugin.IsCustom,
Path: plugin.Path,
Placement: plugin.Placement,
Order: plugin.Order,
Status: pluginStatus,
}
}
// getPlugins gets all plugins
func (h *PluginsHandler) getPlugins(ctx *fasthttp.RequestCtx) {
if h.configStore == nil {
pluginStatus := h.pluginsLoader.GetPluginStatus(ctx)
finalPlugins := []PluginResponse{}
for name, pluginStatus := range pluginStatus {
finalPlugins = append(finalPlugins, PluginResponse{
Name: pluginStatus.Name,
ActualName: name,
Enabled: true,
Config: map[string]any{},
IsCustom: true,
Path: nil,
Status: pluginStatus,
})
}
SendJSON(ctx, map[string]any{
"plugins": finalPlugins,
"count": len(finalPlugins),
})
return
}
plugins, err := h.configStore.GetPlugins(ctx)
if err != nil {
logger.Error("failed to get plugins: %v", err)
SendError(ctx, 500, "Failed to retrieve plugins")
return
}
// Fetching status
pluginStatuses := h.pluginsLoader.GetPluginStatus(ctx)
// Creating ephemeral struct for the plugins
finalPlugins := []PluginResponse{}
// Iterating over plugin status to get the plugin info
for _, plugin := range plugins {
pluginStatus := schemas.PluginStatus{
Name: plugin.Name,
Status: schemas.PluginStatusUninitialized,
Logs: []string{},
}
if !plugin.Enabled {
pluginStatus.Status = schemas.PluginStatusDisabled
}
for _, status := range pluginStatuses {
if plugin.Name == status.Name {
pluginStatus = status
break
}
}
finalPlugins = append(finalPlugins, PluginResponse{
Name: plugin.Name,
ActualName: pluginStatus.Name,
Enabled: plugin.Enabled,
Config: plugin.Config,
IsCustom: plugin.IsCustom,
Path: plugin.Path,
Placement: plugin.Placement,
Order: plugin.Order,
Status: pluginStatus,
})
}
// Creating ephemeral struct
SendJSON(ctx, map[string]any{
"plugins": finalPlugins,
"count": len(finalPlugins),
})
}
// getPlugin gets a plugin by name
func (h *PluginsHandler) getPlugin(ctx *fasthttp.RequestCtx) {
if h.configStore == nil {
pluginStatus := h.pluginsLoader.GetPluginStatus(ctx)
pluginInfo := PluginResponse{}
for name, pluginStatus := range pluginStatus {
if pluginStatus.Name == ctx.UserValue("name") {
pluginInfo = PluginResponse{
Name: pluginStatus.Name,
ActualName: name,
Enabled: true,
Config: map[string]any{},
IsCustom: true,
Path: nil,
Status: pluginStatus,
}
break
}
}
SendJSON(ctx, pluginInfo)
return
}
// Safely validate the "name" parameter
nameValue := ctx.UserValue("name")
if nameValue == nil {
logger.Warn("missing required 'name' parameter in request")
SendError(ctx, 400, "Missing required 'name' parameter")
return
}
name, ok := nameValue.(string)
if !ok {
logger.Warn("invalid 'name' parameter type, expected string but got %T", nameValue)
SendError(ctx, 400, "Invalid 'name' parameter type, expected string")
return
}
if name == "" {
logger.Warn("empty 'name' parameter provided")
SendError(ctx, 400, "Empty 'name' parameter not allowed")
return
}
plugin, err := h.configStore.GetPlugin(ctx, name)
if err != nil {
if errors.Is(err, configstore.ErrNotFound) {
SendError(ctx, fasthttp.StatusNotFound, "Plugin not found")
return
}
logger.Error("failed to get plugin: %v", err)
SendError(ctx, 500, "Failed to retrieve plugin")
return
}
SendJSON(ctx, plugin)
}
// createPlugin creates a new plugin
func (h *PluginsHandler) createPlugin(ctx *fasthttp.RequestCtx) {
if h.configStore == nil {
SendError(ctx, 400, "Plugins creation is not supported when configstore is disabled")
return
}
var request CreatePluginRequest
if err := json.Unmarshal(ctx.PostBody(), &request); err != nil {
logger.Error("failed to unmarshal create plugin request: %v", err)
SendError(ctx, 400, "Invalid request body")
return
}
// Validate required fields
if request.Name == "" {
SendError(ctx, fasthttp.StatusBadRequest, "Plugin name is required")
return
}
// Validate placement value
if request.Placement != nil && *request.Placement != "" &&
*request.Placement != schemas.PluginPlacementPreBuiltin &&
*request.Placement != schemas.PluginPlacementPostBuiltin {
SendError(ctx, fasthttp.StatusBadRequest, "Invalid placement value. Must be 'pre_builtin' or 'post_builtin'")
return
}
if request.Placement != nil && *request.Placement == "" {
request.Placement = nil
}
// Normalize empty path to nil (treat empty string as built-in plugin)
if request.Path != nil && *request.Path == "" {
request.Path = nil
}
// Check if plugin already exists
existingPlugin, err := h.configStore.GetPlugin(ctx, request.Name)
if err == nil && existingPlugin != nil {
SendError(ctx, fasthttp.StatusConflict, "Plugin already exists")
return
}
// Determine if this is a built-in or custom plugin
isBuiltin := lib.IsBuiltinPlugin(request.Name)
// Built-in plugins should not have a path
if isBuiltin && request.Path != nil {
request.Path = nil
}
// Create DB entry first to avoid orphaned in-memory state if DB write fails
if err := h.configStore.CreatePlugin(ctx, &configstoreTables.TablePlugin{
Name: request.Name,
Enabled: request.Enabled,
Config: request.Config,
Path: request.Path,
IsCustom: !isBuiltin,
Placement: request.Placement,
Order: request.Order,
}); err != nil {
logger.Error("failed to create plugin: %v", err)
SendError(ctx, 500, "Failed to create plugin")
return
}
// Reload the plugin into memory if it's enabled
if request.Enabled {
if err := h.pluginsLoader.ReloadPlugin(ctx, request.Name, request.Path, request.Config, request.Placement, request.Order); err != nil {
logger.Error("failed to load plugin: %v", err)
if rbErr := h.configStore.DeletePlugin(ctx, request.Name); rbErr != nil {
logger.Error("failed to rollback plugin creation: %v", rbErr)
}
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Plugin created in database but failed to load: %v", err))
return
}
}
plugin, err := h.configStore.GetPlugin(ctx, request.Name)
if err != nil {
logger.Error("failed to get plugin: %v", err)
SendError(ctx, 500, "Failed to retrieve plugin")
return
}
ctx.SetStatusCode(fasthttp.StatusCreated)
SendJSON(ctx, map[string]any{
"message": "Plugin created successfully",
"plugin": h.buildPluginResponse(ctx, plugin),
})
}
// updatePlugin updates an existing plugin
func (h *PluginsHandler) updatePlugin(ctx *fasthttp.RequestCtx) {
if h.configStore == nil {
SendError(ctx, 400, "Plugins update is not supported when configstore is disabled")
return
}
// Safely validate the "name" parameter
nameValue := ctx.UserValue("name")
if nameValue == nil {
logger.Warn("missing required 'name' parameter in update plugin request")
SendError(ctx, 400, "Missing required 'name' parameter")
return
}
name, ok := nameValue.(string)
if !ok {
logger.Warn("invalid 'name' parameter type in update plugin request, expected string but got %T", nameValue)
SendError(ctx, 400, "Invalid 'name' parameter type, expected string")
return
}
if name == "" {
logger.Warn("empty 'name' parameter provided in update plugin request")
SendError(ctx, 400, "Empty 'name' parameter not allowed")
return
}
var plugin *configstoreTables.TablePlugin
var err error
// Check if plugin exists
_, err = h.configStore.GetPlugin(ctx, name)
if err != nil {
// If doesn't exist, create it
if errors.Is(err, configstore.ErrNotFound) {
plugin = &configstoreTables.TablePlugin{
Name: name,
Enabled: false,
Config: map[string]any{},
Path: nil,
IsCustom: false,
}
if err := h.configStore.CreatePlugin(ctx, plugin); err != nil {
logger.Error("failed to create plugin: %v", err)
SendError(ctx, 500, "Failed to create plugin")
return
}
} else {
logger.Error("failed to get plugin: %v", err)
SendError(ctx, 500, "Failed to update plugin")
return
}
}
// Unmarshalling the request body
var request UpdatePluginRequest
if err := json.Unmarshal(ctx.PostBody(), &request); err != nil {
logger.Error("failed to unmarshal update plugin request: %v", err)
SendError(ctx, 400, "Invalid request body")
return
}
// Validate placement value
if request.Placement != nil && *request.Placement != "" &&
*request.Placement != schemas.PluginPlacementPreBuiltin &&
*request.Placement != schemas.PluginPlacementPostBuiltin {
SendError(ctx, fasthttp.StatusBadRequest, "Invalid placement value. Must be 'pre_builtin' or 'post_builtin'")
return
}
if request.Placement != nil && *request.Placement == "" {
request.Placement = nil
}
// Normalize empty path to nil (treat empty string as built-in plugin)
if request.Path != nil && *request.Path == "" {
request.Path = nil
}
// Determine if this is a built-in plugin
isBuiltin := lib.IsBuiltinPlugin(name)
// Built-in plugins should not have a path
if isBuiltin && request.Path != nil {
request.Path = nil
}
// Updating the plugin
if err := h.configStore.UpdatePlugin(ctx, &configstoreTables.TablePlugin{
Name: name,
Enabled: request.Enabled,
Config: request.Config,
Path: request.Path,
IsCustom: !isBuiltin,
Placement: request.Placement,
Order: request.Order,
}); err != nil {
logger.Error("failed to update plugin: %v", err)
SendError(ctx, 500, "Failed to update plugin")
return
}
plugin, err = h.configStore.GetPlugin(ctx, name)
if err != nil {
if errors.Is(err, configstore.ErrNotFound) {
SendError(ctx, fasthttp.StatusNotFound, "Plugin not found")
return
}
logger.Error("failed to get plugin: %v", err)
SendError(ctx, 500, "Failed to retrieve plugin")
return
}
// We reload the plugin if its enabled, otherwise we stop it
if request.Enabled {
if err := h.pluginsLoader.ReloadPlugin(ctx, name, request.Path, request.Config, request.Placement, request.Order); err != nil {
logger.Error("failed to load plugin: %v", err)
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Plugin updated in database but failed to load: %v", err))
return
}
} else {
ctx.SetUserValue(PluginDisabledKey, true)
if err := h.pluginsLoader.RemovePlugin(ctx, name); err != nil {
if !errors.Is(err, plugins.ErrPluginNotFound) {
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Plugin updated in database but failed to stop: %v", err))
return
}
// If not found then we don't need to do anything
}
}
SendJSON(ctx, map[string]interface{}{
"message": "Plugin updated successfully",
"plugin": h.buildPluginResponse(ctx, plugin),
})
}
// deletePlugin deletes an existing plugin
func (h *PluginsHandler) deletePlugin(ctx *fasthttp.RequestCtx) {
if h.configStore == nil {
SendError(ctx, 400, "Plugins deletion is not supported when configstore is disabled")
return
}
// Safely validate the "name" parameter
nameValue := ctx.UserValue("name")
if nameValue == nil {
logger.Warn("missing required 'name' parameter in delete plugin request")
SendError(ctx, 400, "Missing required 'name' parameter")
return
}
name, ok := nameValue.(string)
if !ok {
logger.Warn("invalid 'name' parameter type in delete plugin request, expected string but got %T", nameValue)
SendError(ctx, 400, "Invalid 'name' parameter type, expected string")
return
}
if name == "" {
logger.Warn("empty 'name' parameter provided in delete plugin request")
SendError(ctx, 400, "Empty 'name' parameter not allowed")
return
}
if err := h.configStore.DeletePlugin(ctx, name); err != nil {
if errors.Is(err, configstore.ErrNotFound) {
SendError(ctx, fasthttp.StatusNotFound, "Plugin not found")
return
}
logger.Error("failed to delete plugin: %v", err)
SendError(ctx, 500, "Failed to delete plugin")
return
}
if err := h.pluginsLoader.RemovePlugin(ctx, name); err != nil {
if !errors.Is(err, plugins.ErrPluginNotFound) {
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Plugin deleted in database but failed to stop: %v", err))
return
}
}
SendJSON(ctx, map[string]interface{}{
"message": "Plugin deleted successfully",
})
}

View File

@@ -0,0 +1,149 @@
package handlers
import (
"context"
"encoding/json"
"net"
"os"
"testing"
"time"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/maximhq/bifrost/framework/modelcatalog"
"github.com/maximhq/bifrost/plugins/governance"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
type pricingOverrideTestGovernanceManager struct{}
func (pricingOverrideTestGovernanceManager) GetGovernanceData(ctx context.Context) *governance.GovernanceData {
return nil
}
func (pricingOverrideTestGovernanceManager) ReloadVirtualKey(context.Context, string) (*configstoreTables.TableVirtualKey, error) {
return nil, nil
}
func (pricingOverrideTestGovernanceManager) RemoveVirtualKey(context.Context, string) error {
return nil
}
func (pricingOverrideTestGovernanceManager) ReloadTeam(context.Context, string) (*configstoreTables.TableTeam, error) {
return nil, nil
}
func (pricingOverrideTestGovernanceManager) RemoveTeam(context.Context, string) error {
return nil
}
func (pricingOverrideTestGovernanceManager) ReloadCustomer(context.Context, string) (*configstoreTables.TableCustomer, error) {
return nil, nil
}
func (pricingOverrideTestGovernanceManager) RemoveCustomer(context.Context, string) error {
return nil
}
func (pricingOverrideTestGovernanceManager) ReloadModelConfig(context.Context, string) (*configstoreTables.TableModelConfig, error) {
return nil, nil
}
func (pricingOverrideTestGovernanceManager) RemoveModelConfig(context.Context, string) error {
return nil
}
func (pricingOverrideTestGovernanceManager) ReloadProvider(context.Context, schemas.ModelProvider) (*configstoreTables.TableProvider, error) {
return nil, nil
}
func (pricingOverrideTestGovernanceManager) RemoveProvider(context.Context, schemas.ModelProvider) error {
return nil
}
func (pricingOverrideTestGovernanceManager) ReloadRoutingRule(context.Context, string) error {
return nil
}
func (pricingOverrideTestGovernanceManager) RemoveRoutingRule(context.Context, string) error {
return nil
}
func (pricingOverrideTestGovernanceManager) UpsertPricingOverride(context.Context, *configstoreTables.TablePricingOverride) error {
return nil
}
func (pricingOverrideTestGovernanceManager) DeletePricingOverride(context.Context, string) error {
return nil
}
func setupPricingOverrideHandlerStore(t *testing.T) configstore.ConfigStore {
t.Helper()
dbPath := t.TempDir() + "/config.db"
store, err := configstore.NewConfigStore(context.Background(), &configstore.Config{
Enabled: true,
Type: configstore.ConfigStoreTypeSQLite,
Config: &configstore.SQLiteConfig{
Path: dbPath,
},
}, &mockLogger{})
require.NoError(t, err)
t.Cleanup(func() {
_ = os.Remove(dbPath)
})
return store
}
func newTestRequestCtx(body string) *fasthttp.RequestCtx {
var req fasthttp.Request
req.SetBodyString(body)
ctx := &fasthttp.RequestCtx{}
ctx.Init(&req, &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 12345}, nil)
return ctx
}
func TestUpdatePricingOverride_ReplacesFullBody(t *testing.T) {
SetLogger(&mockLogger{})
store := setupPricingOverrideHandlerStore(t)
handler := &GovernanceHandler{
configStore: store,
governanceManager: pricingOverrideTestGovernanceManager{},
}
now := time.Now().UTC()
override := configstoreTables.TablePricingOverride{
ID: "override-1",
Name: "Original",
ScopeKind: string(modelcatalog.ScopeKindGlobal),
MatchType: string(modelcatalog.MatchTypeExact),
Pattern: "gpt-4.1",
CreatedAt: now,
UpdatedAt: now,
PricingPatchJSON: `{"input_cost_per_token":1,"output_cost_per_token":2}`,
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
}
require.NoError(t, store.CreatePricingOverride(context.Background(), &override))
// Patch replaces in full: send only input_cost_per_token.
// output_cost_per_token must be absent from the stored patch afterwards,
// confirming full-replace (not merge) semantics.
body := `{
"name":"Updated",
"scope_kind":"global",
"match_type":"exact",
"pattern":"gpt-4.1",
"request_types":["chat_completion"],
"patch":{"input_cost_per_token":1.5}
}`
ctx := newTestRequestCtx(body)
ctx.SetUserValue("id", override.ID)
handler.updatePricingOverride(ctx)
require.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode(), string(ctx.Response.Body()))
stored, err := store.GetPricingOverrideByID(context.Background(), override.ID)
require.NoError(t, err)
assert.Equal(t, "Updated", stored.Name)
var patch modelcatalog.PricingOptions
require.NoError(t, json.Unmarshal([]byte(stored.PricingPatchJSON), &patch))
// Sent field must reflect the new value.
require.NotNil(t, patch.InputCostPerToken)
assert.Equal(t, 1.5, *patch.InputCostPerToken)
// Omitted field must be cleared — patch is always fully replaced, not merged.
assert.Nil(t, patch.OutputCostPerToken)
assert.Empty(t, stored.ConfigHash)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,495 @@
package handlers
import (
"errors"
"fmt"
"net/url"
"github.com/bytedance/sonic"
"github.com/google/uuid"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
// ListProviderKeysResponse represents the response for listing keys for a provider.
type ListProviderKeysResponse struct {
Keys []schemas.Key `json:"keys"`
Total int `json:"total"`
}
func (h *ProviderHandler) listProviderKeys(ctx *fasthttp.RequestCtx) {
provider, err := getProviderFromCtx(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err))
return
}
keys, err := h.inMemoryStore.GetProviderKeysRedacted(provider)
if err != nil {
if errors.Is(err, lib.ErrNotFound) {
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err))
return
}
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider keys: %v", err))
return
}
SendJSON(ctx, ListProviderKeysResponse{Keys: keys, Total: len(keys)})
}
func (h *ProviderHandler) getProviderKey(ctx *fasthttp.RequestCtx) {
provider, err := getProviderFromCtx(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err))
return
}
keyID, err := getKeyIDFromCtx(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
key, err := h.inMemoryStore.GetProviderKeyRedacted(provider, keyID)
if err != nil {
if errors.Is(err, lib.ErrNotFound) {
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err))
return
}
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider key: %v", err))
return
}
SendJSON(ctx, key)
}
func (h *ProviderHandler) createProviderKey(ctx *fasthttp.RequestCtx) {
provider, err := getProviderFromCtx(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err))
return
}
var key schemas.Key
if err := sonic.Unmarshal(ctx.PostBody(), &key); err != nil {
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid JSON: %v", err))
return
}
providerConfig, err := h.inMemoryStore.GetProviderConfigRaw(provider)
if err != nil {
if errors.Is(err, lib.ErrNotFound) {
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err))
return
}
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider config: %v", err))
return
}
if providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.IsKeyLess {
SendError(ctx, fasthttp.StatusBadRequest, "Cannot add keys to a keyless provider")
return
}
baseProvider := provider
if providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.BaseProviderType != "" {
baseProvider = providerConfig.CustomProviderConfig.BaseProviderType
}
if !bifrost.CanProviderKeyValueBeEmpty(baseProvider) && key.Value.GetValue() == "" {
SendError(ctx, fasthttp.StatusBadRequest, "Key value must not be empty")
return
}
if err := validateProviderKeyURL(provider, key); err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
if err := key.BlacklistedModels.Validate(); err != nil {
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid blacklisted_models: %v", err))
return
}
if err := key.Aliases.Validate(); err != nil {
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid aliases: %v", err))
return
}
if key.ID == "" {
key.ID = uuid.NewString()
}
if key.Enabled == nil {
key.Enabled = bifrost.Ptr(true)
}
if err := h.inMemoryStore.AddProviderKey(ctx, provider, key); err != nil {
logger.Warn("Failed to create key for provider %s: %v", provider, err)
if errors.Is(err, lib.ErrNotFound) {
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err))
return
}
if errors.Is(err, lib.ErrAlreadyExists) {
SendError(ctx, fasthttp.StatusConflict, err.Error())
return
}
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to create provider key: %v", err))
return
}
if err := h.attemptModelDiscovery(ctx, provider, providerConfig.CustomProviderConfig); err != nil {
logger.Warn("Model discovery failed for provider %s after key create: %v", provider, err)
}
redactedKey, err := h.inMemoryStore.GetProviderKeyRedacted(provider, key.ID)
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get created provider key: %v", err))
return
}
SendJSON(ctx, redactedKey)
}
func (h *ProviderHandler) updateProviderKey(ctx *fasthttp.RequestCtx) {
provider, err := getProviderFromCtx(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err))
return
}
keyID, err := getKeyIDFromCtx(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
var updateKey schemas.Key
if err := sonic.Unmarshal(ctx.PostBody(), &updateKey); err != nil {
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid JSON: %v", err))
return
}
providerConfig, err := h.inMemoryStore.GetProviderConfigRaw(provider)
if err != nil {
if errors.Is(err, lib.ErrNotFound) {
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err))
return
}
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider config: %v", err))
return
}
if providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.IsKeyLess {
SendError(ctx, fasthttp.StatusBadRequest, "Cannot update keys on a keyless provider")
return
}
oldRawKey, err := h.inMemoryStore.GetProviderKeyRaw(provider, keyID)
if err != nil {
if errors.Is(err, lib.ErrNotFound) {
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err))
return
}
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider key: %v", err))
return
}
oldRedactedKey, err := h.inMemoryStore.GetProviderKeyRedacted(provider, keyID)
if err != nil {
if errors.Is(err, lib.ErrNotFound) {
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err))
return
}
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider key: %v", err))
return
}
updateKey.ID = keyID
mergedKey := h.mergeUpdatedKey(*oldRawKey, *oldRedactedKey, updateKey)
baseProvider := provider
if providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.BaseProviderType != "" {
baseProvider = providerConfig.CustomProviderConfig.BaseProviderType
}
if !bifrost.CanProviderKeyValueBeEmpty(baseProvider) && mergedKey.Value.GetValue() == "" {
SendError(ctx, fasthttp.StatusBadRequest, "Key value must not be empty")
return
}
if err := mergedKey.BlacklistedModels.Validate(); err != nil {
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid blacklisted_models: %v", err))
return
}
if err := mergedKey.Aliases.Validate(); err != nil {
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid aliases: %v", err))
return
}
if err := validateProviderKeyURL(provider, mergedKey); err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
if err := h.inMemoryStore.UpdateProviderKey(ctx, provider, keyID, mergedKey); err != nil {
logger.Warn("Failed to update key %s for provider %s: %v", keyID, provider, err)
if errors.Is(err, lib.ErrNotFound) {
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err))
return
}
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to update provider key: %v", err))
return
}
if err := h.attemptModelDiscovery(ctx, provider, providerConfig.CustomProviderConfig); err != nil {
logger.Warn("Model discovery failed for provider %s after key update: %v", provider, err)
}
redactedKey, err := h.inMemoryStore.GetProviderKeyRedacted(provider, keyID)
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get updated provider key: %v", err))
return
}
SendJSON(ctx, redactedKey)
}
func (h *ProviderHandler) deleteProviderKey(ctx *fasthttp.RequestCtx) {
provider, err := getProviderFromCtx(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err))
return
}
keyID, err := getKeyIDFromCtx(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
providerConfig, err := h.inMemoryStore.GetProviderConfigRaw(provider)
if err != nil {
if errors.Is(err, lib.ErrNotFound) {
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err))
return
}
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider config: %v", err))
return
}
if providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.IsKeyLess {
SendError(ctx, fasthttp.StatusBadRequest, "Cannot delete keys on a keyless provider")
return
}
redactedKey, err := h.inMemoryStore.GetProviderKeyRedacted(provider, keyID)
if err != nil {
if errors.Is(err, lib.ErrNotFound) {
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err))
return
}
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider key: %v", err))
return
}
if err := h.inMemoryStore.RemoveProviderKey(ctx, provider, keyID); err != nil {
logger.Warn("Failed to delete key %s for provider %s: %v", keyID, provider, err)
if errors.Is(err, lib.ErrNotFound) {
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err))
return
}
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to delete provider key: %v", err))
return
}
if err := h.attemptModelDiscovery(ctx, provider, providerConfig.CustomProviderConfig); err != nil {
logger.Warn("Model discovery failed for provider %s after key delete: %v", provider, err)
}
SendJSON(ctx, redactedKey)
}
// mergeUpdatedKey merges an updated key with the old raw/redacted versions,
// preserving real values for fields that were sent back in redacted form.
func (h *ProviderHandler) mergeUpdatedKey(oldRawKey, oldRedactedKey, updateKey schemas.Key) schemas.Key {
mergedKey := updateKey
if updateKey.Value.IsRedacted() && updateKey.Value.Equals(&oldRedactedKey.Value) {
mergedKey.Value = oldRawKey.Value
}
if updateKey.AzureKeyConfig != nil && oldRedactedKey.AzureKeyConfig != nil && oldRawKey.AzureKeyConfig != nil {
if updateKey.AzureKeyConfig.Endpoint.IsRedacted() &&
updateKey.AzureKeyConfig.Endpoint.Equals(&oldRedactedKey.AzureKeyConfig.Endpoint) {
mergedKey.AzureKeyConfig.Endpoint = oldRawKey.AzureKeyConfig.Endpoint
}
if updateKey.AzureKeyConfig.APIVersion != nil &&
oldRedactedKey.AzureKeyConfig.APIVersion != nil &&
oldRawKey.AzureKeyConfig != nil &&
updateKey.AzureKeyConfig.APIVersion.IsRedacted() &&
updateKey.AzureKeyConfig.APIVersion.Equals(oldRedactedKey.AzureKeyConfig.APIVersion) {
mergedKey.AzureKeyConfig.APIVersion = oldRawKey.AzureKeyConfig.APIVersion
}
if updateKey.AzureKeyConfig.ClientID != nil &&
oldRedactedKey.AzureKeyConfig.ClientID != nil &&
oldRawKey.AzureKeyConfig != nil &&
updateKey.AzureKeyConfig.ClientID.IsRedacted() &&
updateKey.AzureKeyConfig.ClientID.Equals(oldRedactedKey.AzureKeyConfig.ClientID) {
mergedKey.AzureKeyConfig.ClientID = oldRawKey.AzureKeyConfig.ClientID
}
if updateKey.AzureKeyConfig.ClientSecret != nil &&
oldRedactedKey.AzureKeyConfig.ClientSecret != nil &&
oldRawKey.AzureKeyConfig != nil &&
updateKey.AzureKeyConfig.ClientSecret.IsRedacted() &&
updateKey.AzureKeyConfig.ClientSecret.Equals(oldRedactedKey.AzureKeyConfig.ClientSecret) {
mergedKey.AzureKeyConfig.ClientSecret = oldRawKey.AzureKeyConfig.ClientSecret
}
if updateKey.AzureKeyConfig.TenantID != nil &&
oldRedactedKey.AzureKeyConfig.TenantID != nil &&
oldRawKey.AzureKeyConfig != nil &&
updateKey.AzureKeyConfig.TenantID.IsRedacted() &&
updateKey.AzureKeyConfig.TenantID.Equals(oldRedactedKey.AzureKeyConfig.TenantID) {
mergedKey.AzureKeyConfig.TenantID = oldRawKey.AzureKeyConfig.TenantID
}
}
if updateKey.VertexKeyConfig != nil && oldRedactedKey.VertexKeyConfig != nil && oldRawKey.VertexKeyConfig != nil {
if updateKey.VertexKeyConfig.ProjectID.IsRedacted() &&
updateKey.VertexKeyConfig.ProjectID.Equals(&oldRedactedKey.VertexKeyConfig.ProjectID) {
mergedKey.VertexKeyConfig.ProjectID = oldRawKey.VertexKeyConfig.ProjectID
}
if updateKey.VertexKeyConfig.ProjectNumber.IsRedacted() &&
updateKey.VertexKeyConfig.ProjectNumber.Equals(&oldRedactedKey.VertexKeyConfig.ProjectNumber) {
mergedKey.VertexKeyConfig.ProjectNumber = oldRawKey.VertexKeyConfig.ProjectNumber
}
if updateKey.VertexKeyConfig.Region.IsRedacted() &&
updateKey.VertexKeyConfig.Region.Equals(&oldRedactedKey.VertexKeyConfig.Region) {
mergedKey.VertexKeyConfig.Region = oldRawKey.VertexKeyConfig.Region
}
if updateKey.VertexKeyConfig.AuthCredentials.IsRedacted() &&
updateKey.VertexKeyConfig.AuthCredentials.Equals(&oldRedactedKey.VertexKeyConfig.AuthCredentials) {
mergedKey.VertexKeyConfig.AuthCredentials = oldRawKey.VertexKeyConfig.AuthCredentials
}
}
if updateKey.BedrockKeyConfig != nil && oldRedactedKey.BedrockKeyConfig != nil && oldRawKey.BedrockKeyConfig != nil {
if updateKey.BedrockKeyConfig.AccessKey.IsRedacted() &&
updateKey.BedrockKeyConfig.AccessKey.Equals(&oldRedactedKey.BedrockKeyConfig.AccessKey) {
mergedKey.BedrockKeyConfig.AccessKey = oldRawKey.BedrockKeyConfig.AccessKey
}
if updateKey.BedrockKeyConfig.SecretKey.IsRedacted() &&
updateKey.BedrockKeyConfig.SecretKey.Equals(&oldRedactedKey.BedrockKeyConfig.SecretKey) {
mergedKey.BedrockKeyConfig.SecretKey = oldRawKey.BedrockKeyConfig.SecretKey
}
if updateKey.BedrockKeyConfig.SessionToken != nil &&
oldRedactedKey.BedrockKeyConfig.SessionToken != nil &&
oldRawKey.BedrockKeyConfig != nil &&
updateKey.BedrockKeyConfig.SessionToken.IsRedacted() &&
updateKey.BedrockKeyConfig.SessionToken.Equals(oldRedactedKey.BedrockKeyConfig.SessionToken) {
mergedKey.BedrockKeyConfig.SessionToken = oldRawKey.BedrockKeyConfig.SessionToken
}
if updateKey.BedrockKeyConfig.Region != nil &&
oldRedactedKey.BedrockKeyConfig.Region != nil &&
oldRawKey.BedrockKeyConfig != nil &&
updateKey.BedrockKeyConfig.Region.IsRedacted() &&
updateKey.BedrockKeyConfig.Region.Equals(oldRedactedKey.BedrockKeyConfig.Region) {
mergedKey.BedrockKeyConfig.Region = oldRawKey.BedrockKeyConfig.Region
}
if updateKey.BedrockKeyConfig.ARN != nil &&
oldRedactedKey.BedrockKeyConfig.ARN != nil &&
oldRawKey.BedrockKeyConfig != nil &&
updateKey.BedrockKeyConfig.ARN.IsRedacted() &&
updateKey.BedrockKeyConfig.ARN.Equals(oldRedactedKey.BedrockKeyConfig.ARN) {
mergedKey.BedrockKeyConfig.ARN = oldRawKey.BedrockKeyConfig.ARN
}
if updateKey.BedrockKeyConfig.RoleARN != nil &&
oldRedactedKey.BedrockKeyConfig.RoleARN != nil &&
oldRawKey.BedrockKeyConfig != nil &&
updateKey.BedrockKeyConfig.RoleARN.IsRedacted() &&
updateKey.BedrockKeyConfig.RoleARN.Equals(oldRedactedKey.BedrockKeyConfig.RoleARN) {
mergedKey.BedrockKeyConfig.RoleARN = oldRawKey.BedrockKeyConfig.RoleARN
}
if updateKey.BedrockKeyConfig.ExternalID != nil &&
oldRedactedKey.BedrockKeyConfig.ExternalID != nil &&
oldRawKey.BedrockKeyConfig != nil &&
updateKey.BedrockKeyConfig.ExternalID.IsRedacted() &&
updateKey.BedrockKeyConfig.ExternalID.Equals(oldRedactedKey.BedrockKeyConfig.ExternalID) {
mergedKey.BedrockKeyConfig.ExternalID = oldRawKey.BedrockKeyConfig.ExternalID
}
if updateKey.BedrockKeyConfig.RoleSessionName != nil &&
oldRedactedKey.BedrockKeyConfig.RoleSessionName != nil &&
oldRawKey.BedrockKeyConfig != nil &&
updateKey.BedrockKeyConfig.RoleSessionName.IsRedacted() &&
updateKey.BedrockKeyConfig.RoleSessionName.Equals(oldRedactedKey.BedrockKeyConfig.RoleSessionName) {
mergedKey.BedrockKeyConfig.RoleSessionName = oldRawKey.BedrockKeyConfig.RoleSessionName
}
}
if updateKey.VLLMKeyConfig != nil && oldRedactedKey.VLLMKeyConfig != nil && oldRawKey.VLLMKeyConfig != nil {
if updateKey.VLLMKeyConfig.URL.IsRedacted() &&
updateKey.VLLMKeyConfig.URL.Equals(&oldRedactedKey.VLLMKeyConfig.URL) {
mergedKey.VLLMKeyConfig.URL = oldRawKey.VLLMKeyConfig.URL
}
}
// ReplicateKeyConfig has no sensitive fields — pass through as-is
if updateKey.ReplicateKeyConfig == nil && oldRawKey.ReplicateKeyConfig != nil {
mergedKey.ReplicateKeyConfig = oldRawKey.ReplicateKeyConfig
}
if updateKey.OllamaKeyConfig != nil && oldRedactedKey.OllamaKeyConfig != nil && oldRawKey.OllamaKeyConfig != nil {
if updateKey.OllamaKeyConfig.URL.IsRedacted() &&
updateKey.OllamaKeyConfig.URL.Equals(&oldRedactedKey.OllamaKeyConfig.URL) {
mergedKey.OllamaKeyConfig.URL = oldRawKey.OllamaKeyConfig.URL
}
}
if updateKey.SGLKeyConfig != nil && oldRedactedKey.SGLKeyConfig != nil && oldRawKey.SGLKeyConfig != nil {
if updateKey.SGLKeyConfig.URL.IsRedacted() &&
updateKey.SGLKeyConfig.URL.Equals(&oldRedactedKey.SGLKeyConfig.URL) {
mergedKey.SGLKeyConfig.URL = oldRawKey.SGLKeyConfig.URL
}
}
mergedKey.ConfigHash = oldRawKey.ConfigHash
mergedKey.Status = oldRawKey.Status
return mergedKey
}
func getKeyIDFromCtx(ctx *fasthttp.RequestCtx) (string, error) {
keyValue := ctx.UserValue("key_id")
if keyValue == nil {
return "", fmt.Errorf("missing key_id parameter")
}
keyID, ok := keyValue.(string)
if !ok || keyID == "" {
return "", fmt.Errorf("invalid key_id parameter")
}
decoded, err := url.PathUnescape(keyID)
if err != nil {
return "", fmt.Errorf("invalid key_id parameter encoding: %v", err)
}
return decoded, nil
}
// validateProviderKeyURL checks that Ollama/SGL keys have a server URL configured.
func validateProviderKeyURL(provider schemas.ModelProvider, key schemas.Key) error {
switch provider {
case schemas.Ollama:
if key.OllamaKeyConfig == nil || !key.OllamaKeyConfig.URL.IsDefined() {
return fmt.Errorf("ollama_key_config.url is required for Ollama keys")
}
case schemas.SGL:
if key.SGLKeyConfig == nil || !key.SGLKeyConfig.URL.IsDefined() {
return fmt.Errorf("sgl_key_config.url is required for SGL keys")
}
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,550 @@
package handlers
import (
"context"
"encoding/json"
"testing"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/maximhq/bifrost/framework/modelcatalog"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
// mockModelsManager returns stable filtered and unfiltered model lists for handler tests.
type mockModelsManager struct {
filtered map[schemas.ModelProvider][]string
unfiltered map[schemas.ModelProvider][]string
reloadCalls []schemas.ModelProvider
reloadErr error
}
func (m *mockModelsManager) ReloadProvider(_ context.Context, provider schemas.ModelProvider) (*configstoreTables.TableProvider, error) {
m.reloadCalls = append(m.reloadCalls, provider)
if m.reloadErr != nil {
return nil, m.reloadErr
}
return nil, nil
}
func (m *mockModelsManager) RemoveProvider(_ context.Context, _ schemas.ModelProvider) error {
return nil
}
func (m *mockModelsManager) GetModelsForProvider(provider schemas.ModelProvider) []string {
models := m.filtered[provider]
result := make([]string, len(models))
copy(result, models)
return result
}
func (m *mockModelsManager) GetUnfilteredModelsForProvider(provider schemas.ModelProvider) []string {
models := m.unfiltered[provider]
result := make([]string, len(models))
copy(result, models)
return result
}
// providerHandlerForTest builds a handler with fixed provider config and model sets.
func providerHandlerForTest(provider schemas.ModelProvider, keys []schemas.Key, filtered, unfiltered []string) *ProviderHandler {
return &ProviderHandler{
inMemoryStore: &lib.Config{
Providers: map[schemas.ModelProvider]configstore.ProviderConfig{
provider: {
Keys: keys,
},
},
},
modelsManager: &mockModelsManager{
filtered: map[schemas.ModelProvider][]string{
provider: filtered,
},
unfiltered: map[schemas.ModelProvider][]string{
provider: unfiltered,
},
},
}
}
func TestAddProvider_ReloadsRuntimeEvenWhenModelDiscoveryIsSkipped(t *testing.T) {
SetLogger(&mockLogger{})
lib.SetLogger(&mockLogger{})
modelsManager := &mockModelsManager{}
h := &ProviderHandler{
inMemoryStore: &lib.Config{Providers: map[schemas.ModelProvider]configstore.ProviderConfig{}},
modelsManager: modelsManager,
}
body, err := sonic.Marshal(providerCreatePayload{
Provider: "mock-openai",
CustomProviderConfig: &schemas.CustomProviderConfig{
BaseProviderType: schemas.OpenAI,
IsKeyLess: true,
},
})
if err != nil {
t.Fatalf("failed to marshal request body: %v", err)
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fasthttp.MethodPost)
ctx.Request.SetRequestURI("/api/providers")
ctx.Request.SetBody(body)
h.addProvider(ctx)
if ctx.Response.StatusCode() != fasthttp.StatusOK {
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
}
if len(modelsManager.reloadCalls) != 1 || modelsManager.reloadCalls[0] != "mock-openai" {
t.Fatalf("expected provider reload for mock-openai, got %#v", modelsManager.reloadCalls)
}
if _, exists := h.inMemoryStore.Providers["mock-openai"]; !exists {
t.Fatalf("expected provider to be added to in-memory store")
}
}
func TestAddProvider_ReturnsErrorWhenRuntimeReloadFails(t *testing.T) {
SetLogger(&mockLogger{})
lib.SetLogger(&mockLogger{})
modelsManager := &mockModelsManager{reloadErr: context.DeadlineExceeded}
h := &ProviderHandler{
inMemoryStore: &lib.Config{Providers: map[schemas.ModelProvider]configstore.ProviderConfig{}},
modelsManager: modelsManager,
}
body, err := sonic.Marshal(providerCreatePayload{
Provider: "mock-openai",
CustomProviderConfig: &schemas.CustomProviderConfig{
BaseProviderType: schemas.OpenAI,
IsKeyLess: true,
},
})
if err != nil {
t.Fatalf("failed to marshal request body: %v", err)
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fasthttp.MethodPost)
ctx.Request.SetRequestURI("/api/providers")
ctx.Request.SetBody(body)
h.addProvider(ctx)
if ctx.Response.StatusCode() != fasthttp.StatusInternalServerError {
t.Fatalf("expected 500, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
}
if len(modelsManager.reloadCalls) != 1 || modelsManager.reloadCalls[0] != "mock-openai" {
t.Fatalf("expected single provider reload for mock-openai, got %#v", modelsManager.reloadCalls)
}
var bifrostErr schemas.BifrostError
if err := json.Unmarshal(ctx.Response.Body(), &bifrostErr); err != nil {
t.Fatalf("failed to unmarshal error response: %v", err)
}
if bifrostErr.Error == nil || bifrostErr.Error.Message == "" {
t.Fatalf("expected error message in response, got %#v", bifrostErr)
}
if bifrostErr.Error.Message != "Failed to initialize provider after add: context deadline exceeded" {
t.Fatalf("unexpected error message: %q", bifrostErr.Error.Message)
}
if _, exists := h.inMemoryStore.Providers["mock-openai"]; exists {
t.Fatalf("expected provider rollback after reload failure")
}
}
// boolPtr keeps pointer-valued key fixtures inline without pulling in pointer helpers.
func boolPtr(v bool) *bool {
return &v
}
func TestListModels_UnknownKeysDoNotFilter(t *testing.T) {
SetLogger(&mockLogger{})
h := providerHandlerForTest(
schemas.OpenAI,
[]schemas.Key{{ID: "key-a"}},
[]string{"gpt-4o", "gpt-4o-mini"},
[]string{"gpt-4o", "gpt-4o-mini"},
)
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET")
ctx.Request.SetRequestURI("/api/models?provider=openai&keys=missing")
h.listModels(ctx)
if ctx.Response.StatusCode() != fasthttp.StatusOK {
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
}
var resp ListModelsResponse
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Total != 2 {
t.Fatalf("expected total=2, got %d", resp.Total)
}
if len(resp.Models) != 2 {
t.Fatalf("expected all models to be returned, got %#v", resp.Models)
}
for _, model := range resp.Models {
if len(model.AccessibleByKeys) != 0 {
t.Fatalf("expected no accessible_by_keys annotations, got %#v", resp.Models)
}
}
}
func TestListModels_ReturnsExactAccessibleByKeysAndSkipsDisabledKeys(t *testing.T) {
SetLogger(&mockLogger{})
h := providerHandlerForTest(
schemas.OpenAI,
[]schemas.Key{
{ID: "key-a", Models: []string{"gpt-4o"}},
{ID: "key-b", Models: []string{"gpt-4o", "gpt-4o-mini"}},
{ID: "key-disabled", Enabled: boolPtr(false)},
},
[]string{"gpt-4o", "gpt-4o-mini"},
[]string{"gpt-4o", "gpt-4o-mini"},
)
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET")
ctx.Request.SetRequestURI("/api/models?provider=openai&keys=key-a,key-b,key-disabled")
h.listModels(ctx)
if ctx.Response.StatusCode() != fasthttp.StatusOK {
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
}
var resp ListModelsResponse
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Total != 2 {
t.Fatalf("expected total=2, got %d", resp.Total)
}
got := map[string][]string{}
for _, model := range resp.Models {
got[model.Name] = model.AccessibleByKeys
}
if len(got["gpt-4o"]) != 2 || got["gpt-4o"][0] != "key-a" || got["gpt-4o"][1] != "key-b" {
t.Fatalf("expected gpt-4o to be accessible by [key-a key-b], got %#v", got["gpt-4o"])
}
if len(got["gpt-4o-mini"]) != 1 || got["gpt-4o-mini"][0] != "key-b" {
t.Fatalf("expected gpt-4o-mini to be accessible by [key-b], got %#v", got["gpt-4o-mini"])
}
}
func TestListModels_AppliesQueryAndLimitAfterFiltering(t *testing.T) {
SetLogger(&mockLogger{})
h := providerHandlerForTest(
schemas.OpenAI,
[]schemas.Key{{ID: "key-a"}},
[]string{"gpt-4o", "gpt-4o-mini", "claude-3-5-sonnet"},
[]string{"gpt-4o", "gpt-4o-mini", "claude-3-5-sonnet"},
)
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET")
ctx.Request.SetRequestURI("/api/models?provider=openai&query=gpt&limit=1")
h.listModels(ctx)
if ctx.Response.StatusCode() != fasthttp.StatusOK {
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
}
var resp ListModelsResponse
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Total != 2 {
t.Fatalf("expected total=2 after query filtering, got %d", resp.Total)
}
if len(resp.Models) != 1 {
t.Fatalf("expected limit to truncate response to 1 model, got %#v", resp.Models)
}
if resp.Models[0].Name != "gpt-4o" {
t.Fatalf("expected first filtered model to be gpt-4o, got %#v", resp.Models[0])
}
}
func TestListModels_UnfilteredIgnoresKeys(t *testing.T) {
SetLogger(&mockLogger{})
h := providerHandlerForTest(
schemas.OpenAI,
[]schemas.Key{
{ID: "key-b", Models: []string{"gpt-4o-mini"}},
},
[]string{"gpt-4o"},
[]string{"gpt-4o", "gpt-4o-mini"},
)
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET")
ctx.Request.SetRequestURI("/api/models?provider=openai&keys=key-b&unfiltered=true")
h.listModels(ctx)
if ctx.Response.StatusCode() != fasthttp.StatusOK {
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
}
var resp ListModelsResponse
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Total != 2 || len(resp.Models) != 2 {
t.Fatalf("expected both unfiltered models, got %#v", resp.Models)
}
for _, model := range resp.Models {
if len(model.AccessibleByKeys) != 0 {
t.Fatalf("expected no accessible_by_keys when unfiltered bypasses key filtering, got %#v", resp.Models)
}
}
}
func TestListModels_UnfilteredWithoutKeysReturnsAllUnfilteredModels(t *testing.T) {
SetLogger(&mockLogger{})
h := providerHandlerForTest(
schemas.OpenAI,
[]schemas.Key{
{ID: "key-b", Models: []string{"gpt-4o-mini"}},
},
[]string{"gpt-4o"},
[]string{"gpt-4o", "gpt-4o-mini"},
)
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET")
ctx.Request.SetRequestURI("/api/models?provider=openai&unfiltered=true")
h.listModels(ctx)
if ctx.Response.StatusCode() != fasthttp.StatusOK {
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
}
var resp ListModelsResponse
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Total != 2 || len(resp.Models) != 2 {
t.Fatalf("expected both unfiltered models, got %#v", resp.Models)
}
for _, model := range resp.Models {
if len(model.AccessibleByKeys) != 0 {
t.Fatalf("expected no accessible_by_keys when no key filter is requested, got %#v", resp.Models)
}
}
}
func TestListModelDetails_ErrorsWhenModelCatalogUnavailable(t *testing.T) {
SetLogger(&mockLogger{})
h := providerHandlerForTest(
schemas.OpenAI,
[]schemas.Key{{ID: "key-a"}},
[]string{"gpt-4o"},
[]string{"gpt-4o"},
)
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET")
ctx.Request.SetRequestURI("/api/models/details?provider=openai")
h.listModelDetails(ctx)
if ctx.Response.StatusCode() != fasthttp.StatusInternalServerError {
t.Fatalf("expected 500, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
}
}
func TestListModelDetails_UnknownKeysDoNotFilter(t *testing.T) {
SetLogger(&mockLogger{})
h := providerHandlerForTest(
schemas.OpenAI,
[]schemas.Key{{ID: "key-a"}},
[]string{"gpt-4o", "gpt-4o-mini"},
[]string{"gpt-4o", "gpt-4o-mini"},
)
h.inMemoryStore.ModelCatalog = &modelcatalog.ModelCatalog{}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET")
ctx.Request.SetRequestURI("/api/models/details?provider=openai&keys=missing")
h.listModelDetails(ctx)
if ctx.Response.StatusCode() != fasthttp.StatusOK {
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
}
var resp ListModelDetailsResponse
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Total != 2 || len(resp.Models) != 2 {
t.Fatalf("expected all models when keys are unknown, got %#v", resp.Models)
}
}
func TestListModelDetails_SkipsUnknownKeysAndFiltersWithValid(t *testing.T) {
SetLogger(&mockLogger{})
h := providerHandlerForTest(
schemas.OpenAI,
[]schemas.Key{{ID: "key-a", Models: []string{"gpt-4o"}}},
[]string{"gpt-4o", "gpt-4o-mini"},
[]string{"gpt-4o", "gpt-4o-mini"},
)
h.inMemoryStore.ModelCatalog = &modelcatalog.ModelCatalog{}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET")
ctx.Request.SetRequestURI("/api/models/details?provider=openai&keys=key-a,missing")
h.listModelDetails(ctx)
if ctx.Response.StatusCode() != fasthttp.StatusOK {
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
}
var resp ListModelDetailsResponse
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Total != 1 || len(resp.Models) != 1 {
t.Fatalf("expected 1 model filtered by valid key, got %#v", resp.Models)
}
if resp.Models[0].Name != "gpt-4o" {
t.Fatalf("expected gpt-4o, got %s", resp.Models[0].Name)
}
}
func TestListModelDetails_SkipsDisabledKeysAndFiltersWithValid(t *testing.T) {
SetLogger(&mockLogger{})
h := providerHandlerForTest(
schemas.OpenAI,
[]schemas.Key{
{ID: "key-a", Models: []string{"gpt-4o"}},
{ID: "key-disabled", Enabled: boolPtr(false)},
},
[]string{"gpt-4o", "gpt-4o-mini"},
[]string{"gpt-4o", "gpt-4o-mini"},
)
h.inMemoryStore.ModelCatalog = &modelcatalog.ModelCatalog{}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET")
ctx.Request.SetRequestURI("/api/models/details?provider=openai&keys=key-a,key-disabled")
h.listModelDetails(ctx)
if ctx.Response.StatusCode() != fasthttp.StatusOK {
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
}
var resp ListModelDetailsResponse
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Total != 1 || len(resp.Models) != 1 {
t.Fatalf("expected 1 model filtered by valid key, got %#v", resp.Models)
}
if resp.Models[0].Name != "gpt-4o" {
t.Fatalf("expected gpt-4o, got %s", resp.Models[0].Name)
}
}
func TestListModelDetails_UnfilteredIgnoresKeys(t *testing.T) {
SetLogger(&mockLogger{})
h := providerHandlerForTest(
schemas.OpenAI,
[]schemas.Key{
{ID: "key-b", Models: []string{"gpt-4o-mini"}},
},
[]string{"gpt-4o"},
[]string{"gpt-4o", "gpt-4o-mini"},
)
h.inMemoryStore.ModelCatalog = &modelcatalog.ModelCatalog{}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET")
ctx.Request.SetRequestURI("/api/models/details?provider=openai&keys=key-b&unfiltered=true")
h.listModelDetails(ctx)
if ctx.Response.StatusCode() != fasthttp.StatusOK {
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
}
var resp ListModelDetailsResponse
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Total != 2 || len(resp.Models) != 2 {
t.Fatalf("expected all unfiltered models when unfiltered=true, got %#v", resp.Models)
}
}
func TestListModels_UsesCatalogAwareAliasMatchingForKeyAllowlist(t *testing.T) {
SetLogger(&mockLogger{})
h := providerHandlerForTest(
schemas.OpenAI,
[]schemas.Key{
{ID: "key-a", Models: []string{"gpt-4o-2024-08-06"}},
},
[]string{"gpt-4o"},
[]string{"gpt-4o"},
)
h.inMemoryStore.ModelCatalog = modelcatalog.NewTestCatalog(map[string]string{
"gpt-4o-2024-08-06": "gpt-4o",
})
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET")
ctx.Request.SetRequestURI("/api/models?provider=openai&keys=key-a")
h.listModels(ctx)
if ctx.Response.StatusCode() != fasthttp.StatusOK {
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
}
var resp ListModelsResponse
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Total != 1 || len(resp.Models) != 1 || resp.Models[0].Name != "gpt-4o" {
t.Fatalf("expected gpt-4o to be matched through alias allowlist, got %#v", resp.Models)
}
}

View File

@@ -0,0 +1,419 @@
package handlers
import (
"encoding/json"
"fmt"
"mime"
"strings"
"time"
"github.com/fasthttp/router"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/plugins/governance"
"github.com/maximhq/bifrost/transports/bifrost-http/integrations"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
// RealtimeClientSecretsHandler exposes OpenAI-compatible HTTP routes for
// minting short-lived Realtime client secrets.
type RealtimeClientSecretsHandler struct {
client *bifrost.Bifrost
config *lib.Config
handlerStore lib.HandlerStore
routeSpecs map[string]schemas.RealtimeSessionRoute
}
func NewRealtimeClientSecretsHandler(client *bifrost.Bifrost, config *lib.Config) *RealtimeClientSecretsHandler {
return &RealtimeClientSecretsHandler{
client: client,
config: config,
handlerStore: config,
routeSpecs: make(map[string]schemas.RealtimeSessionRoute),
}
}
func (h *RealtimeClientSecretsHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
handler := lib.ChainMiddlewares(h.handleRequest, middlewares...)
for _, route := range h.realtimeSessionRoutes() {
h.routeSpecs[route.Path] = route
r.POST(route.Path, handler)
}
}
func (h *RealtimeClientSecretsHandler) findGovernancePlugin() governance.BaseGovernancePlugin {
basePlugins := h.config.BasePlugins.Load()
if basePlugins == nil {
return nil
}
for _, plugin := range *basePlugins {
if governancePlugin, ok := plugin.(governance.BaseGovernancePlugin); ok {
return governancePlugin
}
}
return nil
}
func (h *RealtimeClientSecretsHandler) handleRequest(ctx *fasthttp.RequestCtx) {
if !isJSONContentType(string(ctx.Request.Header.ContentType())) {
SendBifrostError(ctx, newRealtimeClientSecretHandlerError(
fasthttp.StatusBadRequest,
"invalid_request_error",
"Content-Type must be application/json",
nil,
))
return
}
body := append([]byte(nil), ctx.Request.Body()...)
route, ok := h.routeSpecs[string(ctx.Path())]
if !ok {
SendBifrostError(ctx, newRealtimeClientSecretHandlerError(
fasthttp.StatusNotFound,
"invalid_request_error",
"unsupported realtime client secret route",
nil,
))
return
}
providerKey, model, normalizedBody, err := resolveRealtimeClientSecretTarget(route, body)
if err != nil {
SendBifrostError(ctx, err)
return
}
bifrostCtx, cancel := lib.ConvertToBifrostContext(
ctx,
h.handlerStore.ShouldAllowDirectKeys(),
h.config.GetHeaderMatcher(),
h.config.GetMCPHeaderCombinedAllowlist(),
)
defer cancel()
bifrostCtx.SetValue(schemas.BifrostContextKeyHTTPRequestType, schemas.RealtimeRequest)
if route.DefaultProvider == schemas.OpenAI {
bifrostCtx.SetValue(schemas.BifrostContextKeyIntegrationType, "openai")
}
if governanceUserID, ok := ctx.UserValue(schemas.BifrostContextKeyUserID).(string); ok && governanceUserID != "" {
bifrostCtx.SetValue(schemas.BifrostContextKeyUserID, governanceUserID)
}
if userName, ok := ctx.UserValue(schemas.BifrostContextKeyUserName).(string); ok && userName != "" {
bifrostCtx.SetValue(schemas.BifrostContextKeyUserName, userName)
}
if bifrostErr := h.evaluateMintingGovernance(bifrostCtx, providerKey, model); bifrostErr != nil {
SendBifrostError(ctx, bifrostErr)
return
}
provider := h.client.GetProviderByKey(providerKey)
if provider == nil {
SendBifrostError(ctx, newRealtimeClientSecretHandlerError(
fasthttp.StatusBadRequest,
"invalid_request_error",
"provider not found: "+string(providerKey),
nil,
))
return
}
key, keyErr := h.client.SelectKeyForProviderRequestType(bifrostCtx, schemas.RealtimeRequest, providerKey, model)
if keyErr != nil {
SendBifrostError(ctx, newRealtimeClientSecretHandlerError(
fasthttp.StatusBadRequest,
"invalid_request_error",
keyErr.Error(),
keyErr,
))
return
}
// Resolve model aliases now that the key is selected so the forwarded body
// carries the provider's canonical model, matching wsrealtime/webrtc flows.
if resolved := key.Aliases.Resolve(model); resolved != "" && resolved != model {
model = resolved
reparsed, parseErr := schemas.ParseRealtimeClientSecretBody(normalizedBody)
if parseErr != nil {
SendBifrostError(ctx, parseErr)
return
}
rewritten, normalizeErr := normalizeRealtimeClientSecretBody(reparsed, model)
if normalizeErr != nil {
SendBifrostError(ctx, normalizeErr)
return
}
normalizedBody = rewritten
}
sessionProvider, ok := provider.(schemas.RealtimeSessionProvider)
if !ok {
SendBifrostError(ctx, realtimeSessionNotSupportedError(providerKey, provider))
return
}
resp, bifrostErr := sessionProvider.CreateRealtimeClientSecret(bifrostCtx, key, route.EndpointType, normalizedBody)
if bifrostErr != nil {
SendBifrostError(ctx, bifrostErr)
return
}
cacheRealtimeEphemeralKeyMapping(
h.handlerStore.GetKVStore(),
resp.Body,
key.ID,
bifrost.GetStringFromContext(bifrostCtx, schemas.BifrostContextKeyVirtualKey),
)
writeRealtimeClientSecretResponse(ctx, resp)
}
func (h *RealtimeClientSecretsHandler) evaluateMintingGovernance(
bifrostCtx *schemas.BifrostContext,
providerKey schemas.ModelProvider,
model string,
) *schemas.BifrostError {
governancePlugin := h.findGovernancePlugin()
if governancePlugin == nil {
return nil
}
_, bifrostErr := governancePlugin.EvaluateGovernanceRequest(bifrostCtx, &governance.EvaluationRequest{
VirtualKey: bifrost.GetStringFromContext(bifrostCtx, schemas.BifrostContextKeyVirtualKey),
Provider: providerKey,
Model: model,
UserID: bifrost.GetStringFromContext(bifrostCtx, schemas.BifrostContextKeyUserID),
}, schemas.RealtimeRequest)
return bifrostErr
}
func (h *RealtimeClientSecretsHandler) realtimeSessionRoutes() []schemas.RealtimeSessionRoute {
routes := []schemas.RealtimeSessionRoute{
{
Path: "/v1/realtime/client_secrets",
EndpointType: schemas.RealtimeSessionEndpointClientSecrets,
},
{
Path: "/v1/realtime/sessions",
EndpointType: schemas.RealtimeSessionEndpointSessions,
},
}
for _, path := range integrations.OpenAIRealtimeClientSecretPaths("/openai") {
endpointType := schemas.RealtimeSessionEndpointClientSecrets
if strings.HasSuffix(path, "/realtime/sessions") {
endpointType = schemas.RealtimeSessionEndpointSessions
}
routes = append(routes, schemas.RealtimeSessionRoute{
Path: path,
EndpointType: endpointType,
DefaultProvider: schemas.OpenAI,
})
}
return routes
}
func resolveRealtimeClientSecretTarget(route schemas.RealtimeSessionRoute, body []byte) (schemas.ModelProvider, string, []byte, *schemas.BifrostError) {
root, err := schemas.ParseRealtimeClientSecretBody(body)
if err != nil {
return "", "", nil, err
}
rawModel, err := schemas.ExtractRealtimeClientSecretModel(root)
if err != nil {
return "", "", nil, err
}
defaultProvider := route.DefaultProvider
providerKey, model := schemas.ParseModelString(rawModel, defaultProvider)
if defaultProvider == "" && providerKey == "" {
return "", "", nil, newRealtimeClientSecretHandlerError(
fasthttp.StatusBadRequest,
"invalid_request_error",
"session.model must use provider/model on /v1 realtime client secret routes",
nil,
)
}
if providerKey == "" || model == "" {
return "", "", nil, newRealtimeClientSecretHandlerError(
fasthttp.StatusBadRequest,
"invalid_request_error",
"session.model is required",
nil,
)
}
// Normalize the forwarded body so the upstream provider sees the bare model
// (strip provider prefix). Mirrors resolveRealtimeSDPTarget normalization.
normalizedBody, normalizeErr := normalizeRealtimeClientSecretBody(root, model)
if normalizeErr != nil {
return "", "", nil, normalizeErr
}
return providerKey, model, normalizedBody, nil
}
func normalizeRealtimeClientSecretBody(root map[string]json.RawMessage, bareModel string) ([]byte, *schemas.BifrostError) {
normalizedModel, marshalErr := json.Marshal(bareModel)
if marshalErr != nil {
return nil, newRealtimeClientSecretHandlerError(fasthttp.StatusInternalServerError, "server_error", "failed to encode normalized model", marshalErr)
}
// Normalize session.model if present
if sessionJSON, ok := root["session"]; ok && len(sessionJSON) > 0 {
var session map[string]json.RawMessage
if err := json.Unmarshal(sessionJSON, &session); err == nil {
if _, hasModel := session["model"]; hasModel {
session["model"] = normalizedModel
rewritten, err := json.Marshal(session)
if err != nil {
return nil, newRealtimeClientSecretHandlerError(fasthttp.StatusInternalServerError, "server_error", "failed to re-encode session", err)
}
root["session"] = rewritten
}
}
}
// Normalize top-level model if present
if _, ok := root["model"]; ok {
root["model"] = normalizedModel
}
normalized, marshalErr := json.Marshal(root)
if marshalErr != nil {
return nil, newRealtimeClientSecretHandlerError(fasthttp.StatusInternalServerError, "server_error", "failed to re-encode body", marshalErr)
}
return normalized, nil
}
const realtimeEphemeralKeyMappingPrefix = "realtime:ephemeral-key:"
type realtimeEphemeralKeyMapping struct {
KeyID string `json:"key_id,omitempty"`
VirtualKey string `json:"virtual_key,omitempty"`
}
func cacheRealtimeEphemeralKeyMapping(kv schemas.KVStore, body []byte, keyID string, virtualKey string) {
if kv == nil || len(body) == 0 || strings.TrimSpace(keyID) == "" {
return
}
token, ttl, ok := parseRealtimeEphemeralKeyMapping(body)
if !ok || strings.TrimSpace(token) == "" || ttl <= 0 {
return
}
payload, err := json.Marshal(realtimeEphemeralKeyMapping{
KeyID: strings.TrimSpace(keyID),
VirtualKey: strings.TrimSpace(virtualKey),
})
if err != nil {
logger.Warn("failed to encode realtime ephemeral key mapping for key_id=%s: %v", keyID, err)
return
}
if err := kv.SetWithTTL(buildRealtimeEphemeralKeyMappingKey(token), payload, ttl); err != nil {
logger.Warn("failed to cache realtime ephemeral key mapping for key_id=%s: %v", keyID, err)
}
}
func parseRealtimeEphemeralKeyMapping(body []byte) (string, time.Duration, bool) {
var root map[string]json.RawMessage
if err := json.Unmarshal(body, &root); err != nil {
return "", 0, false
}
var clientSecret struct {
Value string `json:"value"`
ExpiresAt int64 `json:"expires_at"`
}
// OpenAI client_secrets responses expose the ephemeral token at the top level.
// Keep accepting the nested shape too so the mapping logic stays compatible
// with any provider/session endpoint variants that wrap the secret object.
if err := json.Unmarshal(body, &clientSecret); err != nil || strings.TrimSpace(clientSecret.Value) == "" || clientSecret.ExpiresAt <= 0 {
clientSecretRaw, ok := root["client_secret"]
if !ok || len(clientSecretRaw) == 0 || string(clientSecretRaw) == "null" {
return "", 0, false
}
if err := json.Unmarshal(clientSecretRaw, &clientSecret); err != nil {
return "", 0, false
}
}
if strings.TrimSpace(clientSecret.Value) == "" || clientSecret.ExpiresAt <= 0 {
return "", 0, false
}
ttl := time.Until(time.Unix(clientSecret.ExpiresAt, 0))
if ttl <= 0 {
return "", 0, false
}
return clientSecret.Value, ttl, true
}
func buildRealtimeEphemeralKeyMappingKey(token string) string {
return realtimeEphemeralKeyMappingPrefix + strings.TrimSpace(token)
}
func realtimeSessionNotSupportedError(providerKey schemas.ModelProvider, provider schemas.Provider) *schemas.BifrostError {
if rtProvider, ok := provider.(schemas.RealtimeProvider); ok && rtProvider.SupportsRealtimeAPI() {
return newRealtimeClientSecretHandlerError(
fasthttp.StatusBadRequest,
"invalid_request_error",
fmt.Sprintf("provider %s supports realtime websocket connections but not realtime client secret creation", providerKey),
nil,
)
}
return newRealtimeClientSecretHandlerError(
fasthttp.StatusBadRequest,
"invalid_request_error",
fmt.Sprintf("provider %s does not support realtime client secret creation", providerKey),
nil,
)
}
func newRealtimeClientSecretHandlerError(status int, errorType, message string, err error) *schemas.BifrostError {
return &schemas.BifrostError{
IsBifrostError: false,
StatusCode: schemas.Ptr(status),
Error: &schemas.ErrorField{
Type: schemas.Ptr(errorType),
Message: message,
Error: err,
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.RealtimeRequest,
},
}
}
func writeRealtimeClientSecretResponse(ctx *fasthttp.RequestCtx, resp *schemas.BifrostPassthroughResponse) {
if resp == nil {
SendBifrostError(ctx, newRealtimeClientSecretHandlerError(
fasthttp.StatusInternalServerError,
"server_error",
"provider returned an empty realtime client secret response",
nil,
))
return
}
for key, value := range resp.Headers {
ctx.Response.Header.Set(key, value)
}
if len(ctx.Response.Header.ContentType()) == 0 {
ctx.SetContentType("application/json")
}
ctx.SetStatusCode(resp.StatusCode)
ctx.SetBody(resp.Body)
}
func isJSONContentType(contentType string) bool {
mediaType, _, err := mime.ParseMediaType(contentType)
if err != nil {
return false
}
mediaType = strings.ToLower(mediaType)
return mediaType == "application/json" || strings.HasSuffix(mediaType, "+json")
}

View File

@@ -0,0 +1,414 @@
package handlers
import (
"context"
"encoding/json"
"fmt"
"testing"
"time"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/kvstore"
"github.com/maximhq/bifrost/plugins/governance"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
func TestResolveRealtimeClientSecretTarget(t *testing.T) {
t.Parallel()
tests := []struct {
name string
route schemas.RealtimeSessionRoute
body []byte
wantProvider schemas.ModelProvider
wantModel string
wantErr bool
}{
{
name: "base route with session model",
route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets},
body: []byte(`{"session":{"model":"openai/gpt-4o-realtime-preview"}}`),
wantProvider: schemas.OpenAI,
wantModel: "gpt-4o-realtime-preview",
},
{
name: "base route with top level model",
route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/sessions", EndpointType: schemas.RealtimeSessionEndpointSessions},
body: []byte(`{"model":"openai/gpt-4o-realtime-preview"}`),
wantProvider: schemas.OpenAI,
wantModel: "gpt-4o-realtime-preview",
},
{
name: "openai alias uses bare model",
route: schemas.RealtimeSessionRoute{Path: "/openai/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets, DefaultProvider: schemas.OpenAI},
body: []byte(`{"session":{"model":"gpt-4o-realtime-preview"}}`),
wantProvider: schemas.OpenAI,
wantModel: "gpt-4o-realtime-preview",
},
{
name: "base route rejects bare model",
route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets},
body: []byte(`{"session":{"model":"gpt-4o-realtime-preview"}}`),
wantErr: true,
},
{
name: "missing model",
route: schemas.RealtimeSessionRoute{Path: "/openai/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets, DefaultProvider: schemas.OpenAI},
body: []byte(`{"session":{}}`),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
gotProvider, gotModel, _, err := resolveRealtimeClientSecretTarget(tt.route, tt.body)
if tt.wantErr {
if err == nil {
t.Fatal("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("resolveRealtimeClientSecretTarget() error = %v", err)
}
if gotProvider != tt.wantProvider {
t.Fatalf("provider = %q, want %q", gotProvider, tt.wantProvider)
}
if gotModel != tt.wantModel {
t.Fatalf("model = %q, want %q", gotModel, tt.wantModel)
}
})
}
}
func TestResolveRealtimeClientSecretTarget_NormalizesModel(t *testing.T) {
t.Parallel()
tests := []struct {
name string
route schemas.RealtimeSessionRoute
body string
wantModel string // bare model expected in normalized body
}{
{
name: "session.model provider prefix stripped",
route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets},
body: `{"session":{"model":"openai/gpt-4o-realtime-preview","voice":"alloy"}}`,
wantModel: "gpt-4o-realtime-preview",
},
{
name: "top-level model provider prefix stripped",
route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/sessions", EndpointType: schemas.RealtimeSessionEndpointSessions},
body: `{"model":"openai/gpt-4o-realtime-preview"}`,
wantModel: "gpt-4o-realtime-preview",
},
{
name: "bare model unchanged on alias route",
route: schemas.RealtimeSessionRoute{Path: "/openai/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets, DefaultProvider: schemas.OpenAI},
body: `{"session":{"model":"gpt-4o-realtime-preview"}}`,
wantModel: "gpt-4o-realtime-preview",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, _, normalizedBody, err := resolveRealtimeClientSecretTarget(tt.route, []byte(tt.body))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var root map[string]json.RawMessage
if unmarshalErr := json.Unmarshal(normalizedBody, &root); unmarshalErr != nil {
t.Fatalf("failed to unmarshal normalized body: %v", unmarshalErr)
}
// Check session.model if present
if sessionJSON, ok := root["session"]; ok {
var session map[string]json.RawMessage
if unmarshalErr := json.Unmarshal(sessionJSON, &session); unmarshalErr != nil {
t.Fatalf("failed to unmarshal session: %v", unmarshalErr)
}
if modelJSON, ok := session["model"]; ok {
var model string
if unmarshalErr := json.Unmarshal(modelJSON, &model); unmarshalErr != nil {
t.Fatalf("failed to unmarshal session.model: %v", unmarshalErr)
}
if model != tt.wantModel {
t.Fatalf("session.model = %q, want %q", model, tt.wantModel)
}
}
}
// Check top-level model if present
if modelJSON, ok := root["model"]; ok {
var model string
if unmarshalErr := json.Unmarshal(modelJSON, &model); unmarshalErr != nil {
t.Fatalf("failed to unmarshal model: %v", unmarshalErr)
}
if model != tt.wantModel {
t.Fatalf("model = %q, want %q", model, tt.wantModel)
}
}
})
}
}
func TestParseRealtimeEphemeralKeyMapping(t *testing.T) {
t.Parallel()
token, ttl, ok := parseRealtimeEphemeralKeyMapping([]byte(`{
"value": "ek_test_123",
"expires_at": 4102444800
}`))
if !ok {
t.Fatal("expected ephemeral mapping to be parsed")
}
if token != "ek_test_123" {
t.Fatalf("token = %q, want %q", token, "ek_test_123")
}
if ttl <= 0 {
t.Fatalf("ttl = %v, want > 0", ttl)
}
}
func TestParseRealtimeEphemeralKeyMapping_NestedFallback(t *testing.T) {
t.Parallel()
token, ttl, ok := parseRealtimeEphemeralKeyMapping([]byte(`{
"client_secret": {
"value": "ek_test_nested",
"expires_at": 4102444800
}
}`))
if !ok {
t.Fatal("expected nested ephemeral mapping to be parsed")
}
if token != "ek_test_nested" {
t.Fatalf("token = %q, want %q", token, "ek_test_nested")
}
if ttl <= 0 {
t.Fatalf("ttl = %v, want > 0", ttl)
}
}
func TestCacheRealtimeEphemeralKeyMappingStoresKeyID(t *testing.T) {
t.Parallel()
store, err := kvstore.New(kvstore.Config{})
if err != nil {
t.Fatalf("kvstore.New() error = %v", err)
}
defer store.Close()
body := []byte(`{
"value": "ek_test_456",
"expires_at": ` + "4102444800" + `
}`)
cacheRealtimeEphemeralKeyMapping(store, body, "key_123", "sk-bf-test")
raw, err := store.Get(buildRealtimeEphemeralKeyMappingKey("ek_test_456"))
if err != nil {
t.Fatalf("store.Get() error = %v", err)
}
value, ok := raw.([]byte)
if !ok {
t.Fatalf("cached value type = %T, want []byte", raw)
}
var mapping realtimeEphemeralKeyMapping
if err := json.Unmarshal(value, &mapping); err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
if mapping.KeyID != "key_123" {
t.Fatalf("mapping.KeyID = %q, want %q", mapping.KeyID, "key_123")
}
if mapping.VirtualKey != "sk-bf-test" {
t.Fatalf("mapping.VirtualKey = %q, want %q", mapping.VirtualKey, "sk-bf-test")
}
}
func TestCacheRealtimeEphemeralKeyMappingSkipsExpiredSecrets(t *testing.T) {
t.Parallel()
store, err := kvstore.New(kvstore.Config{})
if err != nil {
t.Fatalf("kvstore.New() error = %v", err)
}
defer store.Close()
expired := time.Now().Add(-time.Minute).Unix()
body := fmt.Appendf(nil, `{
"value": "ek_expired",
"expires_at": %d
}`, expired)
cacheRealtimeEphemeralKeyMapping(store, body, "key_123", "")
if _, err := store.Get(buildRealtimeEphemeralKeyMappingKey("ek_expired")); err == nil {
t.Fatal("expected no cached mapping for expired token")
}
}
func TestIsJSONContentType(t *testing.T) {
t.Parallel()
if !isJSONContentType("application/json; charset=utf-8") {
t.Fatal("expected application/json content type to pass")
}
if !isJSONContentType("application/vnd.openai+json") {
t.Fatal("expected +json content type to pass")
}
if isJSONContentType("text/plain") {
t.Fatal("expected text/plain content type to fail")
}
}
type mockRealtimeMintingGovernancePlugin struct {
err *schemas.BifrostError
seenUserID string
seenVirtualKey string
seenProvider schemas.ModelProvider
seenModel string
evaluateCalls int
}
func (m *mockRealtimeMintingGovernancePlugin) GetName() string {
return governance.PluginName
}
func (m *mockRealtimeMintingGovernancePlugin) EvaluateGovernanceRequest(ctx *schemas.BifrostContext, evaluationRequest *governance.EvaluationRequest, _ schemas.RequestType) (*governance.EvaluationResult, *schemas.BifrostError) {
m.evaluateCalls++
m.seenUserID = ""
m.seenVirtualKey = ""
m.seenProvider = ""
m.seenModel = ""
if evaluationRequest != nil {
m.seenUserID = evaluationRequest.UserID
m.seenVirtualKey = evaluationRequest.VirtualKey
m.seenProvider = evaluationRequest.Provider
m.seenModel = evaluationRequest.Model
}
if ctx != nil && m.seenVirtualKey == "" {
m.seenVirtualKey = bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey)
}
if m.err != nil {
return nil, m.err
}
return &governance.EvaluationResult{Decision: governance.DecisionAllow}, nil
}
func (m *mockRealtimeMintingGovernancePlugin) HTTPTransportPreHook(_ *schemas.BifrostContext, _ *schemas.HTTPRequest) (*schemas.HTTPResponse, error) {
return nil, nil
}
func (m *mockRealtimeMintingGovernancePlugin) HTTPTransportPostHook(_ *schemas.BifrostContext, _ *schemas.HTTPRequest, _ *schemas.HTTPResponse) error {
return nil
}
func (m *mockRealtimeMintingGovernancePlugin) PreLLMHook(_ *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) {
return req, nil, nil
}
func (m *mockRealtimeMintingGovernancePlugin) PostLLMHook(_ *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) {
return result, bifrostErr, nil
}
func (m *mockRealtimeMintingGovernancePlugin) PreMCPHook(_ *schemas.BifrostContext, req *schemas.BifrostMCPRequest) (*schemas.BifrostMCPRequest, *schemas.MCPPluginShortCircuit, error) {
return req, nil, nil
}
func (m *mockRealtimeMintingGovernancePlugin) PostMCPHook(_ *schemas.BifrostContext, resp *schemas.BifrostMCPResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostMCPResponse, *schemas.BifrostError, error) {
return resp, bifrostErr, nil
}
func (m *mockRealtimeMintingGovernancePlugin) Cleanup() error {
return nil
}
func (m *mockRealtimeMintingGovernancePlugin) GetGovernanceStore() governance.GovernanceStore {
return nil
}
func TestRealtimeClientSecretsEvaluateMintingGovernance_RequiresAccess(t *testing.T) {
t.Parallel()
config := &lib.Config{}
plugin := &mockRealtimeMintingGovernancePlugin{
err: &schemas.BifrostError{
Type: schemas.Ptr("virtual_key_required"),
StatusCode: schemas.Ptr(401),
Error: &schemas.ErrorField{
Message: "virtual key is required. Provide a virtual key via the x-bf-vk header.",
},
},
}
plugins := []schemas.BasePlugin{plugin}
config.BasePlugins.Store(&plugins)
handler := NewRealtimeClientSecretsHandler(nil, config)
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
defer bifrostCtx.Done()
err := handler.evaluateMintingGovernance(bifrostCtx, schemas.OpenAI, "gpt-realtime")
if err == nil {
t.Fatal("expected governance error")
}
if err.StatusCode == nil {
t.Fatal("expected status code")
}
if got, want := *err.StatusCode, fasthttp.StatusUnauthorized; got != want {
t.Fatalf("status = %d, want %d", got, want)
}
}
func TestRealtimeClientSecretsEvaluateMintingGovernance_PassesContext(t *testing.T) {
t.Parallel()
config := &lib.Config{}
plugin := &mockRealtimeMintingGovernancePlugin{}
plugins := []schemas.BasePlugin{
plugin,
}
config.BasePlugins.Store(&plugins)
handler := NewRealtimeClientSecretsHandler(nil, config)
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
defer bifrostCtx.Done()
bifrostCtx.SetValue(schemas.BifrostContextKeyUserID, "user_123")
bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, "sk-bf-123")
if err := handler.evaluateMintingGovernance(bifrostCtx, schemas.OpenAI, "gpt-realtime"); err != nil {
t.Fatalf("unexpected governance error: %v", err)
}
if plugin.evaluateCalls != 1 {
t.Fatalf("evaluate calls = %d, want 1", plugin.evaluateCalls)
}
if plugin.seenUserID != "user_123" {
t.Fatalf("governance user id = %q, want %q", plugin.seenUserID, "user_123")
}
if plugin.seenVirtualKey != "sk-bf-123" {
t.Fatalf("virtual key = %q, want %q", plugin.seenVirtualKey, "sk-bf-123")
}
if plugin.seenProvider != schemas.OpenAI {
t.Fatalf("provider = %q, want %q", plugin.seenProvider, schemas.OpenAI)
}
if plugin.seenModel != "gpt-realtime" {
t.Fatalf("model = %q, want %q", plugin.seenModel, "gpt-realtime")
}
}
func TestRealtimeClientSecretsEvaluateMintingGovernance_ContinuesWithoutGovernance(t *testing.T) {
t.Parallel()
handler := NewRealtimeClientSecretsHandler(nil, &lib.Config{})
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
defer bifrostCtx.Done()
if err := handler.evaluateMintingGovernance(bifrostCtx, schemas.OpenAI, "gpt-realtime"); err != nil {
t.Fatalf("unexpected governance error without plugin: %v", err)
}
}

View File

@@ -0,0 +1,441 @@
package handlers
import (
"encoding/json"
"strings"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/schemas"
bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket"
)
type realtimeTurnSource string
const (
realtimeTurnSourceEI realtimeTurnSource = "ei"
realtimeTurnSourceLM realtimeTurnSource = "lm"
)
const (
realtimeMissingTranscriptText = "[Audio transcription unavailable]"
)
func extractRealtimeTurnSummary(event *schemas.BifrostRealtimeEvent, contentOverride string) string {
if strings.TrimSpace(contentOverride) != "" {
return strings.TrimSpace(contentOverride)
}
if event == nil {
return ""
}
if event.Error != nil && strings.TrimSpace(event.Error.Message) != "" {
return strings.TrimSpace(event.Error.Message)
}
if event.Delta != nil {
if text := strings.TrimSpace(event.Delta.Text); text != "" {
return text
}
if transcript := strings.TrimSpace(event.Delta.Transcript); transcript != "" {
return transcript
}
}
if event.Item != nil {
if summary := extractRealtimeItemSummary(event.Item); summary != "" {
return summary
}
}
if event.Session != nil && strings.TrimSpace(event.Session.Instructions) != "" {
return strings.TrimSpace(event.Session.Instructions)
}
if len(event.RawData) > 0 {
return strings.TrimSpace(string(event.RawData))
}
return ""
}
func extractRealtimeItemSummary(item *schemas.RealtimeItem) string {
if item == nil {
return ""
}
if summary := extractRealtimeContentSummary(item.Content); summary != "" {
return summary
}
switch {
case strings.TrimSpace(item.Output) != "":
return strings.TrimSpace(item.Output)
case strings.TrimSpace(item.Arguments) != "":
return strings.TrimSpace(item.Arguments)
case strings.TrimSpace(item.Name) != "":
return strings.TrimSpace(item.Name)
default:
return ""
}
}
func extractRealtimeContentSummary(raw []byte) string {
if len(raw) == 0 {
return ""
}
var decoded any
if err := sonic.Unmarshal(raw, &decoded); err != nil {
return strings.TrimSpace(string(raw))
}
var parts []string
collectRealtimeTextFragments(decoded, &parts)
return strings.Join(parts, " ")
}
func collectRealtimeTextFragments(value any, parts *[]string) {
switch v := value.(type) {
case map[string]any:
for key, field := range v {
switch key {
case "text", "transcript", "input_text", "output_text", "output", "arguments":
if text, ok := field.(string); ok {
text = strings.TrimSpace(text)
if text != "" {
*parts = append(*parts, text)
}
continue
}
}
collectRealtimeTextFragments(field, parts)
}
case []any:
for _, item := range v {
collectRealtimeTextFragments(item, parts)
}
}
}
func finalizedRealtimeInputSummary(event *schemas.BifrostRealtimeEvent) string {
if event == nil {
return ""
}
switch event.Type {
case schemas.RTEventInputAudioTransCompleted:
if transcript := extractRealtimeExtraParamString(event, "transcript"); transcript != "" {
return transcript
}
return realtimeMissingTranscriptText
default:
if event != nil && event.Type == schemas.RTEventConversationItemDone && schemas.IsRealtimeUserInputEvent(event) {
if summary := extractRealtimeItemSummary(event.Item); summary != "" {
return summary
}
if realtimeItemHasMissingAudioTranscript(event.Item) {
return realtimeMissingTranscriptText
}
}
if schemas.IsRealtimeUserInputEvent(event) {
return extractRealtimeItemSummary(event.Item)
}
}
return ""
}
func pendingRealtimeInputUpdate(event *schemas.BifrostRealtimeEvent) (string, string) {
if event == nil {
return "", ""
}
switch event.Type {
case schemas.RTEventConversationItemRetrieved:
return "", ""
case schemas.RTEventInputAudioTransCompleted:
return realtimeEventItemID(event), finalizedRealtimeInputSummary(event)
default:
if schemas.IsRealtimeUserInputEvent(event) {
return realtimeEventItemID(event), finalizedRealtimeInputSummary(event)
}
}
return "", ""
}
func realtimeItemHasMissingAudioTranscript(item *schemas.RealtimeItem) bool {
if item == nil || len(item.Content) == 0 {
return false
}
var decoded []map[string]any
if err := sonic.Unmarshal(item.Content, &decoded); err != nil {
return false
}
for _, part := range decoded {
partType, _ := part["type"].(string)
if partType != "input_audio" {
continue
}
transcript, exists := part["transcript"]
if !exists || transcript == nil {
return true
}
if text, ok := transcript.(string); ok && strings.TrimSpace(text) == "" {
return true
}
}
return false
}
func finalizedRealtimeToolOutputSummary(event *schemas.BifrostRealtimeEvent) string {
if !schemas.IsRealtimeToolOutputEvent(event) {
return ""
}
return extractRealtimeItemSummary(event.Item)
}
func pendingRealtimeToolOutputUpdate(event *schemas.BifrostRealtimeEvent) (string, string) {
if event == nil || event.Type == schemas.RTEventConversationItemRetrieved || !schemas.IsRealtimeToolOutputEvent(event) {
return "", ""
}
return realtimeEventItemID(event), finalizedRealtimeToolOutputSummary(event)
}
func extractRealtimeExtraParamString(event *schemas.BifrostRealtimeEvent, key string) string {
if event == nil || event.ExtraParams == nil {
return ""
}
raw, ok := event.ExtraParams[key]
if !ok || len(raw) == 0 {
return ""
}
var value string
if err := json.Unmarshal(raw, &value); err != nil {
return ""
}
return strings.TrimSpace(value)
}
func realtimeEventItemID(event *schemas.BifrostRealtimeEvent) string {
if event == nil {
return ""
}
if event.Item != nil && strings.TrimSpace(event.Item.ID) != "" {
return strings.TrimSpace(event.Item.ID)
}
if event.Delta != nil && strings.TrimSpace(event.Delta.ItemID) != "" {
return strings.TrimSpace(event.Delta.ItemID)
}
return extractRealtimeExtraParamString(event, "item_id")
}
func combineRealtimeInputRaw(turnInputs []bfws.RealtimeTurnInput) string {
var parts []string
for _, turnInput := range turnInputs {
if trimmed := strings.TrimSpace(turnInput.Raw); trimmed != "" {
parts = append(parts, trimmed)
}
}
return strings.Join(parts, "\n\n")
}
type realtimeResponseDoneEnvelope struct {
Response struct {
Output []realtimeResponseDoneOutput `json:"output"`
Usage *realtimeResponseDoneUsage `json:"usage"`
} `json:"response"`
}
type realtimeResponseDoneOutput struct {
ID string `json:"id"`
Type string `json:"type"`
Name string `json:"name"`
CallID string `json:"call_id"`
Arguments string `json:"arguments"`
Content []realtimeResponseDoneContent `json:"content"`
}
type realtimeResponseDoneContent struct {
Type string `json:"type"`
Text string `json:"text"`
Transcript string `json:"transcript"`
Refusal string `json:"refusal"`
}
type realtimeResponseDoneUsage struct {
TotalTokens int `json:"total_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokenDetails *realtimeResponseDoneInputTokenUsage `json:"input_token_details"`
OutputTokenDetails *realtimeResponseDoneOutputTokenUsage `json:"output_token_details"`
}
type realtimeResponseDoneInputTokenUsage struct {
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
ImageTokens int `json:"image_tokens"`
CachedTokens int `json:"cached_tokens"`
}
type realtimeResponseDoneOutputTokenUsage struct {
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
ReasoningTokens int `json:"reasoning_tokens"`
ImageTokens *int `json:"image_tokens"`
CitationTokens *int `json:"citation_tokens"`
NumSearchQueries *int `json:"num_search_queries"`
AcceptedPredictionTokens int `json:"accepted_prediction_tokens"`
RejectedPredictionTokens int `json:"rejected_prediction_tokens"`
}
func extractRealtimeTurnUsage(provider schemas.RealtimeProvider, rawMessage []byte) *schemas.BifrostLLMUsage {
if extractor, ok := provider.(schemas.RealtimeUsageExtractor); ok {
if usage := extractor.ExtractRealtimeTurnUsage(rawMessage); usage != nil {
return usage
}
}
return extractRealtimeResponseDoneUsage(rawMessage)
}
func extractRealtimeTurnOutputMessage(provider schemas.RealtimeProvider, rawMessage []byte, contentSummary string) *schemas.ChatMessage {
if extractor, ok := provider.(schemas.RealtimeUsageExtractor); ok {
if message := extractor.ExtractRealtimeTurnOutput(rawMessage); message != nil {
if strings.TrimSpace(contentSummary) != "" && (message.Content == nil || message.Content.ContentStr == nil || strings.TrimSpace(*message.Content.ContentStr) == "") {
message.Content = &schemas.ChatMessageContent{ContentStr: schemas.Ptr(strings.TrimSpace(contentSummary))}
}
return message
}
}
return buildRealtimeAssistantLogMessage(rawMessage, contentSummary)
}
func buildRealtimeAssistantLogMessage(rawMessage []byte, contentSummary string) *schemas.ChatMessage {
contentSummary = strings.TrimSpace(contentSummary)
var parsed realtimeResponseDoneEnvelope
if len(rawMessage) > 0 && sonic.Unmarshal(rawMessage, &parsed) == nil {
message := &schemas.ChatMessage{Role: schemas.ChatMessageRoleAssistant}
if contentSummary == "" {
contentSummary = extractRealtimeResponseDoneAssistantText(parsed.Response.Output)
}
if contentSummary != "" {
message.Content = &schemas.ChatMessageContent{ContentStr: schemas.Ptr(contentSummary)}
}
toolCalls := extractRealtimeResponseDoneToolCalls(parsed.Response.Output)
if len(toolCalls) > 0 {
message.ChatAssistantMessage = &schemas.ChatAssistantMessage{
ToolCalls: toolCalls,
}
}
if message.Content != nil || message.ChatAssistantMessage != nil {
return message
}
}
if contentSummary == "" {
return nil
}
return &schemas.ChatMessage{
Role: schemas.ChatMessageRoleAssistant,
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr(contentSummary)},
}
}
func extractRealtimeResponseDoneAssistantText(outputs []realtimeResponseDoneOutput) string {
var parts []string
for _, output := range outputs {
if output.Type != "message" {
continue
}
for _, block := range output.Content {
switch {
case strings.TrimSpace(block.Text) != "":
parts = append(parts, strings.TrimSpace(block.Text))
case strings.TrimSpace(block.Transcript) != "":
parts = append(parts, strings.TrimSpace(block.Transcript))
case strings.TrimSpace(block.Refusal) != "":
parts = append(parts, strings.TrimSpace(block.Refusal))
}
}
}
return strings.Join(parts, " ")
}
func extractRealtimeResponseDoneToolCalls(outputs []realtimeResponseDoneOutput) []schemas.ChatAssistantMessageToolCall {
toolCalls := make([]schemas.ChatAssistantMessageToolCall, 0)
for _, output := range outputs {
if output.Type != "function_call" {
continue
}
name := strings.TrimSpace(output.Name)
if name == "" {
continue
}
toolType := "function"
id := strings.TrimSpace(output.CallID)
if id == "" {
id = strings.TrimSpace(output.ID)
}
toolCall := schemas.ChatAssistantMessageToolCall{
Index: uint16(len(toolCalls)),
Type: &toolType,
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: schemas.Ptr(name),
Arguments: output.Arguments,
},
}
if id != "" {
toolCall.ID = schemas.Ptr(id)
}
toolCalls = append(toolCalls, toolCall)
}
return toolCalls
}
func extractRealtimeResponseDoneUsage(rawMessage []byte) *schemas.BifrostLLMUsage {
if len(rawMessage) == 0 {
return nil
}
var parsed realtimeResponseDoneEnvelope
if err := sonic.Unmarshal(rawMessage, &parsed); err != nil || parsed.Response.Usage == nil {
return nil
}
totalTokens := parsed.Response.Usage.TotalTokens
if totalTokens == 0 && (parsed.Response.Usage.InputTokens > 0 || parsed.Response.Usage.OutputTokens > 0) {
totalTokens = parsed.Response.Usage.InputTokens + parsed.Response.Usage.OutputTokens
}
usage := &schemas.BifrostLLMUsage{
PromptTokens: parsed.Response.Usage.InputTokens,
CompletionTokens: parsed.Response.Usage.OutputTokens,
TotalTokens: totalTokens,
}
if parsed.Response.Usage.InputTokenDetails != nil {
usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{
TextTokens: parsed.Response.Usage.InputTokenDetails.TextTokens,
AudioTokens: parsed.Response.Usage.InputTokenDetails.AudioTokens,
ImageTokens: parsed.Response.Usage.InputTokenDetails.ImageTokens,
CachedReadTokens: parsed.Response.Usage.InputTokenDetails.CachedTokens,
}
}
if parsed.Response.Usage.OutputTokenDetails != nil {
usage.CompletionTokensDetails = &schemas.ChatCompletionTokensDetails{
TextTokens: parsed.Response.Usage.OutputTokenDetails.TextTokens,
AudioTokens: parsed.Response.Usage.OutputTokenDetails.AudioTokens,
ReasoningTokens: parsed.Response.Usage.OutputTokenDetails.ReasoningTokens,
ImageTokens: parsed.Response.Usage.OutputTokenDetails.ImageTokens,
CitationTokens: parsed.Response.Usage.OutputTokenDetails.CitationTokens,
NumSearchQueries: parsed.Response.Usage.OutputTokenDetails.NumSearchQueries,
AcceptedPredictionTokens: parsed.Response.Usage.OutputTokenDetails.AcceptedPredictionTokens,
RejectedPredictionTokens: parsed.Response.Usage.OutputTokenDetails.RejectedPredictionTokens,
}
}
return usage
}

View File

@@ -0,0 +1,435 @@
package handlers
import (
"encoding/json"
"testing"
"time"
"github.com/maximhq/bifrost/core/providers/openai"
"github.com/maximhq/bifrost/core/schemas"
bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket"
)
func TestShouldAccumulateRealtimeOutput(t *testing.T) {
provider := &openai.OpenAIProvider{}
if !provider.ShouldAccumulateRealtimeOutput(schemas.RTEventResponseTextDelta) {
t.Fatal("expected response.text.delta to accumulate output text")
}
if !provider.ShouldAccumulateRealtimeOutput(schemas.RTEventResponseAudioTransDelta) {
t.Fatal("expected response.audio_transcript.delta to accumulate output transcript")
}
if provider.ShouldAccumulateRealtimeOutput(schemas.RTEventInputAudioTransDelta) {
t.Fatal("did not expect input audio transcription delta to accumulate assistant output")
}
}
func TestExtractRealtimeTurnSummary(t *testing.T) {
event := &schemas.BifrostRealtimeEvent{
Type: schemas.RTEventConversationItemCreate,
Item: &schemas.RealtimeItem{
Content: []byte(`[{"type":"input_text","text":"hello from realtime"}]`),
},
}
got := extractRealtimeTurnSummary(event, "")
if got != "hello from realtime" {
t.Fatalf("extractRealtimeTurnSummary() = %q, want %q", got, "hello from realtime")
}
}
func TestFinalizedRealtimeInputSummary(t *testing.T) {
userCreate := &schemas.BifrostRealtimeEvent{
Type: schemas.RTEventConversationItemCreate,
Item: &schemas.RealtimeItem{
Role: "user",
Content: []byte(`[{"type":"input_text","text":"hello from browser"}]`),
},
}
if got := finalizedRealtimeInputSummary(userCreate); got != "hello from browser" {
t.Fatalf("finalizedRealtimeInputSummary(user create) = %q, want %q", got, "hello from browser")
}
userRetrieved := &schemas.BifrostRealtimeEvent{
Type: schemas.RTEventConversationItemRetrieved,
Item: &schemas.RealtimeItem{
Role: "user",
Content: []byte(`[{"type":"input_text","text":"hello from retrieved item"}]`),
},
}
if got := finalizedRealtimeInputSummary(userRetrieved); got != "hello from retrieved item" {
t.Fatalf("finalizedRealtimeInputSummary(user retrieved) = %q, want %q", got, "hello from retrieved item")
}
userCreated := &schemas.BifrostRealtimeEvent{
Type: schemas.RTEventConversationItemCreated,
Item: &schemas.RealtimeItem{
Role: "user",
Content: []byte(`[{"type":"input_text","text":"hello from provider created item"}]`),
},
}
if got := finalizedRealtimeInputSummary(userCreated); got != "hello from provider created item" {
t.Fatalf("finalizedRealtimeInputSummary(user created) = %q, want %q", got, "hello from provider created item")
}
userAdded := &schemas.BifrostRealtimeEvent{
Type: schemas.RTEventConversationItemAdded,
Item: &schemas.RealtimeItem{
Role: "user",
Content: []byte(`[{"type":"input_text","text":"hello from provider added item"}]`),
},
}
if got := finalizedRealtimeInputSummary(userAdded); got != "hello from provider added item" {
t.Fatalf("finalizedRealtimeInputSummary(user added) = %q, want %q", got, "hello from provider added item")
}
userCreatedWithoutTranscript := &schemas.BifrostRealtimeEvent{
Type: schemas.RTEventConversationItemCreated,
Item: &schemas.RealtimeItem{
Role: "user",
Type: "message",
Content: []byte(`[{"type":"input_audio","audio":null,"transcript":null}]`),
},
RawData: []byte(`{"type":"conversation.item.created","item":{"type":"message","role":"user","content":[{"type":"input_audio","audio":null,"transcript":null}]}}`),
}
if got := finalizedRealtimeInputSummary(userCreatedWithoutTranscript); got != "" {
t.Fatalf("finalizedRealtimeInputSummary(user created without transcript) = %q, want empty", got)
}
userDoneWithoutTranscript := &schemas.BifrostRealtimeEvent{
Type: schemas.RTEventConversationItemDone,
Item: &schemas.RealtimeItem{
Role: "user",
Type: "message",
Status: "completed",
Content: []byte(`[{"type":"input_audio","audio":null,"transcript":null}]`),
},
RawData: []byte(`{"type":"conversation.item.done","item":{"type":"message","role":"user","status":"completed","content":[{"type":"input_audio","audio":null,"transcript":null}]}}`),
}
if got := finalizedRealtimeInputSummary(userDoneWithoutTranscript); got != realtimeMissingTranscriptText {
t.Fatalf("finalizedRealtimeInputSummary(user done without transcript) = %q, want %q", got, realtimeMissingTranscriptText)
}
inputTranscript := &schemas.BifrostRealtimeEvent{
Type: schemas.RTEventInputAudioTransCompleted,
ExtraParams: map[string]json.RawMessage{
"transcript": json.RawMessage(`"spoken user turn"`),
},
}
if got := finalizedRealtimeInputSummary(inputTranscript); got != "spoken user turn" {
t.Fatalf("finalizedRealtimeInputSummary(input transcript) = %q, want %q", got, "spoken user turn")
}
emptyInputTranscript := &schemas.BifrostRealtimeEvent{
Type: schemas.RTEventInputAudioTransCompleted,
ExtraParams: map[string]json.RawMessage{
"transcript": json.RawMessage(`""`),
},
RawData: []byte(`{"type":"conversation.item.input_audio_transcription.completed","transcript":"","usage":{"total_tokens":11}}`),
}
if got := finalizedRealtimeInputSummary(emptyInputTranscript); got != realtimeMissingTranscriptText {
t.Fatalf("finalizedRealtimeInputSummary(empty input transcript) = %q, want %q", got, realtimeMissingTranscriptText)
}
missingInputTranscript := &schemas.BifrostRealtimeEvent{
Type: schemas.RTEventInputAudioTransCompleted,
RawData: []byte(`{"type":"conversation.item.input_audio_transcription.completed","usage":{"total_tokens":11}}`),
}
if got := finalizedRealtimeInputSummary(missingInputTranscript); got != realtimeMissingTranscriptText {
t.Fatalf("finalizedRealtimeInputSummary(missing input transcript) = %q, want %q", got, realtimeMissingTranscriptText)
}
assistantCreate := &schemas.BifrostRealtimeEvent{
Type: schemas.RTEventConversationItemCreate,
Item: &schemas.RealtimeItem{
Role: "assistant",
Content: []byte(`[{"type":"text","text":"assistant text"}]`),
},
}
if got := finalizedRealtimeInputSummary(assistantCreate); got != "" {
t.Fatalf("finalizedRealtimeInputSummary(assistant create) = %q, want empty", got)
}
}
func TestFinalizedRealtimeToolOutputSummary(t *testing.T) {
event := &schemas.BifrostRealtimeEvent{
Type: schemas.RTEventConversationItemCreate,
Item: &schemas.RealtimeItem{
Type: "function_call_output",
Output: `{"nextResponse":"tool result"}`,
},
}
if got := finalizedRealtimeToolOutputSummary(event); got != `{"nextResponse":"tool result"}` {
t.Fatalf("finalizedRealtimeToolOutputSummary() = %q, want %q", got, `{"nextResponse":"tool result"}`)
}
retrieved := &schemas.BifrostRealtimeEvent{
Type: schemas.RTEventConversationItemRetrieved,
Item: &schemas.RealtimeItem{
Type: "function_call_output",
Output: `{"nextResponse":"tool result from retrieved"}`,
},
}
if got := finalizedRealtimeToolOutputSummary(retrieved); got != `{"nextResponse":"tool result from retrieved"}` {
t.Fatalf("finalizedRealtimeToolOutputSummary(retrieved) = %q, want %q", got, `{"nextResponse":"tool result from retrieved"}`)
}
created := &schemas.BifrostRealtimeEvent{
Type: schemas.RTEventConversationItemCreated,
Item: &schemas.RealtimeItem{
Type: "function_call_output",
Output: `{"nextResponse":"tool result from created"}`,
},
}
if got := finalizedRealtimeToolOutputSummary(created); got != `{"nextResponse":"tool result from created"}` {
t.Fatalf("finalizedRealtimeToolOutputSummary(created) = %q, want %q", got, `{"nextResponse":"tool result from created"}`)
}
added := &schemas.BifrostRealtimeEvent{
Type: schemas.RTEventConversationItemAdded,
Item: &schemas.RealtimeItem{
Type: "function_call_output",
Output: `{"nextResponse":"tool result from added"}`,
},
}
if got := finalizedRealtimeToolOutputSummary(added); got != `{"nextResponse":"tool result from added"}` {
t.Fatalf("finalizedRealtimeToolOutputSummary(added) = %q, want %q", got, `{"nextResponse":"tool result from added"}`)
}
}
func TestPendingRealtimeInputUpdate(t *testing.T) {
t.Parallel()
transcriptEvent := &schemas.BifrostRealtimeEvent{
Type: schemas.RTEventInputAudioTransCompleted,
ExtraParams: map[string]json.RawMessage{
"item_id": json.RawMessage(`"item_123"`),
"transcript": json.RawMessage(`"Hello."`),
},
}
itemID, summary := pendingRealtimeInputUpdate(transcriptEvent)
if itemID != "item_123" || summary != "Hello." {
t.Fatalf("pendingRealtimeInputUpdate(transcript) = (%q, %q), want (%q, %q)", itemID, summary, "item_123", "Hello.")
}
retrievedEvent := &schemas.BifrostRealtimeEvent{
Type: schemas.RTEventConversationItemRetrieved,
Item: &schemas.RealtimeItem{
ID: "item_123",
Role: "user",
Content: []byte(`[{"type":"input_text","text":"historical hello"}]`),
},
}
itemID, summary = pendingRealtimeInputUpdate(retrievedEvent)
if itemID != "" || summary != "" {
t.Fatalf("pendingRealtimeInputUpdate(retrieved) = (%q, %q), want empty", itemID, summary)
}
}
func TestPendingRealtimeToolOutputUpdate(t *testing.T) {
t.Parallel()
toolOutputEvent := &schemas.BifrostRealtimeEvent{
Type: schemas.RTEventConversationItemDone,
Item: &schemas.RealtimeItem{
ID: "item_tool_123",
Type: "function_call_output",
Output: `{"nextResponse":"tool result"}`,
},
}
itemID, summary := pendingRealtimeToolOutputUpdate(toolOutputEvent)
if itemID != "item_tool_123" || summary != `{"nextResponse":"tool result"}` {
t.Fatalf("pendingRealtimeToolOutputUpdate(done) = (%q, %q), want (%q, %q)", itemID, summary, "item_tool_123", `{"nextResponse":"tool result"}`)
}
retrievedToolOutputEvent := &schemas.BifrostRealtimeEvent{
Type: schemas.RTEventConversationItemRetrieved,
Item: &schemas.RealtimeItem{
ID: "item_tool_123",
Type: "function_call_output",
Output: `{"nextResponse":"historical tool result"}`,
},
}
itemID, summary = pendingRealtimeToolOutputUpdate(retrievedToolOutputEvent)
if itemID != "" || summary != "" {
t.Fatalf("pendingRealtimeToolOutputUpdate(retrieved) = (%q, %q), want empty", itemID, summary)
}
}
func TestBuildRealtimeTurnPostResponseUsesFullResponseDonePayload(t *testing.T) {
rawRequest := `{"type":"conversation.item.input_audio_transcription.completed","transcript":""}`
rawResponse := []byte(`{
"type":"response.done",
"response":{
"output":[
{
"id":"item_message_123",
"type":"message",
"content":[
{
"type":"audio",
"transcript":"assistant turn text"
}
]
}
],
"usage":{
"total_tokens":26,
"input_tokens":17,
"output_tokens":9,
"input_token_details":{
"text_tokens":12,
"audio_tokens":5,
"image_tokens":0,
"cached_tokens":4
},
"output_token_details":{
"text_tokens":7,
"audio_tokens":2
}
}
}
}`)
resp := buildRealtimeTurnPostResponse(&openai.OpenAIProvider{}, schemas.OpenAI, "gpt-4o-realtime-preview-2025-06-03", rawRequest, rawResponse, "", 4321)
if resp == nil || resp.ResponsesResponse == nil {
t.Fatal("expected realtime post response to be built")
}
if resp.ResponsesResponse.ExtraFields.Latency != 4321 {
t.Fatalf("Latency = %d, want %d", resp.ResponsesResponse.ExtraFields.Latency, 4321)
}
if resp.ResponsesResponse.Usage == nil || resp.ResponsesResponse.Usage.InputTokens != 17 || resp.ResponsesResponse.Usage.OutputTokens != 9 || resp.ResponsesResponse.Usage.TotalTokens != 26 {
t.Fatalf("Usage = %+v, want input=17 output=9 total=26", resp.ResponsesResponse.Usage)
}
if len(resp.ResponsesResponse.Output) != 1 {
t.Fatalf("len(Output) = %d, want 1", len(resp.ResponsesResponse.Output))
}
if resp.ResponsesResponse.Output[0].Content == nil || resp.ResponsesResponse.Output[0].Content.ContentStr == nil || *resp.ResponsesResponse.Output[0].Content.ContentStr != "assistant turn text" {
t.Fatalf("Output[0].Content = %+v, want assistant turn text", resp.ResponsesResponse.Output[0].Content)
}
if got, ok := resp.ResponsesResponse.ExtraFields.RawRequest.(string); !ok || got != rawRequest {
t.Fatalf("RawRequest = %#v, want %q", resp.ResponsesResponse.ExtraFields.RawRequest, rawRequest)
}
if got, ok := resp.ResponsesResponse.ExtraFields.RawResponse.(string); !ok || got == "" {
t.Fatalf("RawResponse = %#v, want raw response string", resp.ResponsesResponse.ExtraFields.RawResponse)
}
}
func TestFinalizeRealtimeTurnHooksWithErrorCompletesActiveHooks(t *testing.T) {
t.Parallel()
session := bfws.NewSession(nil)
session.SetProviderSessionID("sess_provider_123")
session.AddRealtimeInput("hello from user", `{"type":"conversation.item.added"}`)
session.AppendRealtimeOutputText("partial assistant output")
var (
capturedResp *schemas.BifrostResponse
capturedErr *schemas.BifrostError
cleanedUp bool
)
session.SetRealtimeTurnHooks(&bfws.RealtimeTurnPluginState{
RequestID: "req_realtime_123",
StartedAt: time.Now().Add(-time.Second),
PreHookValues: map[any]any{
schemas.BifrostContextKeyGovernanceVirtualKeyID: "vk_123",
},
PostHookRunner: func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
capturedResp = result
capturedErr = err
return result, nil
},
Cleanup: func() {
cleanedUp = true
},
})
rawResponse := []byte(`{"type":"error","error":{"type":"server_error","message":"Virtual key is required."}}`)
postErr := finalizeRealtimeTurnHooksWithError(
nil,
nil,
session,
schemas.OpenAI,
"gpt-realtime",
nil,
schemas.RTEventError,
rawResponse,
newRealtimeWireBifrostError(401, "server_error", "Virtual key is required."),
)
if postErr != nil {
t.Fatalf("finalizeRealtimeTurnHooksWithError() post error = %v, want nil", postErr)
}
if capturedResp != nil {
t.Fatalf("captured response = %#v, want nil", capturedResp)
}
if capturedErr == nil {
t.Fatal("expected captured error")
}
if capturedErr.ExtraFields.RequestType != schemas.RealtimeRequest {
t.Fatalf("request type = %q, want %q", capturedErr.ExtraFields.RequestType, schemas.RealtimeRequest)
}
if capturedErr.ExtraFields.Provider != schemas.OpenAI {
t.Fatalf("provider = %q, want %q", capturedErr.ExtraFields.Provider, schemas.OpenAI)
}
if capturedErr.ExtraFields.OriginalModelRequested != "gpt-realtime" {
t.Fatalf("model requested = %q, want %q", capturedErr.ExtraFields.OriginalModelRequested, "gpt-realtime")
}
rawRequest, ok := capturedErr.ExtraFields.RawRequest.(string)
if !ok || rawRequest == "" {
t.Fatalf("raw request = %#v, want non-empty string", capturedErr.ExtraFields.RawRequest)
}
rawResp, ok := capturedErr.ExtraFields.RawResponse.(json.RawMessage)
if !ok || string(rawResp) != string(rawResponse) {
t.Fatalf("raw response = %#v, want %s", capturedErr.ExtraFields.RawResponse, string(rawResponse))
}
if session.PeekRealtimeTurnHooks() != nil {
t.Fatal("expected active hooks to be cleared")
}
if got := session.ConsumeRealtimeTurnInputs(); len(got) != 0 {
t.Fatalf("remaining turn inputs = %d, want 0", len(got))
}
if got := session.ConsumeRealtimeOutputText(); got != "" {
t.Fatalf("remaining output text = %q, want empty", got)
}
if !cleanedUp {
t.Fatal("expected realtime hook cleanup to run")
}
}
func TestNewBifrostErrorFromRealtimeErrorCarriesRealtimeMetadata(t *testing.T) {
t.Parallel()
rawResponse := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request_error","message":"bad request","param":"session.type"}}`)
bifrostErr := newBifrostErrorFromRealtimeError(
schemas.OpenAI,
"gpt-realtime",
rawResponse,
&schemas.RealtimeError{
Type: "invalid_request_error",
Code: "invalid_request_error",
Message: "bad request",
Param: "session.type",
},
)
if bifrostErr == nil {
t.Fatal("expected bifrost error")
}
if bifrostErr.StatusCode == nil || *bifrostErr.StatusCode != 400 {
t.Fatalf("status code = %#v, want 400", bifrostErr.StatusCode)
}
if bifrostErr.ExtraFields.RequestType != schemas.RealtimeRequest {
t.Fatalf("request type = %q, want %q", bifrostErr.ExtraFields.RequestType, schemas.RealtimeRequest)
}
if bifrostErr.ExtraFields.Provider != schemas.OpenAI {
t.Fatalf("provider = %q, want %q", bifrostErr.ExtraFields.Provider, schemas.OpenAI)
}
if bifrostErr.ExtraFields.OriginalModelRequested != "gpt-realtime" {
t.Fatalf("model requested = %q, want %q", bifrostErr.ExtraFields.OriginalModelRequested, "gpt-realtime")
}
rawResp, ok := bifrostErr.ExtraFields.RawResponse.(json.RawMessage)
if !ok || string(rawResp) != string(rawResponse) {
t.Fatalf("raw response = %#v, want %s", bifrostErr.ExtraFields.RawResponse, string(rawResponse))
}
if bifrostErr.Error == nil || bifrostErr.Error.Param != "session.type" {
t.Fatalf("error param = %#v, want session.type", bifrostErr.Error)
}
}

View File

@@ -0,0 +1,798 @@
package handlers
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/google/uuid"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket"
)
func newRealtimeTurnContext(
baseCtx *schemas.BifrostContext,
requestID string,
sessionID string,
providerSessionID string,
source realtimeTurnSource,
eventType schemas.RealtimeEventType,
key *schemas.Key,
) *schemas.BifrostContext {
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
if baseCtx != nil {
// Realtime post-hook contexts must preserve plugin-private values written in
// pre-hooks (for example telemetry start timestamps), not just public keys.
for ctxKey, value := range baseCtx.GetUserValues() {
if value != nil {
ctx.SetValue(ctxKey, value)
}
}
}
ctx.SetValue(schemas.BifrostContextKeyHTTPRequestType, schemas.RealtimeRequest)
if requestID == "" {
requestID = uuid.NewString()
}
ctx.SetValue(schemas.BifrostContextKeyRequestID, requestID)
resolvedSessionID := strings.TrimSpace(providerSessionID)
if resolvedSessionID == "" {
resolvedSessionID = strings.TrimSpace(sessionID)
}
if baseCtx != nil {
if externalSessionID, ok := baseCtx.Value(schemas.BifrostContextKeyParentRequestID).(string); ok && strings.TrimSpace(externalSessionID) != "" {
resolvedSessionID = strings.TrimSpace(externalSessionID)
}
}
if resolvedSessionID != "" {
ctx.SetValue(schemas.BifrostContextKeyParentRequestID, resolvedSessionID)
}
if strings.TrimSpace(providerSessionID) != "" {
ctx.SetValue(schemas.BifrostContextKeyRealtimeSessionID, providerSessionID)
ctx.SetValue(schemas.BifrostContextKeyRealtimeProviderSessionID, providerSessionID)
}
if source != "" {
ctx.SetValue(schemas.BifrostContextKeyRealtimeSource, string(source))
}
if eventType != "" {
ctx.SetValue(schemas.BifrostContextKeyRealtimeEventType, string(eventType))
}
if key != nil {
if strings.TrimSpace(key.ID) != "" {
ctx.SetValue(schemas.BifrostContextKeySelectedKeyID, key.ID)
}
if strings.TrimSpace(key.Name) != "" {
ctx.SetValue(schemas.BifrostContextKeySelectedKeyName, key.Name)
}
}
return ctx
}
func applyRealtimeTurnContextValues(ctx *schemas.BifrostContext, values map[any]any) {
if ctx == nil || len(values) == 0 {
return
}
for ctxKey, value := range values {
switch ctxKey {
case schemas.BifrostContextKeyRequestID,
schemas.BifrostContextKeyParentRequestID,
schemas.BifrostContextKeyRealtimeSessionID,
schemas.BifrostContextKeyRealtimeProviderSessionID,
schemas.BifrostContextKeyRealtimeSource,
schemas.BifrostContextKeyRealtimeEventType,
schemas.BifrostContextKeyStreamStartTime,
schemas.BifrostContextKeyStreamEndIndicator:
continue
}
if value != nil {
ctx.SetValue(ctxKey, value)
}
}
}
func setRealtimeTurnStreamContext(ctx *schemas.BifrostContext, startedAt time.Time, isFinal bool) {
if ctx == nil {
return
}
if startedAt.IsZero() {
startedAt = time.Now()
}
ctx.SetValue(schemas.BifrostContextKeyStreamStartTime, startedAt)
if isFinal {
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
}
}
func buildRealtimeTurnPreRequest(provider schemas.ModelProvider, model string, turnInputs []bfws.RealtimeTurnInput) *schemas.BifrostRequest {
input := make([]schemas.ResponsesMessage, 0, len(turnInputs))
for _, turnInput := range turnInputs {
summary := strings.TrimSpace(turnInput.Summary)
if summary == "" {
continue
}
switch turnInput.Role {
case string(schemas.ChatMessageRoleTool):
itemType := schemas.ResponsesMessageTypeFunctionCallOutput
output := &schemas.ResponsesToolMessageOutputStruct{
ResponsesToolCallOutputStr: schemas.Ptr(summary),
}
input = append(input, schemas.ResponsesMessage{
Type: &itemType,
ResponsesToolMessage: &schemas.ResponsesToolMessage{Output: output},
})
case string(schemas.ChatMessageRoleUser):
itemType := schemas.ResponsesMessageTypeMessage
role := schemas.ResponsesInputMessageRoleUser
input = append(input, schemas.ResponsesMessage{
Type: &itemType,
Role: &role,
Content: &schemas.ResponsesMessageContent{ContentStr: schemas.Ptr(summary)},
})
}
}
return &schemas.BifrostRequest{
RequestType: schemas.RealtimeRequest,
ResponsesRequest: &schemas.BifrostResponsesRequest{
Provider: provider,
Model: model,
Input: input,
},
}
}
func buildRealtimeTurnPostResponse(
rtProvider schemas.RealtimeProvider,
provider schemas.ModelProvider,
model string,
rawRequest string,
rawResponse []byte,
contentOverride string,
latency int64,
) *schemas.BifrostResponse {
output := buildRealtimeTurnOutputMessages(rtProvider, rawResponse, contentOverride)
resp := &schemas.BifrostResponsesResponse{
Object: "response",
Model: model,
Output: output,
ExtraFields: schemas.BifrostResponseExtraFields{
RequestType: schemas.RealtimeRequest,
Provider: provider,
OriginalModelRequested: model,
Latency: latency,
},
}
if usage := extractRealtimeTurnUsage(rtProvider, rawResponse); usage != nil {
resp.Usage = buildRealtimeResponsesUsage(usage)
}
if strings.TrimSpace(rawRequest) != "" {
resp.ExtraFields.RawRequest = rawRequest
}
if len(rawResponse) > 0 {
resp.ExtraFields.RawResponse = string(rawResponse)
}
return &schemas.BifrostResponse{ResponsesResponse: resp}
}
func buildRealtimeTurnOutputMessages(rtProvider schemas.RealtimeProvider, rawResponse []byte, contentOverride string) []schemas.ResponsesMessage {
outputs := make([]schemas.ResponsesMessage, 0)
if outputMessage := extractRealtimeTurnOutputMessage(rtProvider, rawResponse, contentOverride); outputMessage != nil {
outputs = append(outputs, buildRealtimeResponsesMessagesFromChat(outputMessage, contentOverride)...)
}
if len(outputs) > 0 {
return outputs
}
var parsed realtimeResponseDoneEnvelope
if len(rawResponse) > 0 && schemas.Unmarshal(rawResponse, &parsed) == nil {
for _, item := range parsed.Response.Output {
switch item.Type {
case "message":
content := strings.TrimSpace(contentOverride)
if content == "" {
content = extractRealtimeResponseDoneContentText(item.Content)
}
itemType := schemas.ResponsesMessageTypeMessage
role := schemas.ResponsesInputMessageRoleAssistant
msg := schemas.ResponsesMessage{
Type: &itemType,
Role: &role,
Status: schemas.Ptr("completed"),
}
if strings.TrimSpace(item.ID) != "" {
msg.ID = schemas.Ptr(strings.TrimSpace(item.ID))
}
if content != "" {
msg.Content = &schemas.ResponsesMessageContent{ContentStr: schemas.Ptr(content)}
}
outputs = append(outputs, msg)
case "function_call":
itemType := schemas.ResponsesMessageTypeFunctionCall
msg := schemas.ResponsesMessage{
Type: &itemType,
Status: schemas.Ptr("completed"),
ResponsesToolMessage: &schemas.ResponsesToolMessage{
Name: schemas.Ptr(strings.TrimSpace(item.Name)),
Arguments: schemas.Ptr(item.Arguments),
},
}
if strings.TrimSpace(item.ID) != "" {
msg.ID = schemas.Ptr(strings.TrimSpace(item.ID))
}
if strings.TrimSpace(item.CallID) != "" {
msg.CallID = schemas.Ptr(strings.TrimSpace(item.CallID))
}
outputs = append(outputs, msg)
}
}
}
if len(outputs) == 0 && strings.TrimSpace(contentOverride) != "" {
itemType := schemas.ResponsesMessageTypeMessage
role := schemas.ResponsesInputMessageRoleAssistant
outputs = append(outputs, schemas.ResponsesMessage{
Type: &itemType,
Role: &role,
Status: schemas.Ptr("completed"),
Content: &schemas.ResponsesMessageContent{ContentStr: schemas.Ptr(strings.TrimSpace(contentOverride))},
})
}
return outputs
}
func buildRealtimeResponsesMessagesFromChat(message *schemas.ChatMessage, contentOverride string) []schemas.ResponsesMessage {
if message == nil {
return nil
}
outputs := make([]schemas.ResponsesMessage, 0, 1)
content := strings.TrimSpace(contentOverride)
if content == "" && message.Content != nil && message.Content.ContentStr != nil {
content = strings.TrimSpace(*message.Content.ContentStr)
}
if content != "" {
itemType := schemas.ResponsesMessageTypeMessage
role := schemas.ResponsesInputMessageRoleAssistant
outputs = append(outputs, schemas.ResponsesMessage{
Type: &itemType,
Role: &role,
Status: schemas.Ptr("completed"),
Content: &schemas.ResponsesMessageContent{ContentStr: schemas.Ptr(content)},
})
}
if message.ChatAssistantMessage == nil {
return outputs
}
for _, toolCall := range message.ChatAssistantMessage.ToolCalls {
itemType := schemas.ResponsesMessageTypeFunctionCall
msg := schemas.ResponsesMessage{
Type: &itemType,
Status: schemas.Ptr("completed"),
ResponsesToolMessage: &schemas.ResponsesToolMessage{
Arguments: schemas.Ptr(toolCall.Function.Arguments),
},
}
if toolCall.Function.Name != nil {
msg.ResponsesToolMessage.Name = schemas.Ptr(strings.TrimSpace(*toolCall.Function.Name))
}
if toolCall.ID != nil {
msg.CallID = schemas.Ptr(strings.TrimSpace(*toolCall.ID))
msg.ID = schemas.Ptr(strings.TrimSpace(*toolCall.ID))
}
outputs = append(outputs, msg)
}
return outputs
}
func extractRealtimeResponseDoneContentText(content []realtimeResponseDoneContent) string {
for _, block := range content {
switch {
case strings.TrimSpace(block.Text) != "":
return strings.TrimSpace(block.Text)
case strings.TrimSpace(block.Transcript) != "":
return strings.TrimSpace(block.Transcript)
case strings.TrimSpace(block.Refusal) != "":
return strings.TrimSpace(block.Refusal)
}
}
return ""
}
func buildRealtimeResponsesUsage(usage *schemas.BifrostLLMUsage) *schemas.ResponsesResponseUsage {
if usage == nil {
return nil
}
result := &schemas.ResponsesResponseUsage{
InputTokens: usage.PromptTokens,
OutputTokens: usage.CompletionTokens,
TotalTokens: usage.TotalTokens,
}
if usage.PromptTokensDetails != nil {
result.InputTokensDetails = &schemas.ResponsesResponseInputTokens{
TextTokens: usage.PromptTokensDetails.TextTokens,
AudioTokens: usage.PromptTokensDetails.AudioTokens,
ImageTokens: usage.PromptTokensDetails.ImageTokens,
CachedReadTokens: usage.PromptTokensDetails.CachedReadTokens,
CachedWriteTokens: usage.PromptTokensDetails.CachedWriteTokens,
}
}
if usage.CompletionTokensDetails != nil {
result.OutputTokensDetails = &schemas.ResponsesResponseOutputTokens{
TextTokens: usage.CompletionTokensDetails.TextTokens,
AcceptedPredictionTokens: usage.CompletionTokensDetails.AcceptedPredictionTokens,
AudioTokens: usage.CompletionTokensDetails.AudioTokens,
ImageTokens: usage.CompletionTokensDetails.ImageTokens,
ReasoningTokens: usage.CompletionTokensDetails.ReasoningTokens,
RejectedPredictionTokens: usage.CompletionTokensDetails.RejectedPredictionTokens,
CitationTokens: usage.CompletionTokensDetails.CitationTokens,
NumSearchQueries: usage.CompletionTokensDetails.NumSearchQueries,
}
}
return result
}
func newRealtimeTurnErrorEventPayload(bifrostErr *schemas.BifrostError) []byte {
if bifrostErr == nil {
return []byte(`{"type":"error","error":{"type":"server_error","message":"internal server error"}}`)
}
errorType, errorCode, errorMessage, errorParam := mapRealtimeWireErrorFields(bifrostErr)
payload := schemas.BifrostRealtimeEvent{
Type: schemas.RTEventError,
Error: &schemas.RealtimeError{
Type: errorType,
Code: errorCode,
Message: errorMessage,
Param: errorParam,
},
}
if data, err := schemas.Marshal(payload); err == nil {
return data
}
return []byte(`{"type":"error","error":{"type":"server_error","message":"internal server error"}}`)
}
// isBudgetOrBillingError returns true if the lowercased value indicates a budget or billing exhaustion error.
// Quota/rate-limit patterns (quota_exceeded, quota exceeded, etc.) are already covered by bifrost.IsRateLimitErrorMessage.
func isBudgetOrBillingError(lower string) bool {
return strings.Contains(lower, "budget_exceeded") ||
strings.Contains(lower, "budget exceeded") ||
strings.Contains(lower, "insufficient_quota") ||
strings.Contains(lower, "hard limit reached") ||
strings.Contains(lower, "billing hard limit")
}
func mapRealtimeWireErrorFields(bifrostErr *schemas.BifrostError) (string, string, string, string) {
errorType := "server_error"
errorCode := "server_error"
errorMessage := "internal server error"
errorParam := ""
if bifrostErr == nil {
return errorType, errorCode, errorMessage, errorParam
}
var values []string
if bifrostErr.Type != nil {
values = append(values, strings.TrimSpace(*bifrostErr.Type))
}
if bifrostErr.Error != nil {
if bifrostErr.Error.Type != nil {
values = append(values, strings.TrimSpace(*bifrostErr.Error.Type))
}
if bifrostErr.Error.Code != nil {
values = append(values, strings.TrimSpace(*bifrostErr.Error.Code))
}
if strings.TrimSpace(bifrostErr.Error.Message) != "" {
errorMessage = strings.TrimSpace(bifrostErr.Error.Message)
values = append(values, errorMessage)
}
if bifrostErr.Error.Param != nil {
errorParam = strings.TrimSpace(fmt.Sprint(bifrostErr.Error.Param))
}
}
for _, value := range values {
lower := strings.ToLower(value)
switch {
case lower == "":
continue
case strings.Contains(lower, "invalid_request_error"):
return "invalid_request_error", "invalid_request_error", errorMessage, errorParam
case isBudgetOrBillingError(lower):
return "insufficient_quota", "insufficient_quota", errorMessage, errorParam
case bifrost.IsRateLimitErrorMessage(lower):
return "rate_limit_exceeded", "rate_limit_exceeded", errorMessage, errorParam
}
}
return errorType, errorCode, errorMessage, errorParam
}
func shouldGracefullyDisconnectRealtime(bifrostErr *schemas.BifrostError) bool {
if bifrostErr == nil {
return false
}
var values []string
if bifrostErr.Type != nil {
values = append(values, strings.TrimSpace(*bifrostErr.Type))
}
if bifrostErr.Error != nil {
if bifrostErr.Error.Type != nil {
values = append(values, strings.TrimSpace(*bifrostErr.Error.Type))
}
if bifrostErr.Error.Code != nil {
values = append(values, strings.TrimSpace(*bifrostErr.Error.Code))
}
values = append(values, strings.TrimSpace(bifrostErr.Error.Message))
}
for _, value := range values {
lower := strings.ToLower(value)
if lower == "" {
continue
}
if isBudgetOrBillingError(lower) || bifrost.IsRateLimitErrorMessage(lower) {
return true
}
}
return false
}
func startRealtimeTurnHooks(
client *bifrost.Bifrost,
baseCtx *schemas.BifrostContext,
session *bfws.Session,
rtProvider schemas.RealtimeProvider,
provider schemas.ModelProvider,
model string,
key *schemas.Key,
startEventType schemas.RealtimeEventType,
) *schemas.BifrostError {
if client == nil || session == nil {
return &schemas.BifrostError{
Type: schemas.Ptr("server_error"),
StatusCode: schemas.Ptr(500),
Error: &schemas.ErrorField{
Type: schemas.Ptr("server_error"),
Message: "realtime turn pipeline is unavailable",
},
}
}
if !session.TryBeginRealtimeTurnHooks() {
return &schemas.BifrostError{
Type: schemas.Ptr("invalid_request_error"),
StatusCode: schemas.Ptr(400),
Error: &schemas.ErrorField{
Type: schemas.Ptr("invalid_request_error"),
Message: "Conversation already has an active response in progress.",
},
}
}
committed := false
defer func() {
if !committed {
session.AbortRealtimeTurnHooks()
}
}()
startedAt := time.Now()
turnCtx := newRealtimeTurnContext(baseCtx, "", session.ID(), session.ProviderSessionID(), realtimeTurnSourceEI, startEventType, key)
setRealtimeTurnStreamContext(turnCtx, startedAt, false)
req := buildRealtimeTurnPreRequest(provider, model, session.PeekRealtimeTurnInputs())
hooks, bifrostErr := client.RunRealtimeTurnPreHooks(turnCtx, req)
if bifrostErr != nil {
// RunRealtimeTurnPreHooks already executed post-hooks and flushed the trace
// for this turn-start failure. Clear buffered turn state so transport-close
// fallback finalization does not emit the same error a second time.
session.ConsumeRealtimeTurnInputs()
session.ConsumeRealtimeOutputText()
return bifrostErr
}
requestID, _ := turnCtx.Value(schemas.BifrostContextKeyRequestID).(string)
session.SetRealtimeTurnHooks(&bfws.RealtimeTurnPluginState{
PostHookRunner: hooks.PostHookRunner,
Cleanup: hooks.Cleanup,
RequestID: requestID,
StartedAt: startedAt,
PreHookValues: turnCtx.GetUserValues(),
})
committed = true
return nil
}
func finalizeRealtimeTurnHooks(
client *bifrost.Bifrost,
baseCtx *schemas.BifrostContext,
session *bfws.Session,
rtProvider schemas.RealtimeProvider,
provider schemas.ModelProvider,
model string,
key *schemas.Key,
rawResponse []byte,
contentOverride string,
) *schemas.BifrostError {
if client == nil || session == nil {
return nil
}
turnInputs := session.ConsumeRealtimeTurnInputs()
rawRequest := combineRealtimeInputRaw(turnInputs)
if activeHooks := session.ConsumeRealtimeTurnHooks(); activeHooks != nil {
defer func() {
if activeHooks.Cleanup != nil {
activeHooks.Cleanup()
}
}()
postResponse := buildRealtimeTurnPostResponse(
rtProvider,
provider,
model,
rawRequest,
rawResponse,
contentOverride,
time.Since(activeHooks.StartedAt).Milliseconds(),
)
postCtx := newRealtimeTurnContext(baseCtx, activeHooks.RequestID, session.ID(), session.ProviderSessionID(), realtimeTurnSourceLM, rtProvider.RealtimeTurnFinalEvent(), key)
applyRealtimeTurnContextValues(postCtx, activeHooks.PreHookValues)
setRealtimeTurnStreamContext(postCtx, activeHooks.StartedAt, true)
_, bifrostErr := activeHooks.PostHookRunner(postCtx, postResponse, nil)
completeRealtimeTurnTrace(postCtx)
return bifrostErr
}
startedAt := time.Now()
preCtx := newRealtimeTurnContext(baseCtx, "", session.ID(), session.ProviderSessionID(), realtimeTurnSourceEI, "", key)
setRealtimeTurnStreamContext(preCtx, startedAt, false)
preReq := buildRealtimeTurnPreRequest(provider, model, turnInputs)
hooks, bifrostErr := client.RunRealtimeTurnPreHooks(preCtx, preReq)
if bifrostErr != nil {
return bifrostErr
}
if hooks.Cleanup != nil {
defer hooks.Cleanup()
}
requestID, _ := preCtx.Value(schemas.BifrostContextKeyRequestID).(string)
postResponse := buildRealtimeTurnPostResponse(
rtProvider,
provider,
model,
rawRequest,
rawResponse,
contentOverride,
time.Since(startedAt).Milliseconds(),
)
postCtx := newRealtimeTurnContext(baseCtx, requestID, session.ID(), session.ProviderSessionID(), realtimeTurnSourceLM, rtProvider.RealtimeTurnFinalEvent(), key)
applyRealtimeTurnContextValues(postCtx, preCtx.GetUserValues())
setRealtimeTurnStreamContext(postCtx, startedAt, true)
_, bifrostErr = hooks.PostHookRunner(postCtx, postResponse, nil)
completeRealtimeTurnTrace(postCtx)
return bifrostErr
}
func finalizeRealtimeTurnHooksWithError(
client *bifrost.Bifrost,
baseCtx *schemas.BifrostContext,
session *bfws.Session,
provider schemas.ModelProvider,
model string,
key *schemas.Key,
eventType schemas.RealtimeEventType,
rawResponse []byte,
bifrostErr *schemas.BifrostError,
) *schemas.BifrostError {
if session == nil || bifrostErr == nil {
return nil
}
turnInputs := session.ConsumeRealtimeTurnInputs()
rawRequest := combineRealtimeInputRaw(turnInputs)
session.ConsumeRealtimeOutputText()
if activeHooks := session.ConsumeRealtimeTurnHooks(); activeHooks != nil {
defer func() {
if activeHooks.Cleanup != nil {
activeHooks.Cleanup()
}
}()
postErr := buildRealtimeTurnPostError(
provider,
model,
rawRequest,
rawResponse,
bifrostErr,
)
postCtx := newRealtimeTurnContext(baseCtx, activeHooks.RequestID, session.ID(), session.ProviderSessionID(), realtimeTurnSourceLM, eventType, key)
applyRealtimeTurnContextValues(postCtx, activeHooks.PreHookValues)
setRealtimeTurnStreamContext(postCtx, activeHooks.StartedAt, true)
_, hookErr := activeHooks.PostHookRunner(postCtx, nil, postErr)
completeRealtimeTurnTrace(postCtx)
return hookErr
}
if len(turnInputs) == 0 {
return nil
}
if client == nil {
return nil
}
startedAt := time.Now()
preCtx := newRealtimeTurnContext(baseCtx, "", session.ID(), session.ProviderSessionID(), realtimeTurnSourceEI, "", key)
setRealtimeTurnStreamContext(preCtx, startedAt, false)
preReq := buildRealtimeTurnPreRequest(provider, model, turnInputs)
hooks, hookPreErr := client.RunRealtimeTurnPreHooks(preCtx, preReq)
if hookPreErr != nil {
return hookPreErr
}
if hooks.Cleanup != nil {
defer hooks.Cleanup()
}
requestID, _ := preCtx.Value(schemas.BifrostContextKeyRequestID).(string)
postErr := buildRealtimeTurnPostError(
provider,
model,
rawRequest,
rawResponse,
bifrostErr,
)
postCtx := newRealtimeTurnContext(baseCtx, requestID, session.ID(), session.ProviderSessionID(), realtimeTurnSourceLM, eventType, key)
applyRealtimeTurnContextValues(postCtx, preCtx.GetUserValues())
setRealtimeTurnStreamContext(postCtx, startedAt, true)
_, hookErr := hooks.PostHookRunner(postCtx, nil, postErr)
completeRealtimeTurnTrace(postCtx)
return hookErr
}
func buildRealtimeTurnPostError(
provider schemas.ModelProvider,
model string,
rawRequest string,
rawResponse []byte,
bifrostErr *schemas.BifrostError,
) *schemas.BifrostError {
if bifrostErr == nil {
return nil
}
copied := *bifrostErr
copied.ExtraFields = bifrostErr.ExtraFields
if bifrostErr.Error != nil {
errorCopy := *bifrostErr.Error
copied.Error = &errorCopy
}
copied.ExtraFields.RequestType = schemas.RealtimeRequest
if copied.ExtraFields.Provider == "" {
copied.ExtraFields.Provider = provider
}
if strings.TrimSpace(copied.ExtraFields.OriginalModelRequested) == "" {
copied.ExtraFields.OriginalModelRequested = model
}
if strings.TrimSpace(rawRequest) != "" && copied.ExtraFields.RawRequest == nil {
copied.ExtraFields.RawRequest = rawRequest
}
if len(rawResponse) > 0 && copied.ExtraFields.RawResponse == nil {
copied.ExtraFields.RawResponse = json.RawMessage(append([]byte(nil), rawResponse...))
}
return &copied
}
func newBifrostErrorFromRealtimeError(
provider schemas.ModelProvider,
model string,
rawResponse []byte,
realtimeErr *schemas.RealtimeError,
) *schemas.BifrostError {
if realtimeErr == nil {
return nil
}
statusCode := 500
values := []string{
strings.TrimSpace(realtimeErr.Type),
strings.TrimSpace(realtimeErr.Code),
strings.TrimSpace(realtimeErr.Message),
}
for _, value := range values {
lower := strings.ToLower(value)
switch {
case lower == "":
continue
case strings.Contains(lower, "invalid_request_error"):
statusCode = 400
case isBudgetOrBillingError(lower), bifrost.IsRateLimitErrorMessage(lower):
statusCode = 429
}
}
errType := strings.TrimSpace(realtimeErr.Type)
if errType == "" {
errType = "server_error"
}
errCode := strings.TrimSpace(realtimeErr.Code)
if errCode == "" {
errCode = errType
}
message := strings.TrimSpace(realtimeErr.Message)
if message == "" {
message = "realtime turn failed"
}
bifrostErr := &schemas.BifrostError{
IsBifrostError: true,
StatusCode: schemas.Ptr(statusCode),
Type: schemas.Ptr(errType),
Error: &schemas.ErrorField{
Type: schemas.Ptr(errType),
Code: schemas.Ptr(errCode),
Message: message,
},
ExtraFields: schemas.BifrostErrorExtraFields{
Provider: provider,
OriginalModelRequested: model,
RequestType: schemas.RealtimeRequest,
},
}
if strings.TrimSpace(realtimeErr.Param) != "" {
bifrostErr.Error.Param = realtimeErr.Param
}
if len(rawResponse) > 0 {
bifrostErr.ExtraFields.RawResponse = json.RawMessage(append([]byte(nil), rawResponse...))
}
return bifrostErr
}
func completeRealtimeTurnTrace(ctx *schemas.BifrostContext) {
if ctx == nil {
return
}
traceID, _ := ctx.Value(schemas.BifrostContextKeyTraceID).(string)
if strings.TrimSpace(traceID) == "" {
return
}
tracer, _ := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer)
if tracer == nil {
return
}
tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID))
}
func finalizeRealtimeTurnHooksOnTransportError(
client *bifrost.Bifrost,
baseCtx *schemas.BifrostContext,
session *bfws.Session,
provider schemas.ModelProvider,
model string,
key *schemas.Key,
status int,
code string,
message string,
) *schemas.BifrostError {
return finalizeRealtimeTurnHooksWithError(
client,
baseCtx,
session,
provider,
model,
key,
schemas.RTEventError,
nil,
newRealtimeWireBifrostError(status, code, message),
)
}

View File

@@ -0,0 +1,230 @@
package handlers
import (
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/fasthttp/router"
"github.com/google/uuid"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore"
"github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/maximhq/bifrost/framework/encrypt"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
// SessionHandler manages HTTP requests for session operations
type SessionHandler struct {
configStore configstore.ConfigStore
wsTicketStore *WSTicketStore
}
// NewSessionHandler creates a new session handler instance
func NewSessionHandler(configStore configstore.ConfigStore, wsTicketStore *WSTicketStore) *SessionHandler {
return &SessionHandler{
configStore: configStore,
wsTicketStore: wsTicketStore,
}
}
// RegisterRoutes registers the session-related routes
func (h *SessionHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
r.POST("/api/session/login", lib.ChainMiddlewares(h.login, middlewares...))
r.POST("/api/session/logout", lib.ChainMiddlewares(h.logout, middlewares...))
r.GET("/api/session/is-auth-enabled", lib.ChainMiddlewares(h.isAuthEnabled, middlewares...))
r.POST("/api/session/ws-ticket", lib.ChainMiddlewares(h.issueWSTicket, middlewares...))
}
// isAuthEnabled handles GET /api/session/is-auth-enabled - Check if auth is enabled
func (h *SessionHandler) isAuthEnabled(ctx *fasthttp.RequestCtx) {
if h.configStore == nil {
SendJSON(ctx, map[string]any{
"is_auth_enabled": false,
})
return
}
authConfig, err := h.configStore.GetAuthConfig(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get auth config: %v", err))
return
}
if authConfig == nil {
SendJSON(ctx, map[string]any{
"is_auth_enabled": false,
})
return
}
// Check if the header has a token and is valid (Authorization header or cookie)
token := ""
if authHeader := string(ctx.Request.Header.Peek("Authorization")); strings.HasPrefix(authHeader, "Bearer ") {
token = strings.TrimPrefix(authHeader, "Bearer ")
}
if token == "" {
token = string(ctx.Request.Header.Cookie("token"))
}
hasValidToken := false
if token != "" {
session, err := h.configStore.GetSession(ctx, token)
if err == nil && session != nil && session.ExpiresAt.After(time.Now()) {
hasValidToken = true
}
}
SendJSON(ctx, map[string]any{
"is_auth_enabled": authConfig.IsEnabled,
"has_valid_token": hasValidToken,
})
}
// login handles POST /api/session/login - Login a user
func (h *SessionHandler) login(ctx *fasthttp.RequestCtx) {
if h.configStore == nil {
SendError(ctx, fasthttp.StatusForbidden, "Authentication is not enabled")
return
}
payload := struct {
Username string `json:"username"`
Password string `json:"password"`
}{}
if err := json.Unmarshal(ctx.PostBody(), &payload); err != nil {
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err))
return
}
// Get auth config
authConfig, err := h.configStore.GetAuthConfig(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get auth config: %v", err))
return
}
// Check if auth is enabled
if authConfig == nil || !authConfig.IsEnabled {
SendError(ctx, fasthttp.StatusForbidden, "Authentication is not enabled")
return
}
// Verify credentials
if payload.Username != authConfig.AdminUserName.GetValue() {
SendError(ctx, fasthttp.StatusUnauthorized, "Invalid username or password")
return
}
compare, err := encrypt.CompareHash(authConfig.AdminPassword.GetValue(), payload.Password)
if err != nil {
SendError(ctx, fasthttp.StatusUnauthorized, "Unauthorized")
return
}
if !compare {
SendError(ctx, fasthttp.StatusUnauthorized, "Invalid username or password")
return
}
// Creating a new session
token := uuid.New().String()
session := &tables.SessionsTable{
Token: token,
ExpiresAt: time.Now().Add(time.Hour * 24 * 30), // 30 days
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
err = h.configStore.CreateSession(ctx, session)
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to create session: %v", err))
return
}
// Setting cookies
cookie := fasthttp.AcquireCookie()
defer fasthttp.ReleaseCookie(cookie)
cookie.SetKey("token")
cookie.SetValue(token)
cookie.SetExpire(time.Now().Add(time.Hour * 24 * 30))
cookie.SetPath("/")
cookie.SetHTTPOnly(true)
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
// Check if source is https then set secure
if string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" {
cookie.SetSecure(true)
}
ctx.Response.Header.SetCookie(cookie)
SendJSON(ctx, map[string]any{
"message": "Login successful",
})
}
// logout handles POST /api/session/logout - Logout a user
func (h *SessionHandler) logout(ctx *fasthttp.RequestCtx) {
if h.configStore == nil {
SendError(ctx, fasthttp.StatusForbidden, "Authentication is not enabled")
return
}
// Get token from Authorization header
token := string(ctx.Request.Header.Peek("Authorization"))
token = strings.TrimPrefix(token, "Bearer ")
// If no token in header, try to get from cookie
if token == "" {
token = string(ctx.Request.Header.Cookie("token"))
}
// clear token from cookies
cookie := fasthttp.AcquireCookie()
defer fasthttp.ReleaseCookie(cookie)
cookie.SetKey("token")
cookie.SetValue("")
cookie.SetExpire(time.Now().Add(-time.Hour * 24 * 30))
cookie.SetPath("/")
cookie.SetHTTPOnly(true)
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
// Check if source is https then set secure
if string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" {
cookie.SetSecure(true)
}
ctx.Response.Header.SetCookie(cookie)
// delete session from database if token exists
if token != "" {
err := h.configStore.DeleteSession(ctx, token)
if err != nil && !errors.Is(err, configstore.ErrNotFound) {
logger.Error("failed to delete session during logout: %v", err)
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to invalidate session. Please try again.")
return
}
}
SendJSON(ctx, map[string]any{
"message": "Logout successful",
})
}
// issueWSTicket handles POST /api/session/ws-ticket - Issue a short-lived ticket for WebSocket auth.
// The caller must already be authenticated (via cookie or Authorization header).
// Returns a one-time-use ticket that the frontend passes as ?ticket= when opening the WebSocket.
func (h *SessionHandler) issueWSTicket(ctx *fasthttp.RequestCtx) {
if h.wsTicketStore == nil {
SendError(ctx, fasthttp.StatusServiceUnavailable, "WebSocket tickets are not available")
return
}
sessionToken,ok := ctx.UserValue(schemas.BifrostContextKeySessionToken).(string)
if !ok {
SendError(ctx, fasthttp.StatusUnauthorized, "Unauthorized")
return
}
if sessionToken == "" {
// This is the case where auth is not configured or not enabled
sessionToken = "dummy-session"
}
ticket, err := h.wsTicketStore.Issue(sessionToken)
if err != nil {
logger.Error("failed to issue WS ticket: %v", err)
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to issue WebSocket ticket")
return
}
SendJSON(ctx, map[string]any{
"ticket": ticket,
})
}

View File

@@ -0,0 +1,127 @@
package handlers
import (
"bufio"
"fmt"
"net"
"strings"
"testing"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
// TestSSEStreamReaderNoEventBatching verifies that SSE events are delivered
// individually through fasthttp's chunked transfer encoding, not batched
// into larger TCP segments. This is the core regression test for the
// fasthttputil.PipeConns batching fix.
func TestSSEStreamReaderNoEventBatching(t *testing.T) {
const numEvents = 20
// Build expected events
events := make([]string, numEvents)
for i := range events {
events[i] = fmt.Sprintf("data: {\"index\":%d,\"content\":\"chunk-%d\"}\n\n", i, i)
}
handler := func(ctx *fasthttp.RequestCtx) {
ctx.SetContentType("text/event-stream")
ctx.Response.Header.Set("Cache-Control", "no-cache")
reader := lib.NewSSEStreamReader()
go func() {
defer reader.Done()
for _, event := range events {
if !reader.Send([]byte(event)) {
return
}
}
}()
ctx.Response.SetBodyStream(reader, -1)
}
// Use net.Pipe for deterministic in-process testing
serverConn, clientConn := net.Pipe()
defer clientConn.Close()
// Run fasthttp server on one end of the pipe
go func() {
_ = fasthttp.ServeConn(serverConn, handler)
}()
// Send HTTP request through the pipe
_, err := clientConn.Write([]byte("GET /stream HTTP/1.1\r\nHost: test\r\n\r\n"))
if err != nil {
t.Fatalf("failed to write request: %v", err)
}
// Read response using bufio to parse chunked encoding
br := bufio.NewReader(clientConn)
// Read and skip HTTP response headers
for {
line, err := br.ReadString('\n')
if err != nil {
t.Fatalf("failed to read response header: %v", err)
}
if strings.TrimSpace(line) == "" {
break // End of headers
}
}
// Read chunked transfer-encoded body.
// Each HTTP chunk should contain exactly one SSE event.
var receivedEvents []string
for {
// Read chunk size line (hex size + CRLF)
sizeLine, err := br.ReadString('\n')
if err != nil {
t.Fatalf("failed to read chunk size: %v", err)
}
sizeLine = strings.TrimSpace(sizeLine)
var chunkSize int
_, err = fmt.Sscanf(sizeLine, "%x", &chunkSize)
if err != nil {
t.Fatalf("failed to parse chunk size %q: %v", sizeLine, err)
}
if chunkSize == 0 {
break // Terminal chunk
}
// Read exactly chunkSize bytes + trailing CRLF
chunkData := make([]byte, chunkSize+2) // +2 for CRLF
n := 0
for n < len(chunkData) {
nn, err := br.Read(chunkData[n:])
if err != nil {
t.Fatalf("failed to read chunk data: %v", err)
}
n += nn
}
chunk := string(chunkData[:chunkSize])
receivedEvents = append(receivedEvents, chunk)
}
// Verify each chunk contains exactly one SSE event
if len(receivedEvents) != numEvents {
t.Errorf("expected %d individual chunks, got %d (events were batched)", numEvents, len(receivedEvents))
for i, chunk := range receivedEvents {
eventCount := strings.Count(chunk, "\n\n")
t.Logf(" chunk %d: %d SSE events, %d bytes", i, eventCount, len(chunk))
}
}
for i, chunk := range receivedEvents {
if i >= len(events) {
break
}
if chunk != events[i] {
t.Errorf("chunk %d: got %q, want %q", i, chunk, events[i])
}
}
}

View File

@@ -0,0 +1,136 @@
package handlers
import (
"embed"
"mime"
"path"
"path/filepath"
"strings"
"github.com/fasthttp/router"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
// UIHandler handles UI routes.
type UIHandler struct {
uiContent embed.FS
}
// NewUIHandler creates a new UIHandler instance.
func NewUIHandler(uiContent embed.FS) *UIHandler {
return &UIHandler{
uiContent: uiContent,
}
}
// RegisterRoutes registers the UI routes with the provided router.
func (h *UIHandler) RegisterRoutes(router *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
router.GET("/", lib.ChainMiddlewares(h.serveDashboard, middlewares...))
router.GET("/{filepath:*}", lib.ChainMiddlewares(h.serveDashboard, middlewares...))
}
// ServeDashboard serves the dashboard UI.
func (h *UIHandler) serveDashboard(ctx *fasthttp.RequestCtx) {
// Get the request path
requestPath := string(ctx.Path())
// Clean the path to prevent directory traversal
cleanPath := path.Clean(requestPath)
// Handle .txt files - map from /{page}.txt to /{page}/index.txt
if strings.HasSuffix(cleanPath, ".txt") {
// Remove .txt extension and add /index.txt
basePath := strings.TrimSuffix(cleanPath, ".txt")
if basePath == "/" || basePath == "" {
basePath = "/index"
}
cleanPath = basePath + "/index.txt"
}
// Remove leading slash and add ui prefix
if cleanPath == "/" {
cleanPath = "ui/index.html"
} else {
cleanPath = "ui" + cleanPath
}
// Block hidden directories and files (any path segment starting with .)
segments := strings.Split(cleanPath, "/")
for _, segment := range segments {
if strings.HasPrefix(segment, ".") {
ctx.SetStatusCode(fasthttp.StatusNotFound)
ctx.SetBodyString("404 - Not found")
return
}
}
// Block sensitive files
baseName := filepath.Base(cleanPath)
sensitiveFiles := []string{"package.json", "package-lock.json"}
for _, sensitive := range sensitiveFiles {
if baseName == sensitive {
ctx.SetStatusCode(fasthttp.StatusNotFound)
ctx.SetBodyString("404 - Not found")
return
}
}
// Check if this is a static asset request (has file extension)
hasExtension := strings.Contains(filepath.Base(cleanPath), ".")
// Try to read the file from embedded filesystem
data, err := h.uiContent.ReadFile(cleanPath)
if err != nil {
// If it's a static asset (has extension) and not found, return 404
if hasExtension {
ctx.SetStatusCode(fasthttp.StatusNotFound)
ctx.SetBodyString("404 - Static asset not found: " + requestPath)
return
}
// For routes without extensions (SPA routing), try {path}/index.html first
if !hasExtension {
indexPath := cleanPath + "/index.html"
data, err = h.uiContent.ReadFile(indexPath)
if err == nil {
cleanPath = indexPath
} else {
// If that fails, serve root index.html as fallback
data, err = h.uiContent.ReadFile("ui/index.html")
if err != nil {
ctx.SetStatusCode(fasthttp.StatusNotFound)
ctx.SetBodyString("404 - File not found")
return
}
cleanPath = "ui/index.html"
}
} else {
ctx.SetStatusCode(fasthttp.StatusNotFound)
ctx.SetBodyString("404 - File not found")
return
}
}
// Set content type based on file extension
ext := filepath.Ext(cleanPath)
contentType := mime.TypeByExtension(ext)
if contentType == "" {
contentType = "application/octet-stream"
}
ctx.SetContentType(contentType)
// Set cache headers for static assets
if strings.HasPrefix(cleanPath, "ui/assets/") {
ctx.Response.Header.Set("Cache-Control", "public, max-age=31536000, immutable")
} else if ext == ".html" {
ctx.Response.Header.Set("Cache-Control", "no-cache")
} else {
ctx.Response.Header.Set("Cache-Control", "public, max-age=3600")
}
// Send the file content
ctx.SetBody(data)
}

View File

@@ -0,0 +1,225 @@
// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
// This file contains common utility functions used across all handlers.
package handlers
import (
"encoding/json"
"fmt"
"regexp"
"strings"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
// pluginDisabledKey is a dedicated context key type for marking a plugin as disabled
// rather than removed. Using a named type instead of a raw string follows Go best practices.
type pluginDisabledKey struct{}
// PluginDisabledKey is the context key used to indicate a plugin is being disabled.
var PluginDisabledKey pluginDisabledKey
// badRequestError wraps a client input validation error so that outer handlers
// can distinguish it from internal server errors and return HTTP 400.
type badRequestError struct{ err error }
func (e *badRequestError) Error() string { return e.err.Error() }
func (e *badRequestError) Unwrap() error { return e.err }
// SendJSON sends a JSON response with 200 OK status
func SendJSON(ctx *fasthttp.RequestCtx, data interface{}) {
ctx.SetContentType("application/json")
if err := json.NewEncoder(ctx).Encode(data); err != nil {
logger.Warn(fmt.Sprintf("Failed to encode JSON response: %v", err))
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to encode response: %v", err))
}
}
// SendJSONWithStatus sends a JSON response with a custom status code
func SendJSONWithStatus(ctx *fasthttp.RequestCtx, data interface{}, statusCode int) {
ctx.SetContentType("application/json")
ctx.SetStatusCode(statusCode)
if err := json.NewEncoder(ctx).Encode(data); err != nil {
logger.Warn(fmt.Sprintf("Failed to encode JSON response: %v", err))
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to encode response: %v", err))
}
}
// SendError sends a BifrostError response
func SendError(ctx *fasthttp.RequestCtx, statusCode int, message string) {
bifrostErr := &schemas.BifrostError{
IsBifrostError: false,
StatusCode: &statusCode,
Error: &schemas.ErrorField{
Message: message,
},
}
SendBifrostError(ctx, bifrostErr)
}
// SendBifrostError sends a BifrostError response
func SendBifrostError(ctx *fasthttp.RequestCtx, bifrostErr *schemas.BifrostError) {
if bifrostErr.StatusCode != nil {
ctx.SetStatusCode(*bifrostErr.StatusCode)
} else if !bifrostErr.IsBifrostError {
ctx.SetStatusCode(fasthttp.StatusBadRequest)
} else {
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
}
ctx.SetContentType("application/json")
if encodeErr := json.NewEncoder(ctx).Encode(bifrostErr); encodeErr != nil {
logger.Warn(fmt.Sprintf("Failed to encode error response: %v", encodeErr))
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
ctx.SetBodyString(fmt.Sprintf("Failed to encode error response: %v", encodeErr))
}
}
// streamLargeResponseIfActive checks if large response mode was activated by the provider
// and streams the response directly to the client. Returns true if the response was handled
// (caller should return), false if normal response handling should continue.
func streamLargeResponseIfActive(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext) bool {
isLargeResponse, ok := bifrostCtx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool)
if !ok || !isLargeResponse {
return false
}
if !lib.StreamLargeResponseBody(ctx, bifrostCtx) {
SendError(ctx, fasthttp.StatusInternalServerError, "Large response reader not available")
}
return true
}
// SendSSEError sends an error in Server-Sent Events format
func SendSSEError(ctx *fasthttp.RequestCtx, bifrostErr *schemas.BifrostError) {
errorJSON, err := json.Marshal(map[string]interface{}{
"error": bifrostErr,
})
if err != nil {
logger.Error("failed to marshal error for SSE: %v", err)
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
return
}
if _, err := fmt.Fprintf(ctx, "data: %s\n\n", errorJSON); err != nil {
logger.Warn(fmt.Sprintf("Failed to write SSE error: %v", err))
}
}
// IsOriginAllowed checks if the given origin is allowed based on localhost rules and configured allowed origins.
// Localhost origins are always allowed. Additional origins can be configured in allowedOrigins.
// Supports wildcard patterns like *.example.com to match any subdomain.
func IsOriginAllowed(origin string, allowedOrigins []string) bool {
// Always allow localhost origins
if isLocalhostOrigin(origin) {
return true
}
// Check configured allowed origins
for _, allowedOrigin := range allowedOrigins {
// Check for exact match first
if allowedOrigin == origin {
return true
}
if allowedOrigin == "*" {
return true
}
// Check for wildcard pattern
if strings.Contains(allowedOrigin, "*") {
if matchesWildcardPattern(origin, allowedOrigin) {
return true
}
}
}
return false
}
// isLocalhostOrigin checks if the given origin is a localhost origin
func isLocalhostOrigin(origin string) bool {
return strings.HasPrefix(origin, "http://localhost:") ||
strings.HasPrefix(origin, "https://localhost:") ||
strings.HasPrefix(origin, "http://127.0.0.1:") ||
strings.HasPrefix(origin, "http://0.0.0.0:") ||
strings.HasPrefix(origin, "https://127.0.0.1:")
}
// matchesWildcardPattern checks if an origin matches a wildcard pattern.
// Supports patterns like *.example.com, https://*.example.com, or http://*.example.com
func matchesWildcardPattern(origin string, pattern string) bool {
// Convert wildcard pattern to regex pattern
// Escape special regex characters except *
regexPattern := regexp.QuoteMeta(pattern)
// Replace escaped \* with regex pattern for subdomain matching
// \* should match one or more characters that are not dots (to match a subdomain)
regexPattern = strings.ReplaceAll(regexPattern, `\*`, `[^/.]+`)
// Anchor the pattern to match the entire origin
regexPattern = "^" + regexPattern + "$"
// Compile and test the regex
re, err := regexp.Compile(regexPattern)
if err != nil {
return false
}
return re.MatchString(origin)
}
// ParseModel parses a model string in the format "provider/model" or "provider/nested/model"
// Returns the provider and full model name after the first slash
func ParseModel(model string) (string, string, error) {
model = strings.TrimSpace(model)
if model == "" {
return "", "", fmt.Errorf("model cannot be empty")
}
parts := strings.SplitN(model, "/", 2)
if len(parts) < 2 {
return "", "", fmt.Errorf("model must be in the format 'provider/model'")
}
provider := strings.TrimSpace(parts[0])
name := strings.TrimSpace(parts[1])
if provider == "" || name == "" {
return "", "", fmt.Errorf("model must be in the format 'provider/model' with non-empty provider and model")
}
return provider, name, nil
}
// ClampPaginationParams applies default/max bounds to limit and offset so that
// the handler response matches the values the store actually uses.
func ClampPaginationParams(limit, offset int) (int, int) {
if limit <= 0 {
limit = 25
} else if limit > 100 {
limit = 100
}
if offset < 0 {
offset = 0
}
return limit, offset
}
// fuzzyMatch checks if all characters in query appear in text in order (case-insensitive)
// Example: "gpt4" matches "gpt-4", "gpt-4-turbo", etc.
func fuzzyMatch(text, query string) bool {
if query == "" {
return true
}
text = strings.ToLower(text)
query = strings.ToLower(query)
queryIndex := 0
queryRunes := []rune(query)
for _, textChar := range text {
if queryIndex < len(queryRunes) && textChar == queryRunes[queryIndex] {
queryIndex++
}
}
return queryIndex == len(queryRunes)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,346 @@
package handlers
import (
"context"
"encoding/json"
"testing"
"time"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/kvstore"
"github.com/maximhq/bifrost/framework/logstore"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket"
"github.com/valyala/fasthttp"
)
type testHandlerStore struct {
kv *kvstore.Store
}
func (s testHandlerStore) ShouldAllowDirectKeys() bool { return true }
func (s testHandlerStore) GetHeaderMatcher() *lib.HeaderMatcher { return nil }
func (s testHandlerStore) GetAvailableProviders() []schemas.ModelProvider { return nil }
func (s testHandlerStore) GetStreamChunkInterceptor() lib.StreamChunkInterceptor {
return nil
}
func (s testHandlerStore) GetAsyncJobExecutor() *logstore.AsyncJobExecutor { return nil }
func (s testHandlerStore) GetAsyncJobResultTTL() int { return 0 }
func (s testHandlerStore) GetKVStore() *kvstore.Store { return s.kv }
func (s testHandlerStore) GetMCPHeaderCombinedAllowlist() schemas.WhiteList { return nil }
func TestResolveRealtimeSDPTarget_BaseRouteRequiresProviderPrefix(t *testing.T) {
_, _, _, err := resolveRealtimeSDPTarget("/v1/realtime", []byte(`{"model":"gpt-4o-realtime-preview"}`))
if err == nil {
t.Fatal("expected provider/model validation error")
}
if err.Error == nil || err.Error.Message != "session.model must use provider/model on /v1 realtime routes" {
t.Fatalf("unexpected error: %#v", err)
}
}
func TestResolveRealtimeSDPTarget_BaseRouteNormalizesModel(t *testing.T) {
provider, model, normalized, err := resolveRealtimeSDPTarget("/v1/realtime", []byte(`{"model":"openai/gpt-4o-realtime-preview","voice":"alloy"}`))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if provider != schemas.OpenAI {
t.Fatalf("expected provider %s, got %s", schemas.OpenAI, provider)
}
if model != "gpt-4o-realtime-preview" {
t.Fatalf("unexpected normalized model: %s", model)
}
var root map[string]json.RawMessage
if unmarshalErr := json.Unmarshal(normalized, &root); unmarshalErr != nil {
t.Fatalf("failed to unmarshal normalized session: %v", unmarshalErr)
}
var sessionModel string
if unmarshalErr := json.Unmarshal(root["model"], &sessionModel); unmarshalErr != nil {
t.Fatalf("failed to unmarshal model: %v", unmarshalErr)
}
if sessionModel != "gpt-4o-realtime-preview" {
t.Fatalf("unexpected marshaled model: %s", sessionModel)
}
}
func TestResolveRealtimeSDPTarget_OpenAIRouteDefaultsProvider(t *testing.T) {
provider, model, _, err := resolveRealtimeSDPTarget("/openai/v1/realtime", []byte(`{"model":"gpt-4o-realtime-preview"}`))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if provider != schemas.OpenAI {
t.Fatalf("expected provider %s, got %s", schemas.OpenAI, provider)
}
if model != "gpt-4o-realtime-preview" {
t.Fatalf("unexpected model: %s", model)
}
}
func TestParseCallsWebRTCRequest_RawSDPKeepsGARoute(t *testing.T) {
var ctx fasthttp.RequestCtx
ctx.Request.Header.SetMethod(fasthttp.MethodPost)
ctx.Request.SetRequestURI("/openai/v1/realtime/calls?model=gpt-realtime")
ctx.Request.Header.SetContentType("application/sdp")
ctx.Request.SetBodyString("v=0\r\n")
sdpOffer, provider, model, session, err := parseCallsWebRTCRequest(&ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if sdpOffer != "v=0\r\n" {
t.Fatalf("unexpected sdp offer: %q", sdpOffer)
}
if provider != schemas.OpenAI {
t.Fatalf("expected provider %s, got %s", schemas.OpenAI, provider)
}
if model != "gpt-realtime" {
t.Fatalf("unexpected model: %s", model)
}
if session != nil {
t.Fatalf("expected nil session for raw SDP /calls request, got %s", string(session))
}
}
func TestNewRealtimeRelayContextCopiesValuesWithoutRequestCancellation(t *testing.T) {
requestCtx, requestCancel := schemas.NewBifrostContextWithCancel(context.Background())
requestCtx.SetValue(schemas.BifrostContextKeyHTTPRequestType, schemas.RealtimeRequest)
requestCtx.SetValue(schemas.BifrostContextKeyIntegrationType, "openai")
requestCtx.SetValue(schemas.BifrostContextKeyGovernanceVirtualKeyID, "vk_test")
relayCtx, relayCancel := newRealtimeRelayContext(requestCtx)
defer relayCancel()
requestCancel()
select {
case <-requestCtx.Done():
case <-time.After(time.Second):
t.Fatal("expected request context to be cancelled")
}
select {
case <-relayCtx.Done():
t.Fatal("relay context should outlive cancelled request context")
default:
}
if got := relayCtx.Value(schemas.BifrostContextKeyHTTPRequestType); got != schemas.RealtimeRequest {
t.Fatalf("request type = %v, want %v", got, schemas.RealtimeRequest)
}
if got := relayCtx.Value(schemas.BifrostContextKeyIntegrationType); got != "openai" {
t.Fatalf("integration type = %v, want %q", got, "openai")
}
if got := relayCtx.Value(schemas.BifrostContextKeyGovernanceVirtualKeyID); got != "vk_test" {
t.Fatalf("virtual key id = %v, want %q", got, "vk_test")
}
}
func TestParseRealtimeEventPreservesExtraParams(t *testing.T) {
event, err := schemas.ParseRealtimeEvent([]byte(`{"type":"conversation.item.truncate","item_id":"item_123","content_index":0,"audio_end_ms":640}`))
if err != nil {
t.Fatalf("ParseRealtimeEvent() error = %v", err)
}
var itemID string
if err := json.Unmarshal(event.ExtraParams["item_id"], &itemID); err != nil {
t.Fatalf("json.Unmarshal(item_id) error = %v", err)
}
if itemID != "item_123" {
t.Fatalf("item_id = %q, want %q", itemID, "item_123")
}
var contentIndex int
if err := json.Unmarshal(event.ExtraParams["content_index"], &contentIndex); err != nil {
t.Fatalf("json.Unmarshal(content_index) error = %v", err)
}
if contentIndex != 0 {
t.Fatalf("content_index = %d, want 0", contentIndex)
}
}
func TestExtractRealtimeBearerToken(t *testing.T) {
var ctx fasthttp.RequestCtx
ctx.Request.Header.Set("Authorization", "Bearer ek_test_123")
if got := extractRealtimeBearerToken(&ctx); got != "ek_test_123" {
t.Fatalf("extractRealtimeBearerToken() = %q, want %q", got, "ek_test_123")
}
}
func TestLookupRealtimeEphemeralKeyMappingKeepsEntryUntilTTLExpiry(t *testing.T) {
t.Parallel()
store, err := kvstore.New(kvstore.Config{})
if err != nil {
t.Fatalf("kvstore.New() error = %v", err)
}
defer store.Close()
payload, err := json.Marshal(realtimeEphemeralKeyMapping{KeyID: "key_123", VirtualKey: "sk-bf-test"})
if err != nil {
t.Fatalf("json.Marshal() error = %v", err)
}
if err := store.SetWithTTL(buildRealtimeEphemeralKeyMappingKey("ek_test_123"), payload, time.Minute); err != nil {
t.Fatalf("store.SetWithTTL() error = %v", err)
}
mapping, ok := lookupRealtimeEphemeralKeyMapping(store, "ek_test_123")
if !ok {
t.Fatal("expected mapping to be consumed")
}
if mapping.KeyID != "key_123" {
t.Fatalf("mapping.KeyID = %q, want %q", mapping.KeyID, "key_123")
}
if mapping.VirtualKey != "sk-bf-test" {
t.Fatalf("mapping.VirtualKey = %q, want %q", mapping.VirtualKey, "sk-bf-test")
}
raw, err := store.Get(buildRealtimeEphemeralKeyMappingKey("ek_test_123"))
if err != nil {
t.Fatalf("expected mapping to remain until TTL expiry: %v", err)
}
if raw == nil {
t.Fatal("expected mapping to remain in KV store")
}
}
func TestLookupRealtimeEphemeralKeyMapping_BackwardsCompatibleStringValue(t *testing.T) {
t.Parallel()
store, err := kvstore.New(kvstore.Config{})
if err != nil {
t.Fatalf("kvstore.New() error = %v", err)
}
defer store.Close()
if err := store.SetWithTTL(buildRealtimeEphemeralKeyMappingKey("ek_test_legacy"), "key_legacy", time.Minute); err != nil {
t.Fatalf("store.SetWithTTL() error = %v", err)
}
mapping, ok := lookupRealtimeEphemeralKeyMapping(store, "ek_test_legacy")
if !ok {
t.Fatal("expected legacy mapping to be consumed")
}
if mapping.KeyID != "key_legacy" {
t.Fatalf("mapping.KeyID = %q, want %q", mapping.KeyID, "key_legacy")
}
if mapping.VirtualKey != "" {
t.Fatalf("mapping.VirtualKey = %q, want empty", mapping.VirtualKey)
}
}
func TestWebRTCRealtimeRelayCloseFinalizesActiveTurnHooks(t *testing.T) {
t.Parallel()
session := bfws.NewSession(nil)
session.SetProviderSessionID("sess_provider_123")
session.AddRealtimeInput("hello from user", `{"type":"conversation.item.added"}`)
var (
capturedErr *schemas.BifrostError
cleanedUp bool
)
session.SetRealtimeTurnHooks(&bfws.RealtimeTurnPluginState{
RequestID: "req_realtime_123",
StartedAt: time.Now().Add(-time.Second),
PostHookRunner: func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
capturedErr = err
return result, nil
},
Cleanup: func() {
cleanedUp = true
},
})
relay := &webrtcRealtimeRelay{
session: session,
providerKey: schemas.OpenAI,
model: "gpt-realtime",
}
relay.close()
if capturedErr == nil {
t.Fatal("expected active turn to be finalized with an error on close")
}
if capturedErr.ExtraFields.RequestType != schemas.RealtimeRequest {
t.Fatalf("request type = %q, want %q", capturedErr.ExtraFields.RequestType, schemas.RealtimeRequest)
}
if capturedErr.Error == nil || capturedErr.Error.Message != "realtime WebRTC session closed before turn completed" {
t.Fatalf("error message = %#v, want realtime close message", capturedErr.Error)
}
if session.PeekRealtimeTurnHooks() != nil {
t.Fatal("expected active realtime turn hooks to be cleared")
}
if !cleanedUp {
t.Fatal("expected realtime hook cleanup to run")
}
}
func TestResolveRealtimeWebRTCKeys_UnmappedEphemeralTokenStaysAnonymous(t *testing.T) {
t.Parallel()
store, err := kvstore.New(kvstore.Config{})
if err != nil {
t.Fatalf("kvstore.New() error = %v", err)
}
defer store.Close()
handler := &WebRTCRealtimeHandler{
handlerStore: testHandlerStore{kv: store},
}
var ctx fasthttp.RequestCtx
ctx.Request.Header.Set("Authorization", "Bearer ek_test_unmapped")
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
bifrostCtx.SetValue(schemas.BifrostContextKeyDirectKey, schemas.Key{ID: "header-provided"})
bifrostCtx.SetValue(schemas.BifrostContextKeySelectedKeyID, "selected")
bifrostCtx.SetValue(schemas.BifrostContextKeySelectedKeyName, "selected-name")
bifrostCtx.SetValue(schemas.BifrostContextKeyAPIKeyID, "mapped-id")
bifrostCtx.SetValue(schemas.BifrostContextKeyAPIKeyName, "mapped-name")
authKey, selectedKey, err := handler.resolveRealtimeWebRTCKeys(&ctx, bifrostCtx, schemas.OpenAI, "gpt-realtime")
if err != nil {
t.Fatalf("resolveRealtimeWebRTCKeys() error = %v", err)
}
if got := authKey.Value.GetValue(); got != "ek_test_unmapped" {
t.Fatalf("auth key value = %q, want %q", got, "ek_test_unmapped")
}
if selectedKey != nil {
t.Fatalf("selectedKey = %#v, want nil", selectedKey)
}
if got := bifrostCtx.Value(schemas.BifrostContextKeyDirectKey); got != nil {
t.Fatalf("direct key context = %#v, want nil", got)
}
if got := bifrostCtx.Value(schemas.BifrostContextKeySelectedKeyID); got != nil {
t.Fatalf("selected key id context = %#v, want nil", got)
}
if got := bifrostCtx.Value(schemas.BifrostContextKeySelectedKeyName); got != nil {
t.Fatalf("selected key name context = %#v, want nil", got)
}
if got := bifrostCtx.Value(schemas.BifrostContextKeyAPIKeyID); got != nil {
t.Fatalf("api key id context = %#v, want nil", got)
}
if got := bifrostCtx.Value(schemas.BifrostContextKeyAPIKeyName); got != nil {
t.Fatalf("api key name context = %#v, want nil", got)
}
}
func TestApplyRealtimeEphemeralKeyMapping_RestoresVirtualKeyAndKeyID(t *testing.T) {
t.Parallel()
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
applyRealtimeEphemeralKeyMapping(bifrostCtx, realtimeEphemeralKeyMapping{
KeyID: "key_123",
VirtualKey: "sk-bf-test",
})
if got := bifrostCtx.Value(schemas.BifrostContextKeyVirtualKey); got != "sk-bf-test" {
t.Fatalf("virtual key context = %#v, want %q", got, "sk-bf-test")
}
if got := bifrostCtx.Value(schemas.BifrostContextKeyAPIKeyID); got != "key_123" {
t.Fatalf("api key id context = %#v, want %q", got, "key_123")
}
}

View File

@@ -0,0 +1,268 @@
// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
// This file contains WebSocket handlers for real-time log streaming.
package handlers
import (
"context"
"strings"
"sync"
"time"
"github.com/bytedance/sonic"
"github.com/fasthttp/router"
"github.com/fasthttp/websocket"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
// WebSocketClient represents a connected WebSocket client with its own mutex
type WebSocketClient struct {
conn *websocket.Conn
mu sync.Mutex // Per-connection mutex for thread-safe writes
}
// WebSocketHandler manages WebSocket connections for real-time updates
type WebSocketHandler struct {
ctx context.Context
allowedOrigins []string
clients map[*websocket.Conn]*WebSocketClient
mu sync.RWMutex
stopChan chan struct{} // Channel to signal heartbeat goroutine to stop
done chan struct{} // Channel to signal when heartbeat goroutine has stopped
}
// NewWebSocketHandler creates a new WebSocket handler instance
func NewWebSocketHandler(ctx context.Context, allowedOrigins []string) *WebSocketHandler {
return &WebSocketHandler{
ctx: ctx,
allowedOrigins: allowedOrigins,
clients: make(map[*websocket.Conn]*WebSocketClient),
stopChan: make(chan struct{}),
done: make(chan struct{}),
}
}
// RegisterRoutes registers all WebSocket-related routes
func (h *WebSocketHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
r.GET("/ws", lib.ChainMiddlewares(h.connectStream, middlewares...))
}
// getUpgrader returns a WebSocket upgrader configured with the current allowed origins
func (h *WebSocketHandler) getUpgrader() websocket.FastHTTPUpgrader {
return websocket.FastHTTPUpgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(ctx *fasthttp.RequestCtx) bool {
origin := string(ctx.Request.Header.Peek("Origin"))
if origin == "" {
// If no Origin header, check the Host header for direct connections
host := string(ctx.Request.Header.Peek("Host"))
return isLocalhost(host)
}
// Check if origin is allowed (localhost always allowed + configured origins)
return IsOriginAllowed(origin, h.allowedOrigins)
},
}
}
// isLocalhost checks if the given host is localhost
func isLocalhost(host string) bool {
// Remove port if present
if idx := strings.LastIndex(host, ":"); idx != -1 {
host = host[:idx]
}
// Check for localhost variations
return host == "localhost" ||
host == "127.0.0.1" ||
host == "::1" ||
host == ""
}
// connectStream handles WebSocket connections for real-time streaming
func (h *WebSocketHandler) connectStream(ctx *fasthttp.RequestCtx) {
upgrader := h.getUpgrader()
err := upgrader.Upgrade(ctx, func(ws *websocket.Conn) {
// Read safety & liveness
ws.SetReadLimit(50 << 20) // 50 MiB
ws.SetReadDeadline(time.Now().Add(60 * time.Second))
ws.SetPongHandler(func(string) error {
ws.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
// Create a new client with its own mutex
client := &WebSocketClient{
conn: ws,
}
// Register new client
h.mu.Lock()
h.clients[ws] = client
h.mu.Unlock()
// Clean up on disconnect
defer func() {
h.mu.Lock()
delete(h.clients, ws)
h.mu.Unlock()
ws.Close()
}()
// Keep connection alive and handle client messages
// This loop continuously reads and discards incoming WebSocket messages to:
// 1. Keep the connection alive by processing client pings and control frames
// 2. Detect when the client disconnects by watching for close frames or errors
// 3. Maintain proper WebSocket protocol handling without accumulating messages
for {
_, _, err := ws.ReadMessage()
if err != nil {
// Only log unexpected close errors
if websocket.IsUnexpectedCloseError(err,
websocket.CloseNormalClosure,
websocket.CloseGoingAway,
websocket.CloseAbnormalClosure,
websocket.CloseNoStatusReceived) {
logger.Error("websocket read error: %v", err)
}
break
}
}
})
if err != nil {
logger.Error("websocket upgrade error: %v", err)
return
}
}
// sendMessageSafely sends a message to a client with proper locking and error handling
func (h *WebSocketHandler) sendMessageSafely(client *WebSocketClient, messageType int, data []byte) error {
client.mu.Lock()
defer client.mu.Unlock()
// Set a write deadline to prevent hanging connections
client.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
defer client.conn.SetWriteDeadline(time.Time{}) // Clear the deadline
err := client.conn.WriteMessage(messageType, data)
if err != nil {
// Remove the client from the map if write fails
go func() {
h.mu.Lock()
delete(h.clients, client.conn)
h.mu.Unlock()
client.conn.Close()
}()
}
return err
}
// BroadcastUpdatesToClients sends a store update notification to all connected WebSocket clients
// The tags parameter should match RTK Query tagTypes (e.g., "Providers", "VirtualKeys", "MCPClients")
func (h *WebSocketHandler) BroadcastUpdatesToClients(tags []string) {
message := struct {
Type string `json:"type"`
Tags []string `json:"tags"`
}{
Type: "store_update",
Tags: tags,
}
data, err := sonic.Marshal(message)
if err != nil {
logger.Error("failed to marshal store update: %v", err)
return
}
h.BroadcastMarshaledMessage(data)
}
// BroadcastEvent sends a typed event to all connected WebSocket clients.
// Any subsystem can use this to push real-time updates to the frontend.
func (h *WebSocketHandler) BroadcastEvent(eventType string, data interface{}) {
message := struct {
Type string `json:"type"`
Data interface{} `json:"data"`
}{
Type: eventType,
Data: data,
}
bytes, err := sonic.Marshal(message)
if err != nil {
logger.Error("failed to marshal event %s: %v", eventType, err)
return
}
h.BroadcastMarshaledMessage(bytes)
}
// BroadcastMarshaledMessage sends an adaptive routing update to all connected WebSocket clients
func (h *WebSocketHandler) BroadcastMarshaledMessage(data []byte) {
// Get a snapshot of clients to avoid holding the lock during writes
h.mu.RLock()
clients := make([]*WebSocketClient, 0, len(h.clients))
for _, client := range h.clients {
clients = append(clients, client)
}
h.mu.RUnlock()
// Send message to each client safely
for _, client := range clients {
if err := h.sendMessageSafely(client, websocket.TextMessage, data); err != nil {
logger.Error("failed to send message to client: %v", err)
}
}
}
// StartHeartbeat starts sending periodic heartbeat messages to keep connections alive
func (h *WebSocketHandler) StartHeartbeat() {
ticker := time.NewTicker(30 * time.Second)
go func() {
defer func() {
ticker.Stop()
close(h.done)
}()
for {
select {
case <-h.ctx.Done():
logger.Info("got context cancel(), stopping webserver")
return
case <-ticker.C:
// Get a snapshot of clients to avoid holding the lock during writes
h.mu.RLock()
clients := make([]*WebSocketClient, 0, len(h.clients))
for _, client := range h.clients {
clients = append(clients, client)
}
h.mu.RUnlock()
// Send heartbeat to each client safely
for _, client := range clients {
if err := h.sendMessageSafely(client, websocket.PingMessage, nil); err != nil {
logger.Error("failed to send heartbeat: %v", err)
}
}
case <-h.stopChan:
return
}
}
}()
}
// Stop gracefully shuts down the WebSocket handler
func (h *WebSocketHandler) Stop() {
close(h.stopChan) // Signal heartbeat goroutine to stop
<-h.done // Wait for heartbeat goroutine to finish
// Close all client connections
h.mu.Lock()
for _, client := range h.clients {
client.conn.Close()
}
h.clients = make(map[*websocket.Conn]*WebSocketClient)
h.mu.Unlock()
}

View File

@@ -0,0 +1,102 @@
package handlers
import (
"crypto/rand"
"encoding/hex"
"sync"
"time"
)
const (
wsTicketTTL = 30 * time.Second
wsTicketCleanupHz = 60 * time.Second
)
type wsTicketEntry struct {
sessionToken string
expiresAt time.Time
}
// WSTicketStore provides short-lived, single-use tickets for WebSocket authentication.
// Instead of putting the long-lived session token in the WS URL (visible in logs/history),
// clients exchange their session for a 30-second one-time ticket via an authenticated endpoint.
type WSTicketStore struct {
mu sync.Mutex
tickets map[string]wsTicketEntry
done chan struct{}
stopOnce sync.Once
}
// NewWSTicketStore creates a new ticket store and starts a background goroutine
// that periodically purges expired tickets.
func NewWSTicketStore() *WSTicketStore {
s := &WSTicketStore{
tickets: make(map[string]wsTicketEntry),
done: make(chan struct{}),
}
go s.cleanup()
return s
}
// Issue generates a cryptographically random ticket bound to the given session token.
// The ticket expires after wsTicketTTL (30 seconds).
func (s *WSTicketStore) Issue(sessionToken string) (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
ticket := hex.EncodeToString(b)
s.mu.Lock()
s.tickets[ticket] = wsTicketEntry{
sessionToken: sessionToken,
expiresAt: time.Now().Add(wsTicketTTL),
}
s.mu.Unlock()
return ticket, nil
}
// Consume validates and deletes a ticket, returning the underlying session token.
// Returns empty string if the ticket doesn't exist or has expired (single-use).
func (s *WSTicketStore) Consume(ticket string) string {
s.mu.Lock()
defer s.mu.Unlock()
entry, ok := s.tickets[ticket]
if !ok {
return ""
}
delete(s.tickets, ticket)
if time.Now().After(entry.expiresAt) {
return ""
}
return entry.sessionToken
}
// Stop terminates the background cleanup goroutine.
func (s *WSTicketStore) Stop() {
s.stopOnce.Do(func() {
close(s.done)
})
}
// cleanup periodically removes expired tickets to prevent unbounded memory growth.
func (s *WSTicketStore) cleanup() {
ticker := time.NewTicker(wsTicketCleanupHz)
defer ticker.Stop()
for {
select {
case <-s.done:
return
case <-ticker.C:
now := time.Now()
s.mu.Lock()
for k, v := range s.tickets {
if now.After(v.expiresAt) {
delete(s.tickets, k)
}
}
s.mu.Unlock()
}
}
}

View File

@@ -0,0 +1,666 @@
package handlers
import (
"errors"
"io"
"net"
"strings"
"sync"
"time"
"github.com/fasthttp/router"
ws "github.com/fasthttp/websocket"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/transports/bifrost-http/integrations"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket"
"github.com/valyala/fasthttp"
)
const (
realtimeWSPingInterval = 15 * time.Second
realtimeWSPongTimeout = 45 * time.Second
realtimeWSPingWriteTimeout = 10 * time.Second
realtimeWSWriteTimeout = 30 * time.Second
)
// WSRealtimeHandler handles bidirectional WebSocket proxying for the Realtime API.
type WSRealtimeHandler struct {
client *bifrost.Bifrost
config *lib.Config
handlerStore lib.HandlerStore
pool *bfws.Pool
sessions *bfws.SessionManager
}
// NewWSRealtimeHandler creates a new Realtime WebSocket handler.
func NewWSRealtimeHandler(client *bifrost.Bifrost, config *lib.Config, pool *bfws.Pool) *WSRealtimeHandler {
maxConns := config.WebSocketConfig.MaxConnections
return &WSRealtimeHandler{
client: client,
config: config,
handlerStore: config,
pool: pool,
sessions: bfws.NewSessionManager(maxConns),
}
}
// RegisterRoutes registers the Realtime WebSocket endpoint at the base path and OpenAI integration paths.
func (h *WSRealtimeHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
handler := lib.ChainMiddlewares(h.handleUpgrade, middlewares...)
r.GET("/v1/realtime", handler)
for _, path := range integrations.OpenAIRealtimePaths("/openai") {
r.GET(path, handler)
}
}
func (h *WSRealtimeHandler) Close() {
if h == nil || h.sessions == nil {
return
}
h.sessions.CloseAll()
}
func (h *WSRealtimeHandler) handleUpgrade(ctx *fasthttp.RequestCtx) {
path := string(ctx.Path())
modelParam := string(ctx.QueryArgs().Peek("model"))
deploymentParam := string(ctx.QueryArgs().Peek("deployment"))
auth := captureAuthHeaders(ctx)
// OpenAI's SDK sends the API key via WebSocket subprotocol: "openai-insecure-api-key.<key>".
// Extract it into the auth headers so downstream processing recognizes it.
if auth.authorization == "" {
if token := extractRealtimeSubprotocolAPIKey(ctx); token != "" {
auth.authorization = "Bearer " + token
}
}
providerKey, model, err := resolveRealtimeTarget(path, modelParam, deploymentParam)
if err != nil {
upgrader := h.websocketUpgrader("")
upgradeErr := upgrader.Upgrade(ctx, func(conn *ws.Conn) {
defer conn.Close()
clientConn := newRealtimeClientConn(conn)
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", err.Error()))
})
if upgradeErr != nil {
logger.Warn("websocket upgrade failed for %s: %v", path, upgradeErr)
}
return
}
provider := h.client.GetProviderByKey(providerKey)
rtProvider, ok := provider.(schemas.RealtimeProvider)
if provider == nil || !ok || !rtProvider.SupportsRealtimeAPI() {
upgrader := h.websocketUpgrader("")
upgradeErr := upgrader.Upgrade(ctx, func(conn *ws.Conn) {
defer conn.Close()
clientConn := newRealtimeClientConn(conn)
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "provider does not support realtime: "+string(providerKey)))
})
if upgradeErr != nil {
logger.Warn("websocket upgrade failed for %s: %v", path, upgradeErr)
}
return
}
upgrader := h.websocketUpgrader(rtProvider.RealtimeWebSocketSubprotocol())
err = upgrader.Upgrade(ctx, func(conn *ws.Conn) {
defer conn.Close()
clientConn := newRealtimeClientConn(conn)
session, sessionErr := h.sessions.Create(conn)
if sessionErr != nil {
clientConn.writeRealtimeError(newRealtimeWireBifrostError(429, "rate_limit_exceeded", sessionErr.Error()))
return
}
defer h.sessions.Remove(conn)
h.runRealtimeSession(clientConn, session, auth, path, providerKey, model)
})
if err != nil {
logger.Warn("websocket upgrade failed for %s: %v", path, err)
}
}
func (h *WSRealtimeHandler) websocketUpgrader(subprotocol string) ws.FastHTTPUpgrader {
upgrader := ws.FastHTTPUpgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(ctx *fasthttp.RequestCtx) bool {
origin := string(ctx.Request.Header.Peek("Origin"))
if origin == "" {
return true
}
return IsOriginAllowed(origin, h.config.ClientConfig.AllowedOrigins)
},
}
if strings.TrimSpace(subprotocol) != "" {
upgrader.Subprotocols = []string{subprotocol}
}
return upgrader
}
func (h *WSRealtimeHandler) runRealtimeSession(
clientConn *realtimeClientConn,
session *bfws.Session,
auth *authHeaders,
path string,
providerKey schemas.ModelProvider,
model string,
) {
clientConn.startHeartbeat()
defer clientConn.stopHeartbeat()
bifrostCtx, cancel := createBifrostContextFromAuth(h.handlerStore, auth)
if bifrostCtx == nil {
clientConn.writeRealtimeError(newRealtimeWireBifrostError(500, "server_error", "failed to create request context"))
return
}
defer cancel()
// Resolve ephemeral key mapping to restore virtual key context.
token := extractRealtimeBearerTokenFromHeader(auth.authorization)
if isRealtimeEphemeralToken(token) {
mapping, ok := lookupRealtimeEphemeralKeyMapping(h.handlerStore.GetKVStore(), token)
if ok {
applyRealtimeEphemeralKeyMapping(bifrostCtx, mapping)
}
}
bifrostCtx.SetValue(schemas.BifrostContextKeyHTTPRequestType, schemas.RealtimeRequest)
if strings.HasPrefix(path, "/openai") {
bifrostCtx.SetValue(schemas.BifrostContextKeyIntegrationType, "openai")
}
provider := h.client.GetProviderByKey(providerKey)
if provider == nil {
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "provider not found: "+string(providerKey)))
return
}
rtProvider, ok := provider.(schemas.RealtimeProvider)
if !ok || !rtProvider.SupportsRealtimeAPI() {
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "provider does not support realtime: "+string(providerKey)))
return
}
key, err := h.client.SelectKeyForProviderRequestType(bifrostCtx, schemas.RealtimeRequest, providerKey, model)
if err != nil {
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", err.Error()))
return
}
// Resolve model alias so the provider receives the actual model identifier.
model = key.Aliases.Resolve(model)
wsURL := rtProvider.RealtimeWebSocketURL(key, model)
upstream, err := h.pool.Get(bfws.PoolKey{
Provider: providerKey,
KeyID: key.ID,
Endpoint: wsURL,
}, rtProvider.RealtimeHeaders(key))
if err != nil {
clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", err.Error()))
return
}
defer h.pool.Discard(upstream)
errCh := make(chan error, 2)
go func() {
errCh <- h.relayClientToRealtimeProvider(clientConn, session, upstream, rtProvider, bifrostCtx, providerKey, model, key)
}()
go func() {
errCh <- h.relayRealtimeProviderToClient(clientConn, session, upstream, rtProvider, bifrostCtx, providerKey, model, key)
}()
firstErr := <-errCh
_ = upstream.Close()
_ = clientConn.Close()
secondErr := <-errCh
if logErr := selectRealtimeRelayError(firstErr, secondErr); logErr != nil {
logger.Warn("realtime websocket relay ended for %s/%s on %s: %v", providerKey, model, path, logErr)
}
}
func (h *WSRealtimeHandler) relayClientToRealtimeProvider(
clientConn *realtimeClientConn,
session *bfws.Session,
upstream *bfws.UpstreamConn,
provider schemas.RealtimeProvider,
bifrostCtx *schemas.BifrostContext,
providerKey schemas.ModelProvider,
model string,
key schemas.Key,
) error {
for {
messageType, message, err := clientConn.ReadMessage()
if err != nil {
finalizeRealtimeTurnHooksOnTransportError(
h.client,
bifrostCtx,
session,
providerKey,
model,
&key,
499,
"client_closed_request",
"client realtime websocket disconnected before turn completed",
)
if isNormalWebSocketClosure(err) {
return nil
}
return err
}
if messageType != ws.TextMessage {
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "realtime websocket only accepts text messages"))
return nil
}
event, err := schemas.ParseRealtimeEvent(message)
if err != nil {
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "failed to parse realtime event JSON"))
continue
}
// Extract pending tool/input summaries but defer recording until the event
// passes validation — rejected events must not pollute session state.
toolItemID, toolSummary := pendingRealtimeToolOutputUpdate(event)
inputItemID, inputSummary := pendingRealtimeInputUpdate(event)
startsTurn := provider.ShouldStartRealtimeTurn(event)
if startsTurn {
if session.PeekRealtimeTurnHooks() != nil {
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "Conversation already has an active response in progress."))
continue
}
if toolSummary != "" {
session.RecordRealtimeToolOutput(toolItemID, toolSummary, string(message))
}
if inputSummary != "" {
session.RecordRealtimeInput(inputItemID, inputSummary, string(message))
}
if bifrostErr := startRealtimeTurnHooks(h.client, bifrostCtx, session, provider, providerKey, model, &key, event.Type); bifrostErr != nil {
clientConn.writeRealtimeError(bifrostErr)
return nil
}
}
providerEvent, err := provider.ToProviderRealtimeEvent(event)
if err != nil {
if startsTurn {
if finalizeErr := finalizeRealtimeTurnHooksWithError(
h.client,
bifrostCtx,
session,
providerKey,
model,
&key,
schemas.RTEventError,
nil,
newRealtimeWireBifrostError(400, "invalid_request_error", err.Error()),
); finalizeErr != nil {
clientConn.writeRealtimeError(finalizeErr)
return nil
}
}
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", err.Error()))
continue
}
// Record tool output / input only after the event passed validation.
if !startsTurn {
if toolSummary != "" {
session.RecordRealtimeToolOutput(toolItemID, toolSummary, string(message))
}
if inputSummary != "" {
session.RecordRealtimeInput(inputItemID, inputSummary, string(message))
}
}
if err := upstream.WriteMessage(ws.TextMessage, providerEvent); err != nil {
finalizeRealtimeTurnHooksWithError(
h.client,
bifrostCtx,
session,
providerKey,
model,
&key,
schemas.RTEventError,
nil,
newRealtimeWireBifrostError(502, "server_error", "failed to write realtime event upstream"),
)
clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "failed to write realtime event upstream"))
return err
}
}
}
func (h *WSRealtimeHandler) relayRealtimeProviderToClient(
clientConn *realtimeClientConn,
session *bfws.Session,
upstream *bfws.UpstreamConn,
provider schemas.RealtimeProvider,
bifrostCtx *schemas.BifrostContext,
providerKey schemas.ModelProvider,
model string,
key schemas.Key,
) error {
for {
disconnectAfterWrite := false
messageType, message, err := upstream.ReadMessage()
if err != nil {
finalizeRealtimeTurnHooksOnTransportError(
h.client,
bifrostCtx,
session,
providerKey,
model,
&key,
502,
"upstream_connection_error",
"upstream realtime websocket closed before turn completed",
)
if isNormalWebSocketClosure(err) {
return nil
}
finalizeRealtimeTurnHooksWithError(
h.client,
bifrostCtx,
session,
providerKey,
model,
&key,
schemas.RTEventError,
nil,
newRealtimeWireBifrostError(502, "server_error", "upstream realtime websocket stream interrupted"),
)
clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "upstream realtime websocket stream interrupted"))
return err
}
if messageType == ws.TextMessage {
event, err := provider.ToBifrostRealtimeEvent(message)
if err != nil {
finalizeRealtimeTurnHooksWithError(
h.client,
bifrostCtx,
session,
providerKey,
model,
&key,
schemas.RTEventError,
message,
newRealtimeWireBifrostError(502, "server_error", "failed to translate upstream realtime event"),
)
clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "failed to translate upstream realtime event"))
return err
}
if event != nil {
if event.Session != nil && event.Session.ID != "" {
session.SetProviderSessionID(event.Session.ID)
}
if event.Delta != nil && provider.ShouldAccumulateRealtimeOutput(event.Type) {
session.AppendRealtimeOutputText(event.Delta.Text)
session.AppendRealtimeOutputText(event.Delta.Transcript)
}
if provider.ShouldStartRealtimeTurn(event) && session.PeekRealtimeTurnHooks() == nil {
if bifrostErr := startRealtimeTurnHooks(h.client, bifrostCtx, session, provider, providerKey, model, &key, event.Type); bifrostErr != nil {
clientConn.writeRealtimeError(bifrostErr)
return nil
}
}
}
if event != nil {
inputItemID, inputSummary := pendingRealtimeInputUpdate(event)
if !provider.ShouldForwardRealtimeEvent(event) {
continue
}
if event.Type == provider.RealtimeTurnFinalEvent() {
contentOverride := session.ConsumeRealtimeOutputText()
if bifrostErr := finalizeRealtimeTurnHooks(h.client, bifrostCtx, session, provider, providerKey, model, &key, message, contentOverride); bifrostErr != nil {
clientConn.writeRealtimeError(bifrostErr)
return nil
}
} else if event.Error != nil {
turnErr := newBifrostErrorFromRealtimeError(providerKey, model, message, event.Error)
finalizeErr := finalizeRealtimeTurnHooksWithError(
h.client,
bifrostCtx,
session,
providerKey,
model,
&key,
event.Type,
message,
turnErr,
)
if finalizeErr != nil {
clientConn.writeRealtimeError(finalizeErr)
return nil
}
// Defer the disconnect so the normal translated-write path
// below still runs — otherwise terminal errors from translated
// providers would reach the client in provider-native format.
disconnectAfterWrite = shouldGracefullyDisconnectRealtime(turnErr)
} else if inputSummary != "" {
session.RecordRealtimeInput(inputItemID, inputSummary, string(message))
}
if len(event.RawData) == 0 {
message, err = provider.ToProviderRealtimeEvent(event)
if err != nil {
clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "failed to encode translated realtime event"))
return err
}
}
}
}
if err := clientConn.WriteMessage(messageType, message); err != nil {
finalizeRealtimeTurnHooksOnTransportError(
h.client,
bifrostCtx,
session,
providerKey,
model,
&key,
499,
"client_closed_request",
"client realtime websocket disconnected before turn completed",
)
if isNormalWebSocketClosure(err) {
return nil
}
return err
}
if disconnectAfterWrite {
return nil
}
}
}
func resolveRealtimeTarget(path, modelParam, deploymentParam string) (schemas.ModelProvider, string, error) {
defaultProvider := realtimeDefaultProviderForPath(path)
switch {
case strings.TrimSpace(modelParam) != "":
provider, model := schemas.ParseModelString(strings.TrimSpace(modelParam), defaultProvider)
if provider == "" || strings.TrimSpace(model) == "" {
return "", "", errRealtimeModelFormat
}
return provider, strings.TrimSpace(model), nil
case strings.TrimSpace(deploymentParam) != "":
provider, model := schemas.ParseModelString(strings.TrimSpace(deploymentParam), defaultProvider)
if provider == "" || strings.TrimSpace(model) == "" {
return "", "", errRealtimeDeploymentFormat
}
return provider, strings.TrimSpace(model), nil
default:
return "", "", errRealtimeModelRequired
}
}
func realtimeDefaultProviderForPath(path string) schemas.ModelProvider {
if strings.HasPrefix(path, "/openai/") {
return schemas.OpenAI
}
return ""
}
func isNormalWebSocketClosure(err error) bool {
return ws.IsCloseError(err, ws.CloseNormalClosure, ws.CloseGoingAway, ws.CloseNoStatusReceived)
}
func isExpectedRealtimeRelayShutdown(err error) bool {
if err == nil {
return true
}
if isNormalWebSocketClosure(err) || errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
return true
}
// Relay teardown closes the opposite socket after the first side exits, which can
// surface as a plain network-close read error instead of a websocket close frame.
return strings.Contains(err.Error(), "use of closed network connection")
}
func selectRealtimeRelayError(errs ...error) error {
for _, err := range errs {
if err != nil && !isExpectedRealtimeRelayShutdown(err) {
return err
}
}
return nil
}
var (
errRealtimeModelRequired = errorf("model or deployment query parameter is required for realtime websocket")
errRealtimeModelFormat = errorf("model query parameter must resolve to provider/model for realtime websocket")
errRealtimeDeploymentFormat = errorf("deployment query parameter must resolve to provider/model for realtime websocket")
)
type realtimeClientConn struct {
conn *ws.Conn
writeMu sync.Mutex
closeOnce sync.Once
done chan struct{}
}
func newRealtimeClientConn(conn *ws.Conn) *realtimeClientConn {
return &realtimeClientConn{
conn: conn,
done: make(chan struct{}),
}
}
func (c *realtimeClientConn) ReadMessage() (messageType int, p []byte, err error) {
messageType, p, err = c.conn.ReadMessage()
if err == nil {
c.refreshReadDeadline()
}
return messageType, p, err
}
func (c *realtimeClientConn) WriteMessage(messageType int, data []byte) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
if err := c.conn.SetWriteDeadline(time.Now().Add(realtimeWSWriteTimeout)); err != nil {
return err
}
if err := c.conn.WriteMessage(messageType, data); err != nil {
return err
}
return c.conn.SetWriteDeadline(time.Time{})
}
func (c *realtimeClientConn) startHeartbeat() {
c.installPongHandler()
c.refreshReadDeadline()
go func() {
ticker := time.NewTicker(realtimeWSPingInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := c.writePing(); err != nil {
_ = c.Close()
return
}
case <-c.done:
return
}
}
}()
}
func (c *realtimeClientConn) stopHeartbeat() {
c.closeDone()
}
func (c *realtimeClientConn) installPongHandler() {
c.conn.SetPongHandler(func(string) error {
return c.refreshReadDeadline()
})
}
func (c *realtimeClientConn) refreshReadDeadline() error {
return c.conn.SetReadDeadline(time.Now().Add(realtimeWSPongTimeout))
}
func (c *realtimeClientConn) writePing() error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
if err := c.conn.SetWriteDeadline(time.Now().Add(realtimeWSPingWriteTimeout)); err != nil {
return err
}
if err := c.conn.WriteMessage(ws.PingMessage, nil); err != nil {
return err
}
return c.conn.SetWriteDeadline(time.Time{})
}
func (c *realtimeClientConn) closeDone() {
c.closeOnce.Do(func() {
close(c.done)
})
}
func (c *realtimeClientConn) writeRealtimeError(bifrostErr *schemas.BifrostError) {
payload := newRealtimeTurnErrorEventPayload(bifrostErr)
_ = c.WriteMessage(ws.TextMessage, payload)
}
func (c *realtimeClientConn) Close() error {
c.closeDone()
return c.conn.Close()
}
const realtimeSubprotocolAPIKeyPrefix = "openai-insecure-api-key."
// extractRealtimeSubprotocolAPIKey extracts an API key from the Sec-WebSocket-Protocol
// header. The OpenAI SDK sends: "realtime, openai-insecure-api-key.<key>".
func extractRealtimeSubprotocolAPIKey(ctx *fasthttp.RequestCtx) string {
header := string(ctx.Request.Header.Peek("Sec-WebSocket-Protocol"))
for _, proto := range strings.Split(header, ",") {
proto = strings.TrimSpace(proto)
if strings.HasPrefix(proto, realtimeSubprotocolAPIKeyPrefix) {
return strings.TrimPrefix(proto, realtimeSubprotocolAPIKeyPrefix)
}
}
return ""
}
func newRealtimeWireBifrostError(status int, code, message string) *schemas.BifrostError {
errType := code
return &schemas.BifrostError{
StatusCode: &status,
Type: &errType,
Error: &schemas.ErrorField{
Type: &errType,
Code: &errType,
Message: message,
},
}
}

View File

@@ -0,0 +1,702 @@
package handlers
import (
"context"
"strings"
"github.com/bytedance/sonic"
"github.com/fasthttp/router"
ws "github.com/fasthttp/websocket"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/transports/bifrost-http/integrations"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket"
"github.com/valyala/fasthttp"
)
// wsWriter abstracts a WebSocket write target. Both *ws.Conn (pre-session)
// and *bfws.Session (post-session, mutex-protected) satisfy this interface.
type wsWriter interface {
WriteMessage(messageType int, data []byte) error
}
// WSResponsesHandler handles WebSocket connections for the Responses API WebSocket Mode.
// Clients connect via `GET /v1/responses` with a WS upgrade and send `response.create` events.
// Each event is routed through the standard Bifrost inference pipeline (PreLLMHook, key selection,
// provider call, PostLLMHook) via the HTTP bridge, with native WS upstream as an optimization.
type WSResponsesHandler struct {
client *bifrost.Bifrost
config *lib.Config
handlerStore lib.HandlerStore
pool *bfws.Pool
sessions *bfws.SessionManager
upgrader ws.FastHTTPUpgrader
}
// NewWSResponsesHandler creates a new WebSocket Responses handler.
func NewWSResponsesHandler(client *bifrost.Bifrost, config *lib.Config, pool *bfws.Pool) *WSResponsesHandler {
maxConns := config.WebSocketConfig.MaxConnections
return &WSResponsesHandler{
client: client,
config: config,
handlerStore: config,
pool: pool,
sessions: bfws.NewSessionManager(maxConns),
upgrader: ws.FastHTTPUpgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(ctx *fasthttp.RequestCtx) bool {
origin := string(ctx.Request.Header.Peek("Origin"))
if origin == "" {
return true
}
return IsOriginAllowed(origin, config.ClientConfig.AllowedOrigins)
},
},
}
}
// Close gracefully shuts down all active WebSocket responses sessions.
func (h *WSResponsesHandler) Close() {
if h == nil || h.sessions == nil {
return
}
h.sessions.CloseAll()
}
// RegisterRoutes registers the WebSocket Responses endpoint at the base path
// and all OpenAI integration paths.
func (h *WSResponsesHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
handler := lib.ChainMiddlewares(h.handleUpgrade, middlewares...)
// Base path (outside integration prefix)
r.GET("/v1/responses", handler)
// OpenAI integration paths (/openai/v1/responses, /openai/responses, /openai/openai/responses)
for _, path := range integrations.OpenAIWSResponsesPaths("/openai") {
r.GET(path, handler)
}
}
// handleUpgrade upgrades the HTTP connection to WebSocket and starts the event loop.
func (h *WSResponsesHandler) handleUpgrade(ctx *fasthttp.RequestCtx) {
err := h.upgrader.Upgrade(ctx, func(conn *ws.Conn) {
defer conn.Close()
session, sessionErr := h.sessions.Create(conn)
if sessionErr != nil {
writeWSError(conn, 429, "websocket_connection_limit_reached", sessionErr.Error())
return
}
defer h.sessions.Remove(conn)
// Capture auth headers from the upgrade request for per-event context creation
authHeaders := captureAuthHeaders(ctx)
h.eventLoop(conn, session, authHeaders)
})
if err != nil {
logger.Warn("websocket upgrade failed for /v1/responses: %v", err)
}
}
// authHeaders holds auth-related headers captured during the WS upgrade.
type authHeaders struct {
authorization string
virtualKey string
apiKey string
googAPIKey string
baggage string
extraHeaders map[string]string
}
// captureAuthHeaders captures the auth headers from the request.
func captureAuthHeaders(ctx *fasthttp.RequestCtx) *authHeaders {
ah := &authHeaders{
authorization: string(ctx.Request.Header.Peek("Authorization")),
virtualKey: string(ctx.Request.Header.Peek("x-bf-vk")),
apiKey: string(ctx.Request.Header.Peek("x-api-key")),
googAPIKey: string(ctx.Request.Header.Peek("x-goog-api-key")),
baggage: string(ctx.Request.Header.Peek("baggage")),
extraHeaders: make(map[string]string),
}
for key, value := range ctx.Request.Header.All() {
k := string(key)
lk := strings.ToLower(k)
if strings.HasPrefix(lk, "x-bf-") {
ah.extraHeaders[k] = string(value)
}
}
return ah
}
// eventLoop reads events from the client WebSocket and processes them.
func (h *WSResponsesHandler) eventLoop(conn *ws.Conn, session *bfws.Session, auth *authHeaders) {
for {
_, message, err := conn.ReadMessage()
if err != nil {
if ws.IsUnexpectedCloseError(err, ws.CloseGoingAway, ws.CloseNormalClosure) {
logger.Warn("websocket read error: %v", err)
}
return
}
// Parse the event type
var envelope struct {
Type string `json:"type"`
}
if err := sonic.Unmarshal(message, &envelope); err != nil {
writeWSError(session, 400, "invalid_request_error", "failed to parse event JSON")
continue
}
switch schemas.WebSocketEventType(envelope.Type) {
case schemas.WSEventResponseCreate:
h.handleResponseCreate(session, auth, message)
default:
writeWSError(session, 400, "invalid_request_error", "unsupported event type: "+envelope.Type)
}
}
}
// handleResponseCreate processes a response.create event.
// Strategy: try native WS upstream for providers that support it, otherwise use HTTP bridge.
// If native WS upstream fails mid-stream, falls back to HTTP bridge.
func (h *WSResponsesHandler) handleResponseCreate(session *bfws.Session, auth *authHeaders, message []byte) {
var event schemas.WebSocketResponsesEvent
if err := sonic.Unmarshal(message, &event); err != nil {
writeWSError(session, 400, "invalid_request_error", "failed to parse response.create event")
return
}
// Store override: default to store=true (Codex sends false by default but expects true).
// If DisableStore is set in provider config, force store=false.
// If client explicitly sets store, respect that value unless DisableStore overrides it.
provider, modelName := schemas.ParseModelString(event.Model, "")
if provider == "" || modelName == "" {
writeWSError(session, 400, "invalid_request_error", "failed to parse model string")
return
}
if providerCfg, cfgErr := h.config.GetProviderConfigRaw(provider); cfgErr == nil &&
providerCfg.OpenAIConfig != nil && providerCfg.OpenAIConfig.DisableStore {
event.Store = schemas.Ptr(false)
} else {
event.Store = schemas.Ptr(true)
}
bifrostReq, err := h.convertEventToRequest(&event)
if err != nil {
writeWSError(session, 400, "invalid_request_error", err.Error())
return
}
// Extract extra params (unknown fields) and forward them, matching the HTTP path behavior
extraParams, extractErr := extractExtraParams(message, wsResponsesKnownFields)
if extractErr == nil && len(extraParams) > 0 {
if bifrostReq.Params == nil {
bifrostReq.Params = &schemas.ResponsesParameters{}
}
bifrostReq.Params.ExtraParams = extraParams
}
bifrostCtx, cancel := createBifrostContextFromAuth(h.handlerStore, auth)
if bifrostCtx == nil {
writeWSError(session, 500, "server_error", "failed to create request context")
return
}
// Try native WS upstream first
if h.tryNativeWSUpstream(session, bifrostCtx, bifrostReq, message) {
cancel()
return
}
// Fall back to HTTP bridge
h.executeHTTPBridge(session, bifrostCtx, cancel, bifrostReq)
}
// tryNativeWSUpstream attempts to forward the event to a native WS upstream connection.
// Returns true if the event was handled (successfully or with error sent to client).
// Returns false if the provider doesn't support WS and we should fall back to HTTP bridge.
func (h *WSResponsesHandler) tryNativeWSUpstream(
session *bfws.Session,
ctx *schemas.BifrostContext,
req *schemas.BifrostResponsesRequest,
rawEvent []byte,
) bool {
provider := h.client.GetProviderByKey(req.Provider)
if provider == nil {
return false
}
wsProvider, ok := provider.(schemas.WebSocketCapableProvider)
if !ok || !wsProvider.SupportsWebSocketMode() {
return false
}
key, err := h.client.SelectKeyForProviderRequestType(ctx, schemas.WebSocketResponsesRequest, req.Provider, req.Model)
if err != nil {
writeWSError(session, 400, "invalid_request_error", err.Error())
return true
}
wsURL := wsProvider.WebSocketResponsesURL(key)
upstream := session.Upstream()
// Validate the pinned upstream matches the current request's provider/key
if upstream != nil && !upstream.IsClosed() &&
(upstream.Provider() != req.Provider || upstream.KeyID() != key.ID) {
h.pool.Discard(upstream)
session.SetUpstream(nil)
upstream = nil
}
// If no upstream connection pinned, get one from the pool or dial
if upstream == nil || upstream.IsClosed() {
headers := wsProvider.WebSocketHeaders(key)
poolKey := bfws.PoolKey{
Provider: req.Provider,
KeyID: key.ID,
Endpoint: wsURL,
}
upstream, err = h.pool.Get(poolKey, headers)
if err != nil {
logger.Warn("failed to get upstream WS connection for %s: %v, falling back to HTTP bridge", req.Provider, err)
return false
}
session.SetUpstream(upstream)
}
// Run plugin pre-hooks before forwarding to upstream
bifrostReq := &schemas.BifrostRequest{
RequestType: schemas.WebSocketResponsesRequest,
ResponsesRequest: req,
}
hooks, preErr := h.client.RunStreamPreHooks(ctx, bifrostReq)
if preErr != nil {
writeWSBifrostError(session, preErr)
return true
}
defer hooks.Cleanup()
// If a plugin short-circuited with a cached response, write it and skip upstream
if hooks.ShortCircuitResponse != nil {
writeWSShortCircuitResponse(session, hooks.ShortCircuitResponse)
return true
}
// Forward the raw event to upstream
if err := upstream.WriteMessage(ws.TextMessage, rawEvent); err != nil {
logger.Warn("upstream WS write failed for %s: %v, falling back to HTTP bridge", req.Provider, err)
h.pool.Discard(upstream)
session.SetUpstream(nil)
return false
}
// Retrieve tracer and traceID for chunk accumulation
tracer, _ := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer)
traceID, _ := ctx.Value(schemas.BifrostContextKeyTraceID).(string)
// Read response events from upstream and relay to client, running post-hooks per chunk
forwardedAny := false
for {
msgType, data, readErr := upstream.ReadMessage()
if readErr != nil {
logger.Warn("upstream WS read failed for %s: %v, falling back to HTTP bridge", req.Provider, readErr)
h.pool.Discard(upstream)
session.SetUpstream(nil)
if !forwardedAny {
return false
}
writeWSError(session, 502, "upstream_connection_error", "upstream websocket stream interrupted")
return true
}
streamResp := parseUpstreamWSEvent(data, req.Provider, req.Model)
isTerminal := streamResp != nil && isTerminalStreamType(streamResp.Type)
if isTerminal {
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
}
if streamResp != nil {
resp := &schemas.BifrostResponse{ResponsesStreamResponse: streamResp}
if tracer != nil && traceID != "" {
tracer.AddStreamingChunk(traceID, resp)
}
_, postErr := hooks.PostHookRunner(ctx, resp, nil)
if postErr != nil {
h.pool.Discard(upstream)
session.SetUpstream(nil)
writeWSBifrostError(session, postErr)
return true
}
}
if writeErr := session.WriteMessage(msgType, data); writeErr != nil {
h.pool.Discard(upstream)
session.SetUpstream(nil)
return true
}
forwardedAny = true
if isTerminal {
h.trackResponseID(session, data)
return true
}
}
}
// writeWSShortCircuitResponse writes a short-circuited plugin response as WS events.
func writeWSShortCircuitResponse(session *bfws.Session, resp *schemas.BifrostResponse) {
if resp.ResponsesResponse != nil {
data, err := sonic.Marshal(resp.ResponsesResponse)
if err != nil {
return
}
if err := session.WriteMessage(ws.TextMessage, data); err != nil {
return
}
if resp.ResponsesResponse.ID != nil && *resp.ResponsesResponse.ID != "" {
session.SetLastResponseID(*resp.ResponsesResponse.ID)
}
} else if resp.ResponsesStreamResponse != nil {
data, err := sonic.Marshal(resp.ResponsesStreamResponse)
if err != nil {
return
}
session.WriteMessage(ws.TextMessage, data)
}
}
// parseUpstreamWSEvent attempts to parse a raw upstream WS event into a BifrostResponsesStreamResponse.
// It populates ExtraFields so downstream plugins (logging, tracing) can identify the request type.
// Returns nil if the data cannot be parsed (non-fatal, the raw bytes are still relayed).
func parseUpstreamWSEvent(data []byte, provider schemas.ModelProvider, model string) *schemas.BifrostResponsesStreamResponse {
var streamResp schemas.BifrostResponsesStreamResponse
if err := sonic.Unmarshal(data, &streamResp); err != nil {
return nil
}
if streamResp.Type == "" {
return nil
}
streamResp.ExtraFields.RequestType = schemas.ResponsesStreamRequest
streamResp.ExtraFields.Provider = provider
streamResp.ExtraFields.OriginalModelRequested = model
return &streamResp
}
// isTerminalStreamType returns true if the event type signals the end of a response stream.
func isTerminalStreamType(t schemas.ResponsesStreamResponseType) bool {
switch t {
case schemas.ResponsesStreamResponseTypeCompleted,
schemas.ResponsesStreamResponseTypeFailed,
schemas.ResponsesStreamResponseTypeIncomplete,
schemas.ResponsesStreamResponseTypeError:
return true
}
return false
}
// trackResponseID extracts and stores the response ID from terminal events.
func (h *WSResponsesHandler) trackResponseID(session *bfws.Session, data []byte) {
var envelope struct {
Response struct {
ID string `json:"id"`
} `json:"response"`
}
if err := sonic.Unmarshal(data, &envelope); err == nil && envelope.Response.ID != "" {
session.SetLastResponseID(envelope.Response.ID)
}
}
// convertEventToRequest converts a WebSocket response.create event to a BifrostResponsesRequest.
func (h *WSResponsesHandler) convertEventToRequest(event *schemas.WebSocketResponsesEvent) (*schemas.BifrostResponsesRequest, error) {
provider, modelName := schemas.ParseModelString(event.Model, "")
if provider == "" || modelName == "" {
return nil, errModelFormat
}
var input []schemas.ResponsesMessage
if event.Input != nil {
// Try parsing as array first
if err := sonic.Unmarshal(event.Input, &input); err != nil {
// Try as string
var inputStr string
if strErr := sonic.Unmarshal(event.Input, &inputStr); strErr != nil {
return nil, errInputRequired
}
input = []schemas.ResponsesMessage{
{
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
Content: &schemas.ResponsesMessageContent{ContentStr: &inputStr},
},
}
}
}
if len(input) == 0 {
return nil, errInputRequired
}
params := &schemas.ResponsesParameters{}
if event.Temperature != nil {
params.Temperature = event.Temperature
}
if event.TopP != nil {
params.TopP = event.TopP
}
if event.MaxOutputTokens != nil {
params.MaxOutputTokens = event.MaxOutputTokens
}
if event.Instructions != "" {
params.Instructions = &event.Instructions
}
if event.PreviousResponseID != "" {
params.PreviousResponseID = &event.PreviousResponseID
}
if event.Store != nil {
params.Store = event.Store
}
if event.Tools != nil {
var tools []schemas.ResponsesTool
if err := sonic.Unmarshal(event.Tools, &tools); err == nil {
params.Tools = tools
}
}
if event.ToolChoice != nil {
var tc schemas.ResponsesToolChoice
if err := sonic.Unmarshal(event.ToolChoice, &tc); err == nil {
params.ToolChoice = &tc
}
}
if event.Reasoning != nil {
var reasoning schemas.ResponsesParametersReasoning
if err := sonic.Unmarshal(event.Reasoning, &reasoning); err == nil {
params.Reasoning = &reasoning
}
}
if event.Text != nil {
var text schemas.ResponsesTextConfig
if err := sonic.Unmarshal(event.Text, &text); err == nil {
params.Text = &text
}
}
if event.Metadata != nil {
var metadata map[string]any
if err := sonic.Unmarshal(event.Metadata, &metadata); err == nil {
params.Metadata = &metadata
}
}
if event.Truncation != "" {
params.Truncation = &event.Truncation
}
return &schemas.BifrostResponsesRequest{
Provider: schemas.ModelProvider(provider),
Model: modelName,
Input: input,
Params: params,
}, nil
}
// createBifrostContextFromAuth builds a BifrostContext from the auth headers captured during upgrade.
func createBifrostContextFromAuth(handlerStore lib.HandlerStore, auth *authHeaders) (*schemas.BifrostContext, context.CancelFunc) {
ctx, cancel := schemas.NewBifrostContextWithCancel(context.Background())
if sessionID := lib.ParseSessionIDFromBaggage(auth.baggage); sessionID != "" {
ctx.SetValue(schemas.BifrostContextKeyParentRequestID, sessionID)
}
if auth.virtualKey != "" {
ctx.SetValue(schemas.BifrostContextKeyVirtualKey, auth.virtualKey)
}
// Handle Bearer token with sk-bf- prefix (virtual key via Authorization header)
if auth.authorization != "" {
if strings.HasPrefix(auth.authorization, "Bearer ") {
token := strings.TrimPrefix(auth.authorization, "Bearer ")
if strings.HasPrefix(token, "sk-bf-") {
ctx.SetValue(schemas.BifrostContextKeyVirtualKey, strings.TrimPrefix(token, "sk-bf-"))
} else if handlerStore.ShouldAllowDirectKeys() {
key := schemas.Key{
ID: "header-provided",
Value: *schemas.NewEnvVar(token),
Models: schemas.WhiteList{"*"},
Weight: 1.0,
}
ctx.SetValue(schemas.BifrostContextKeyDirectKey, key)
}
}
}
if auth.apiKey != "" {
if strings.HasPrefix(auth.apiKey, "sk-bf-") {
ctx.SetValue(schemas.BifrostContextKeyVirtualKey, strings.TrimPrefix(auth.apiKey, "sk-bf-"))
} else if handlerStore.ShouldAllowDirectKeys() {
key := schemas.Key{
ID: "header-provided",
Value: *schemas.NewEnvVar(auth.apiKey),
Models: schemas.WhiteList{"*"},
Weight: 1.0,
}
ctx.SetValue(schemas.BifrostContextKeyDirectKey, key)
}
}
if auth.googAPIKey != "" {
if strings.HasPrefix(auth.googAPIKey, "sk-bf-") {
ctx.SetValue(schemas.BifrostContextKeyVirtualKey, strings.TrimPrefix(auth.googAPIKey, "sk-bf-"))
} else if handlerStore.ShouldAllowDirectKeys() {
key := schemas.Key{
ID: "header-provided",
Value: *schemas.NewEnvVar(auth.googAPIKey),
Models: schemas.WhiteList{"*"},
Weight: 1.0,
}
ctx.SetValue(schemas.BifrostContextKeyDirectKey, key)
}
}
// Forward x-bf-* headers
for k, v := range auth.extraHeaders {
lk := strings.ToLower(k)
switch {
case lk == "x-bf-vk":
// Already handled above
case lk == "x-bf-api-key":
ctx.SetValue(schemas.BifrostContextKeyAPIKeyName, v)
case strings.HasPrefix(lk, "x-bf-eh-"):
suffix := strings.TrimPrefix(lk, "x-bf-eh-")
existing, _ := ctx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string]string)
if existing == nil {
existing = make(map[string]string)
}
existing[suffix] = v
ctx.SetValue(schemas.BifrostContextKeyExtraHeaders, existing)
}
}
return ctx, cancel
}
// executeHTTPBridge runs the response through the existing streaming inference pipeline.
func (h *WSResponsesHandler) executeHTTPBridge(
session *bfws.Session,
ctx *schemas.BifrostContext,
cancel context.CancelFunc,
req *schemas.BifrostResponsesRequest,
) {
defer cancel()
stream, bifrostErr := h.client.ResponsesStreamRequest(ctx, req)
if bifrostErr != nil {
writeWSBifrostError(session, bifrostErr)
return
}
// Relay streaming chunks as WS messages
for chunk := range stream {
if chunk == nil {
continue
}
chunkJSON, err := sonic.Marshal(chunk)
if err != nil {
logger.Warn("failed to marshal stream chunk: %v", err)
continue
}
if writeErr := session.WriteMessage(ws.TextMessage, chunkJSON); writeErr != nil {
return
}
// Track last response ID for session chaining
if chunk.BifrostResponsesStreamResponse != nil &&
chunk.BifrostResponsesStreamResponse.Response != nil &&
chunk.BifrostResponsesStreamResponse.Response.ID != nil &&
*chunk.BifrostResponsesStreamResponse.Response.ID != "" {
session.SetLastResponseID(*chunk.BifrostResponsesStreamResponse.Response.ID)
}
}
}
// writeWSError sends a JSON error event to a WebSocket write target.
// Accepts either a raw *ws.Conn (pre-session) or a *bfws.Session (mutex-protected).
func writeWSError(w wsWriter, status int, code, message string) {
event := schemas.WebSocketErrorEvent{
Type: schemas.WSEventError,
Status: status,
Error: &schemas.WebSocketErrorBody{
Code: code,
Message: message,
},
}
data, err := sonic.Marshal(event)
if err != nil {
return
}
w.WriteMessage(ws.TextMessage, data)
}
// writeWSBifrostError converts a BifrostError to a WS error event.
func writeWSBifrostError(w wsWriter, bifrostErr *schemas.BifrostError) {
status := 500
if bifrostErr.StatusCode != nil && *bifrostErr.StatusCode > 0 {
status = *bifrostErr.StatusCode
}
code := "server_error"
msg := "internal server error"
if bifrostErr.Error != nil {
if bifrostErr.Error.Code != nil && *bifrostErr.Error.Code != "" {
code = *bifrostErr.Error.Code
} else if bifrostErr.Error.Type != nil && *bifrostErr.Error.Type != "" {
code = *bifrostErr.Error.Type
}
if bifrostErr.Error.Message != "" {
msg = bifrostErr.Error.Message
}
}
writeWSError(w, status, code, msg)
}
// wsResponsesKnownFields lists the fields explicitly handled by WebSocketResponsesEvent.
// Anything not in this set is treated as an extra param and forwarded as-is to the provider.
var wsResponsesKnownFields = map[string]bool{
"type": true,
"model": true,
"store": true,
"input": true,
"instructions": true,
"previous_response_id": true,
"tools": true,
"tool_choice": true,
"temperature": true,
"top_p": true,
"max_output_tokens": true,
"reasoning": true,
"metadata": true,
"text": true,
"truncation": true,
}
var (
errModelFormat = errorf("model should be in provider/model format")
errInputRequired = errorf("input is required for responses")
)
func errorf(msg string) error {
return &simpleError{msg: msg}
}
type simpleError struct {
msg string
}
func (e *simpleError) Error() string {
return e.msg
}

View File

@@ -0,0 +1,68 @@
package handlers
import (
"testing"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/kvstore"
"github.com/maximhq/bifrost/framework/logstore"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
)
type testWSHandlerStore struct {
allowDirectKeys bool
}
func (s testWSHandlerStore) ShouldAllowDirectKeys() bool {
return s.allowDirectKeys
}
func (s testWSHandlerStore) GetHeaderMatcher() *lib.HeaderMatcher {
return nil
}
func (s testWSHandlerStore) GetAvailableProviders() []schemas.ModelProvider {
return nil
}
func (s testWSHandlerStore) GetStreamChunkInterceptor() lib.StreamChunkInterceptor {
return nil
}
func (s testWSHandlerStore) GetAsyncJobExecutor() *logstore.AsyncJobExecutor {
return nil
}
func (s testWSHandlerStore) GetAsyncJobResultTTL() int {
return 0
}
func (s testWSHandlerStore) GetKVStore() *kvstore.Store {
return nil
}
func (s testWSHandlerStore) GetMCPHeaderCombinedAllowlist() schemas.WhiteList {
return nil
}
func TestCreateBifrostContextFromAuth_BaggageSessionIDSetsGrouping(t *testing.T) {
ctx, cancel := createBifrostContextFromAuth(testWSHandlerStore{}, &authHeaders{
baggage: "foo=bar, session-id=rt-ws-123, baz=qux",
})
defer cancel()
if got, _ := ctx.Value(schemas.BifrostContextKeyParentRequestID).(string); got != "rt-ws-123" {
t.Fatalf("parent request id = %q, want %q", got, "rt-ws-123")
}
}
func TestCreateBifrostContextFromAuth_EmptyBaggageSessionIDIgnored(t *testing.T) {
ctx, cancel := createBifrostContextFromAuth(testWSHandlerStore{}, &authHeaders{
baggage: "session-id= ",
})
defer cancel()
if got := ctx.Value(schemas.BifrostContextKeyParentRequestID); got != nil {
t.Fatalf("parent request id should be unset, got %#v", got)
}
}