first commit
This commit is contained in:
548
transports/bifrost-http/handlers/asyncinference.go
Normal file
548
transports/bifrost-http/handlers/asyncinference.go
Normal file
@@ -0,0 +1,548 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/logstore"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// --- HTTP Handler ---
|
||||
|
||||
// AsyncHandler handles async job HTTP endpoints.
|
||||
type AsyncHandler struct {
|
||||
client *bifrost.Bifrost
|
||||
executor *logstore.AsyncJobExecutor
|
||||
handlerStore lib.HandlerStore
|
||||
config *lib.Config
|
||||
}
|
||||
|
||||
// AsyncPathToTypeMapping maps exact paths to request types (only for non-parameterized paths)
|
||||
// Parameterized paths are set per-route in RegisterRoutes
|
||||
var AsyncPathToTypeMapping = map[string]schemas.RequestType{
|
||||
"/v1/async/completions": schemas.TextCompletionRequest,
|
||||
"/v1/async/chat/completions": schemas.ChatCompletionRequest,
|
||||
"/v1/async/responses": schemas.ResponsesRequest,
|
||||
"/v1/async/embeddings": schemas.EmbeddingRequest,
|
||||
"/v1/async/audio/speech": schemas.SpeechRequest,
|
||||
"/v1/async/audio/transcriptions": schemas.TranscriptionRequest,
|
||||
"/v1/async/images/generations": schemas.ImageGenerationRequest,
|
||||
"/v1/async/images/edits": schemas.ImageEditRequest,
|
||||
"/v1/async/images/variations": schemas.ImageVariationRequest,
|
||||
"/v1/async/rerank": schemas.RerankRequest,
|
||||
"/v1/async/ocr": schemas.OCRRequest,
|
||||
}
|
||||
|
||||
// RegisterAsyncRequestTypeMiddleware handles exact path matching for non-parameterized routes
|
||||
func RegisterAsyncRequestTypeMiddleware(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
||||
return func(ctx *fasthttp.RequestCtx) {
|
||||
path := string(ctx.Path())
|
||||
if requestType, ok := AsyncPathToTypeMapping[path]; ok {
|
||||
ctx.SetUserValue(schemas.BifrostContextKeyHTTPRequestType, requestType)
|
||||
}
|
||||
next(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// NewAsyncHandler creates a new AsyncHandler.
|
||||
// If the async job executor is not available (e.g., LogsStore or governance plugin not configured),
|
||||
// the handler is created with a nil executor and RegisterRoutes will skip async route registration.
|
||||
func NewAsyncHandler(client *bifrost.Bifrost, config *lib.Config) *AsyncHandler {
|
||||
return &AsyncHandler{
|
||||
client: client,
|
||||
executor: config.GetAsyncJobExecutor(),
|
||||
handlerStore: config,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers async job endpoints.
|
||||
func (h *AsyncHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
if h.executor == nil {
|
||||
return // LogStore not configured, skip async routes
|
||||
}
|
||||
|
||||
baseMiddlewares := append([]schemas.BifrostHTTPMiddleware{RegisterAsyncRequestTypeMiddleware}, middlewares...)
|
||||
|
||||
// Async submission endpoints (non-parameterized, request type set via AsyncPathToTypeMapping)
|
||||
r.POST("/v1/async/completions", lib.ChainMiddlewares(h.asyncTextCompletion, baseMiddlewares...))
|
||||
r.POST("/v1/async/chat/completions", lib.ChainMiddlewares(h.asyncChatCompletion, baseMiddlewares...))
|
||||
r.POST("/v1/async/responses", lib.ChainMiddlewares(h.asyncResponses, baseMiddlewares...))
|
||||
r.POST("/v1/async/embeddings", lib.ChainMiddlewares(h.asyncEmbeddings, baseMiddlewares...))
|
||||
r.POST("/v1/async/audio/speech", lib.ChainMiddlewares(h.asyncSpeech, baseMiddlewares...))
|
||||
r.POST("/v1/async/audio/transcriptions", lib.ChainMiddlewares(h.asyncTranscription, baseMiddlewares...))
|
||||
r.POST("/v1/async/images/generations", lib.ChainMiddlewares(h.asyncImageGeneration, baseMiddlewares...))
|
||||
r.POST("/v1/async/images/edits", lib.ChainMiddlewares(h.asyncImageEdit, baseMiddlewares...))
|
||||
r.POST("/v1/async/images/variations", lib.ChainMiddlewares(h.asyncImageVariation, baseMiddlewares...))
|
||||
r.POST("/v1/async/rerank", lib.ChainMiddlewares(h.asyncRerank, baseMiddlewares...))
|
||||
r.POST("/v1/async/ocr", lib.ChainMiddlewares(h.asyncOCR, baseMiddlewares...))
|
||||
|
||||
// Async job retrieval endpoints
|
||||
r.GET("/v1/async/completions/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.TextCompletionRequest), middlewares...))
|
||||
r.GET("/v1/async/chat/completions/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.ChatCompletionRequest), middlewares...))
|
||||
r.GET("/v1/async/responses/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.ResponsesRequest), middlewares...))
|
||||
r.GET("/v1/async/embeddings/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.EmbeddingRequest), middlewares...))
|
||||
r.GET("/v1/async/audio/speech/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.SpeechRequest), middlewares...))
|
||||
r.GET("/v1/async/audio/transcriptions/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.TranscriptionRequest), middlewares...))
|
||||
r.GET("/v1/async/images/generations/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.ImageGenerationRequest), middlewares...))
|
||||
r.GET("/v1/async/images/edits/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.ImageEditRequest), middlewares...))
|
||||
r.GET("/v1/async/images/variations/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.ImageVariationRequest), middlewares...))
|
||||
r.GET("/v1/async/rerank/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.RerankRequest), middlewares...))
|
||||
r.GET("/v1/async/ocr/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.OCRRequest), middlewares...))
|
||||
}
|
||||
|
||||
// --- Async submission handlers ---
|
||||
|
||||
// asyncTextCompletion handles POST /v1/async/completions
|
||||
func (h *AsyncHandler) asyncTextCompletion(ctx *fasthttp.RequestCtx) {
|
||||
req, bifrostTextReq, err := prepareTextCompletionRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.Stream != nil && *req.Stream {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "stream is not supported for async text completions")
|
||||
return
|
||||
}
|
||||
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
|
||||
|
||||
job, err := h.executor.SubmitJob(
|
||||
bifrostCtx,
|
||||
resultTTL,
|
||||
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
|
||||
return h.client.TextCompletionRequest(bgCtx, bifrostTextReq)
|
||||
},
|
||||
schemas.TextCompletionRequest,
|
||||
)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
|
||||
}
|
||||
|
||||
// asyncChatCompletion handles POST /v1/async/chat/completions
|
||||
func (h *AsyncHandler) asyncChatCompletion(ctx *fasthttp.RequestCtx) {
|
||||
req, bifrostChatReq, err := prepareChatCompletionRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.Stream != nil && *req.Stream {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "stream is not supported for async chat completions")
|
||||
return
|
||||
}
|
||||
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
|
||||
|
||||
job, err := h.executor.SubmitJob(
|
||||
bifrostCtx,
|
||||
resultTTL,
|
||||
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
|
||||
return h.client.ChatCompletionRequest(bgCtx, bifrostChatReq)
|
||||
},
|
||||
schemas.ChatCompletionRequest,
|
||||
)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
|
||||
}
|
||||
|
||||
// asyncResponses handles POST /v1/async/responses
|
||||
func (h *AsyncHandler) asyncResponses(ctx *fasthttp.RequestCtx) {
|
||||
req, bifrostResponsesReq, err := prepareResponsesRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.Stream != nil && *req.Stream {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "stream is not supported for async responses")
|
||||
return
|
||||
}
|
||||
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
|
||||
|
||||
job, err := h.executor.SubmitJob(
|
||||
bifrostCtx,
|
||||
resultTTL,
|
||||
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
|
||||
return h.client.ResponsesRequest(bgCtx, bifrostResponsesReq)
|
||||
},
|
||||
schemas.ResponsesRequest,
|
||||
)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Failed to create async job: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
|
||||
}
|
||||
|
||||
// asyncEmbeddings handles POST /v1/async/embeddings
|
||||
func (h *AsyncHandler) asyncEmbeddings(ctx *fasthttp.RequestCtx) {
|
||||
_, bifrostEmbeddingReq, err := prepareEmbeddingRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
|
||||
|
||||
job, err := h.executor.SubmitJob(
|
||||
bifrostCtx,
|
||||
resultTTL,
|
||||
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
|
||||
return h.client.EmbeddingRequest(bgCtx, bifrostEmbeddingReq)
|
||||
},
|
||||
schemas.EmbeddingRequest,
|
||||
)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
|
||||
}
|
||||
|
||||
// asyncSpeech handles POST /v1/async/audio/speech
|
||||
func (h *AsyncHandler) asyncSpeech(ctx *fasthttp.RequestCtx) {
|
||||
req, bifrostSpeechReq, err := prepareSpeechRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.StreamFormat != nil && *req.StreamFormat == "sse" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "stream is not supported for async speech")
|
||||
return
|
||||
}
|
||||
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
|
||||
|
||||
job, err := h.executor.SubmitJob(
|
||||
bifrostCtx,
|
||||
resultTTL,
|
||||
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
|
||||
return h.client.SpeechRequest(bgCtx, bifrostSpeechReq)
|
||||
},
|
||||
schemas.SpeechRequest,
|
||||
)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
|
||||
}
|
||||
|
||||
// asyncTranscription handles POST /v1/async/audio/transcriptions
|
||||
func (h *AsyncHandler) asyncTranscription(ctx *fasthttp.RequestCtx) {
|
||||
bifrostTranscriptionReq, stream, err := prepareTranscriptionRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if stream {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "stream is not supported for async transcriptions")
|
||||
return
|
||||
}
|
||||
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
|
||||
|
||||
job, err := h.executor.SubmitJob(
|
||||
bifrostCtx,
|
||||
resultTTL,
|
||||
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
|
||||
return h.client.TranscriptionRequest(bgCtx, bifrostTranscriptionReq)
|
||||
},
|
||||
schemas.TranscriptionRequest,
|
||||
)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
|
||||
}
|
||||
|
||||
// asyncImageGeneration handles POST /v1/async/images/generations
|
||||
func (h *AsyncHandler) asyncImageGeneration(ctx *fasthttp.RequestCtx) {
|
||||
req, bifrostReq, err := prepareImageGenerationRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.BifrostParams.Stream != nil && *req.BifrostParams.Stream {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "stream is not supported for async image generations")
|
||||
return
|
||||
}
|
||||
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
|
||||
|
||||
job, err := h.executor.SubmitJob(
|
||||
bifrostCtx,
|
||||
resultTTL,
|
||||
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
|
||||
return h.client.ImageGenerationRequest(bgCtx, bifrostReq)
|
||||
},
|
||||
schemas.ImageGenerationRequest,
|
||||
)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
|
||||
}
|
||||
|
||||
// asyncImageEdit handles POST /v1/async/images/edits
|
||||
func (h *AsyncHandler) asyncImageEdit(ctx *fasthttp.RequestCtx) {
|
||||
req, bifrostReq, err := prepareImageEditRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.Stream != nil && *req.Stream {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "stream is not supported for async image edits")
|
||||
return
|
||||
}
|
||||
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
|
||||
|
||||
job, err := h.executor.SubmitJob(
|
||||
bifrostCtx,
|
||||
resultTTL,
|
||||
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
|
||||
return h.client.ImageEditRequest(bgCtx, bifrostReq)
|
||||
},
|
||||
schemas.ImageEditRequest,
|
||||
)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
|
||||
}
|
||||
|
||||
// asyncImageVariation handles POST /v1/async/images/variations
|
||||
func (h *AsyncHandler) asyncImageVariation(ctx *fasthttp.RequestCtx) {
|
||||
bifrostReq, err := prepareImageVariationRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
|
||||
|
||||
job, err := h.executor.SubmitJob(
|
||||
bifrostCtx,
|
||||
resultTTL,
|
||||
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
|
||||
return h.client.ImageVariationRequest(bgCtx, bifrostReq)
|
||||
},
|
||||
schemas.ImageVariationRequest,
|
||||
)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
|
||||
}
|
||||
|
||||
// asyncRerank handles POST /v1/async/rerank
|
||||
func (h *AsyncHandler) asyncRerank(ctx *fasthttp.RequestCtx) {
|
||||
_, bifrostReq, err := prepareRerankRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
|
||||
|
||||
job, err := h.executor.SubmitJob(
|
||||
bifrostCtx,
|
||||
resultTTL,
|
||||
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
|
||||
return h.client.RerankRequest(bgCtx, bifrostReq)
|
||||
},
|
||||
schemas.RerankRequest,
|
||||
)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
|
||||
}
|
||||
|
||||
// asyncOCR handles POST /v1/async/ocr
|
||||
func (h *AsyncHandler) asyncOCR(ctx *fasthttp.RequestCtx) {
|
||||
_, bifrostReq, err := prepareOCRRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
|
||||
|
||||
job, err := h.executor.SubmitJob(
|
||||
bifrostCtx,
|
||||
resultTTL,
|
||||
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
|
||||
return h.client.OCRRequest(bgCtx, bifrostReq)
|
||||
},
|
||||
schemas.OCRRequest,
|
||||
)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
|
||||
}
|
||||
|
||||
// --- Job retrieval handler ---
|
||||
|
||||
// getJob handles GET /v1/async/{type}/{job_id}
|
||||
func (h *AsyncHandler) getJob(operationType schemas.RequestType) fasthttp.RequestHandler {
|
||||
return func(ctx *fasthttp.RequestCtx) {
|
||||
jobID, ok := ctx.UserValue("job_id").(string)
|
||||
if !ok || jobID == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "job_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Get the requesting user's VK for auth check
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
job, err := h.executor.RetrieveJob(bifrostCtx, jobID, getVirtualKeyFromContext(bifrostCtx), operationType)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp := job.ToResponse()
|
||||
|
||||
// Return 202 for pending/processing, 200 for completed/failed
|
||||
switch job.Status {
|
||||
case schemas.AsyncJobStatusPending, schemas.AsyncJobStatusProcessing:
|
||||
SendJSONWithStatus(ctx, resp, fasthttp.StatusAccepted)
|
||||
default:
|
||||
SendJSON(ctx, resp)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Helper functions ---
|
||||
|
||||
// getVirtualKeyFromContext extracts the virtual key value from context.
|
||||
// Returns nil if no VK is present (e.g., direct key mode or no governance).
|
||||
func getVirtualKeyFromContext(ctx *schemas.BifrostContext) *string {
|
||||
vkValue := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey)
|
||||
if vkValue == "" {
|
||||
return nil
|
||||
}
|
||||
return &vkValue
|
||||
}
|
||||
|
||||
func getResultTTLFromHeaderWithDefault(ctx *fasthttp.RequestCtx, defaultTTL int) int {
|
||||
resultTTL := string(ctx.Request.Header.Peek(schemas.AsyncHeaderResultTTL))
|
||||
if resultTTL == "" {
|
||||
return defaultTTL
|
||||
}
|
||||
resultTTLInt, err := strconv.Atoi(resultTTL)
|
||||
if err != nil || resultTTLInt < 0 {
|
||||
return defaultTTL
|
||||
}
|
||||
return resultTTLInt
|
||||
}
|
||||
61
transports/bifrost-http/handlers/cache.go
Normal file
61
transports/bifrost-http/handlers/cache.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/plugins/semanticcache"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
type CacheHandler struct {
|
||||
plugin *semanticcache.Plugin
|
||||
}
|
||||
|
||||
func NewCacheHandler(plugin schemas.LLMPlugin) *CacheHandler {
|
||||
semanticCachePlugin, ok := plugin.(*semanticcache.Plugin)
|
||||
if !ok {
|
||||
logger.Fatal("Cache handler requires a semantic cache plugin")
|
||||
}
|
||||
|
||||
return &CacheHandler{
|
||||
plugin: semanticCachePlugin,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *CacheHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
r.DELETE("/api/cache/clear/{requestId}", lib.ChainMiddlewares(h.clearCache, middlewares...))
|
||||
r.DELETE("/api/cache/clear-by-key/{cacheKey}", lib.ChainMiddlewares(h.clearCacheByKey, middlewares...))
|
||||
}
|
||||
|
||||
func (h *CacheHandler) clearCache(ctx *fasthttp.RequestCtx) {
|
||||
requestID, ok := ctx.UserValue("requestId").(string)
|
||||
if !ok {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Invalid request ID")
|
||||
return
|
||||
}
|
||||
if err := h.plugin.ClearCacheForRequestID(requestID); err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to clear cache")
|
||||
return
|
||||
}
|
||||
|
||||
SendJSON(ctx, map[string]any{
|
||||
"message": "Cache cleared successfully",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *CacheHandler) clearCacheByKey(ctx *fasthttp.RequestCtx) {
|
||||
cacheKey, ok := ctx.UserValue("cacheKey").(string)
|
||||
if !ok {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Invalid cache key")
|
||||
return
|
||||
}
|
||||
if err := h.plugin.ClearCacheForKey(cacheKey); err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to clear cache")
|
||||
return
|
||||
}
|
||||
|
||||
SendJSON(ctx, map[string]any{
|
||||
"message": "Cache cleared successfully",
|
||||
})
|
||||
}
|
||||
887
transports/bifrost-http/handlers/config.go
Normal file
887
transports/bifrost-http/handlers/config.go
Normal file
@@ -0,0 +1,887 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/network"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/maximhq/bifrost/framework/encrypt"
|
||||
"github.com/maximhq/bifrost/framework/modelcatalog"
|
||||
"github.com/maximhq/bifrost/plugins/compat"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// securityHeaders is the list of headers that cannot be configured in allowlist/denylist
|
||||
// These headers are always blocked for security reasons regardless of user configuration
|
||||
var securityHeaders = []string{
|
||||
"authorization",
|
||||
"proxy-authorization",
|
||||
"cookie",
|
||||
"host",
|
||||
"content-length",
|
||||
"connection",
|
||||
"transfer-encoding",
|
||||
"x-api-key",
|
||||
"x-goog-api-key",
|
||||
"x-bf-api-key",
|
||||
"x-bf-vk",
|
||||
}
|
||||
|
||||
// ConfigManager is the interface for the config manager
|
||||
type ConfigManager interface {
|
||||
UpdateAuthConfig(ctx context.Context, authConfig *configstore.AuthConfig) error
|
||||
ReloadClientConfigFromConfigStore(ctx context.Context) error
|
||||
UpdateSyncConfig(ctx context.Context) error
|
||||
ForceReloadPricing(ctx context.Context) error
|
||||
UpdateDropExcessRequests(ctx context.Context, value bool)
|
||||
UpdateMCPToolManagerConfig(ctx context.Context, maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string, disableAutoToolInject bool) error
|
||||
ReloadPlugin(ctx context.Context, name string, path *string, pluginConfig any, placement *schemas.PluginPlacement, order *int) error
|
||||
RemovePlugin(ctx context.Context, name string) error
|
||||
ReloadProxyConfig(ctx context.Context, config *configstoreTables.GlobalProxyConfig) error
|
||||
ReloadHeaderFilterConfig(ctx context.Context, config *configstoreTables.GlobalHeaderFilterConfig) error
|
||||
}
|
||||
|
||||
// ConfigHandler manages runtime configuration updates for Bifrost.
|
||||
// It provides endpoints to update and retrieve settings persisted via the ConfigStore backed by sql database.
|
||||
type ConfigHandler struct {
|
||||
store *lib.Config
|
||||
configManager ConfigManager
|
||||
}
|
||||
|
||||
// NewConfigHandler creates a new handler for configuration management.
|
||||
// It requires the Bifrost client, a logger, and the config store.
|
||||
func NewConfigHandler(configManager ConfigManager, store *lib.Config) *ConfigHandler {
|
||||
return &ConfigHandler{
|
||||
configManager: configManager,
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the configuration-related routes.
|
||||
// It adds the `PUT /api/config` endpoint.
|
||||
func (h *ConfigHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
r.GET("/api/config", lib.ChainMiddlewares(h.getConfig, middlewares...))
|
||||
r.PUT("/api/config", lib.ChainMiddlewares(h.updateConfig, middlewares...))
|
||||
r.GET("/api/version", lib.ChainMiddlewares(h.getVersion, middlewares...))
|
||||
r.GET("/api/proxy-config", lib.ChainMiddlewares(h.getProxyConfig, middlewares...))
|
||||
r.PUT("/api/proxy-config", lib.ChainMiddlewares(h.updateProxyConfig, middlewares...))
|
||||
r.POST("/api/pricing/force-sync", lib.ChainMiddlewares(h.forceSyncPricing, middlewares...))
|
||||
}
|
||||
|
||||
// getVersion handles GET /api/version - Get the current version
|
||||
func (h *ConfigHandler) getVersion(ctx *fasthttp.RequestCtx) {
|
||||
SendJSON(ctx, version)
|
||||
}
|
||||
|
||||
// getConfig handles GET /config - Get the current configuration
|
||||
func (h *ConfigHandler) getConfig(ctx *fasthttp.RequestCtx) {
|
||||
mapConfig := make(map[string]any)
|
||||
|
||||
if query := string(ctx.QueryArgs().Peek("from_db")); query == "true" {
|
||||
if h.store.ConfigStore == nil {
|
||||
SendError(ctx, fasthttp.StatusServiceUnavailable, "config store not available")
|
||||
return
|
||||
}
|
||||
cc, err := h.store.ConfigStore.GetClientConfig(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError,
|
||||
fmt.Sprintf("failed to fetch config from db: %v", err))
|
||||
return
|
||||
}
|
||||
if cc != nil {
|
||||
mapConfig["client_config"] = *cc
|
||||
}
|
||||
// Fetching framework config
|
||||
fc, err := h.store.ConfigStore.GetFrameworkConfig(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to fetch framework config from db: %v", err))
|
||||
return
|
||||
}
|
||||
normalizedFrameworkConfig, _, _ := lib.ResolveFrameworkPricingConfig(fc, nil)
|
||||
mapConfig["framework_config"] = *normalizedFrameworkConfig
|
||||
} else {
|
||||
mapConfig["client_config"] = h.store.ClientConfig
|
||||
normalizedFrameworkConfig, _, _ := lib.ResolveFrameworkPricingConfig(nil, h.store.FrameworkConfig)
|
||||
mapConfig["framework_config"] = *normalizedFrameworkConfig
|
||||
}
|
||||
if h.store.ConfigStore != nil {
|
||||
// Fetching governance config
|
||||
authConfig, err := h.store.ConfigStore.GetAuthConfig(ctx)
|
||||
if err != nil {
|
||||
logger.Warn("failed to get auth config from store: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to get auth config from store: %v", err))
|
||||
return
|
||||
}
|
||||
// Getting username and password from auth config
|
||||
// This username password is for the dashboard authentication
|
||||
if authConfig != nil {
|
||||
// For password, return EnvVar structure with redacted value
|
||||
// If from env, preserve env_var reference but clear value
|
||||
// If not from env, show <redacted> as the value
|
||||
var passwordEnvVar *schemas.EnvVar
|
||||
if authConfig.AdminPassword != nil && authConfig.AdminPassword.IsFromEnv() {
|
||||
passwordEnvVar = &schemas.EnvVar{
|
||||
Val: "",
|
||||
EnvVar: authConfig.AdminPassword.EnvVar,
|
||||
FromEnv: true,
|
||||
}
|
||||
} else {
|
||||
passwordEnvVar = &schemas.EnvVar{
|
||||
Val: "<redacted>",
|
||||
EnvVar: "",
|
||||
FromEnv: false,
|
||||
}
|
||||
}
|
||||
mapConfig["auth_config"] = map[string]any{
|
||||
"admin_username": authConfig.AdminUserName,
|
||||
"admin_password": passwordEnvVar,
|
||||
"is_enabled": authConfig.IsEnabled,
|
||||
"disable_auth_on_inference": authConfig.DisableAuthOnInference,
|
||||
}
|
||||
} else {
|
||||
// No auth config exists yet, return default empty EnvVar values
|
||||
mapConfig["auth_config"] = map[string]any{
|
||||
"admin_username": &schemas.EnvVar{Val: "", EnvVar: "", FromEnv: false},
|
||||
"admin_password": &schemas.EnvVar{Val: "", EnvVar: "", FromEnv: false},
|
||||
"is_enabled": false,
|
||||
"disable_auth_on_inference": false,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mapConfig["auth_config"] = map[string]any{
|
||||
"admin_username": &schemas.EnvVar{Val: "", EnvVar: "", FromEnv: false},
|
||||
"admin_password": &schemas.EnvVar{Val: "", EnvVar: "", FromEnv: false},
|
||||
"is_enabled": false,
|
||||
"disable_auth_on_inference": false,
|
||||
}
|
||||
}
|
||||
mapConfig["is_db_connected"] = h.store.ConfigStore != nil
|
||||
mapConfig["is_cache_connected"] = h.store.VectorStore != nil
|
||||
mapConfig["is_logs_connected"] = h.store.LogsStore != nil
|
||||
// Fetching proxy config
|
||||
if h.store.ConfigStore != nil {
|
||||
proxyConfig, err := h.store.ConfigStore.GetProxyConfig(ctx)
|
||||
if err != nil {
|
||||
logger.Warn("failed to get proxy config from store: %v", err)
|
||||
} else if proxyConfig != nil {
|
||||
// Redact password if present
|
||||
if proxyConfig.Password != "" {
|
||||
proxyConfig.Password = "<redacted>"
|
||||
}
|
||||
mapConfig["proxy_config"] = proxyConfig
|
||||
}
|
||||
// Fetching restart required config
|
||||
restartConfig, err := h.store.ConfigStore.GetRestartRequiredConfig(ctx)
|
||||
if err != nil {
|
||||
logger.Warn("failed to get restart required config from store: %v", err)
|
||||
} else if restartConfig != nil {
|
||||
mapConfig["restart_required"] = restartConfig
|
||||
}
|
||||
}
|
||||
SendJSON(ctx, mapConfig)
|
||||
}
|
||||
|
||||
// updateConfig updates the core configuration settings.
|
||||
// Currently, it supports hot-reloading of the `drop_excess_requests` setting.
|
||||
// Note that settings like `prometheus_labels` cannot be changed at runtime.
|
||||
func (h *ConfigHandler) updateConfig(ctx *fasthttp.RequestCtx) {
|
||||
if h.store.ConfigStore == nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Config store not initialized")
|
||||
return
|
||||
}
|
||||
|
||||
payload := struct {
|
||||
ClientConfig configstore.ClientConfig `json:"client_config"`
|
||||
FrameworkConfig configstoreTables.TableFrameworkConfig `json:"framework_config"`
|
||||
AuthConfig *configstore.AuthConfig `json:"auth_config"`
|
||||
}{}
|
||||
|
||||
if err := json.Unmarshal(ctx.PostBody(), &payload); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Validating framework config
|
||||
if payload.FrameworkConfig.PricingURL != nil && *payload.FrameworkConfig.PricingURL != modelcatalog.DefaultPricingURL {
|
||||
// Checking the accessibility of the pricing URL
|
||||
resp, err := http.Get(*payload.FrameworkConfig.PricingURL)
|
||||
if err != nil {
|
||||
logger.Warn("failed to check the accessibility of the pricing URL: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to check the accessibility of the pricing URL: %v", err))
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.Warn("failed to check the accessibility of the pricing URL: %v", resp.StatusCode)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to check the accessibility of the pricing URL: %v", resp.StatusCode))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Checking the pricing sync interval
|
||||
if payload.FrameworkConfig.PricingSyncInterval != nil && *payload.FrameworkConfig.PricingSyncInterval <= 0 {
|
||||
logger.Warn("pricing sync interval must be greater than 0")
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "pricing sync interval must be greater than 0")
|
||||
return
|
||||
}
|
||||
|
||||
// Get current config with proper locking
|
||||
currentConfig := h.store.ClientConfig
|
||||
updatedConfig := currentConfig
|
||||
|
||||
var restartReasons []string
|
||||
|
||||
if payload.ClientConfig.DropExcessRequests != currentConfig.DropExcessRequests {
|
||||
h.configManager.UpdateDropExcessRequests(ctx, payload.ClientConfig.DropExcessRequests)
|
||||
updatedConfig.DropExcessRequests = payload.ClientConfig.DropExcessRequests
|
||||
}
|
||||
|
||||
if payload.ClientConfig.MCPCodeModeBindingLevel != "" {
|
||||
if payload.ClientConfig.MCPCodeModeBindingLevel != string(schemas.CodeModeBindingLevelServer) && payload.ClientConfig.MCPCodeModeBindingLevel != string(schemas.CodeModeBindingLevelTool) {
|
||||
logger.Warn("mcp_code_mode_binding_level must be 'server' or 'tool'")
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "mcp_code_mode_binding_level must be 'server' or 'tool'")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
shouldReloadMCPToolManagerConfig := false
|
||||
|
||||
// Only process MCPAgentDepth if explicitly provided (> 0) and different from current
|
||||
if payload.ClientConfig.MCPAgentDepth > 0 && payload.ClientConfig.MCPAgentDepth != currentConfig.MCPAgentDepth {
|
||||
updatedConfig.MCPAgentDepth = payload.ClientConfig.MCPAgentDepth
|
||||
shouldReloadMCPToolManagerConfig = true
|
||||
}
|
||||
|
||||
// Only process MCPToolExecutionTimeout if explicitly provided (> 0) and different from current
|
||||
if payload.ClientConfig.MCPToolExecutionTimeout > 0 && payload.ClientConfig.MCPToolExecutionTimeout != currentConfig.MCPToolExecutionTimeout {
|
||||
updatedConfig.MCPToolExecutionTimeout = payload.ClientConfig.MCPToolExecutionTimeout
|
||||
shouldReloadMCPToolManagerConfig = true
|
||||
}
|
||||
|
||||
if payload.ClientConfig.MCPCodeModeBindingLevel != "" && payload.ClientConfig.MCPCodeModeBindingLevel != currentConfig.MCPCodeModeBindingLevel {
|
||||
updatedConfig.MCPCodeModeBindingLevel = payload.ClientConfig.MCPCodeModeBindingLevel
|
||||
shouldReloadMCPToolManagerConfig = true
|
||||
}
|
||||
|
||||
if payload.ClientConfig.MCPDisableAutoToolInject != currentConfig.MCPDisableAutoToolInject {
|
||||
updatedConfig.MCPDisableAutoToolInject = payload.ClientConfig.MCPDisableAutoToolInject
|
||||
shouldReloadMCPToolManagerConfig = true
|
||||
}
|
||||
|
||||
// Reload MCP tool manager config with all current values in one call
|
||||
if shouldReloadMCPToolManagerConfig && h.store.MCPConfig != nil {
|
||||
if err := h.configManager.UpdateMCPToolManagerConfig(ctx, updatedConfig.MCPAgentDepth, updatedConfig.MCPToolExecutionTimeout, updatedConfig.MCPCodeModeBindingLevel, updatedConfig.MCPDisableAutoToolInject); err != nil {
|
||||
logger.Warn(fmt.Sprintf("failed to update mcp tool manager config: %v", err))
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to update mcp tool manager config: %v", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !slices.Equal(payload.ClientConfig.PrometheusLabels, currentConfig.PrometheusLabels) {
|
||||
updatedConfig.PrometheusLabels = payload.ClientConfig.PrometheusLabels
|
||||
restartReasons = append(restartReasons, "Prometheus labels")
|
||||
}
|
||||
|
||||
if !slices.Equal(payload.ClientConfig.AllowedOrigins, currentConfig.AllowedOrigins) {
|
||||
updatedConfig.AllowedOrigins = payload.ClientConfig.AllowedOrigins
|
||||
restartReasons = append(restartReasons, "Allowed origins")
|
||||
}
|
||||
|
||||
if !slices.Equal(payload.ClientConfig.AllowedHeaders, currentConfig.AllowedHeaders) {
|
||||
updatedConfig.AllowedHeaders = payload.ClientConfig.AllowedHeaders
|
||||
restartReasons = append(restartReasons, "Allowed headers")
|
||||
}
|
||||
|
||||
// Only update InitialPoolSize if explicitly provided (> 0) to avoid clearing stored value
|
||||
if payload.ClientConfig.InitialPoolSize > 0 {
|
||||
if payload.ClientConfig.InitialPoolSize != currentConfig.InitialPoolSize {
|
||||
restartReasons = append(restartReasons, "Initial pool size")
|
||||
}
|
||||
updatedConfig.InitialPoolSize = payload.ClientConfig.InitialPoolSize
|
||||
}
|
||||
|
||||
if payload.ClientConfig.EnableLogging != nil {
|
||||
payloadLogging := *payload.ClientConfig.EnableLogging
|
||||
currentLogging := currentConfig.EnableLogging == nil || *currentConfig.EnableLogging
|
||||
if payloadLogging != currentLogging {
|
||||
restartReasons = append(restartReasons, "Logging changed")
|
||||
}
|
||||
updatedConfig.EnableLogging = payload.ClientConfig.EnableLogging
|
||||
}
|
||||
|
||||
if payload.ClientConfig.DisableContentLogging != currentConfig.DisableContentLogging {
|
||||
restartReasons = append(restartReasons, "Content logging")
|
||||
}
|
||||
updatedConfig.DisableContentLogging = payload.ClientConfig.DisableContentLogging
|
||||
updatedConfig.DisableDBPingsInHealth = payload.ClientConfig.DisableDBPingsInHealth
|
||||
updatedConfig.AllowDirectKeys = payload.ClientConfig.AllowDirectKeys
|
||||
|
||||
updatedConfig.EnforceAuthOnInference = payload.ClientConfig.EnforceAuthOnInference
|
||||
// Sync deprecated columns to match new field so they stay consistent in the DB
|
||||
updatedConfig.EnforceGovernanceHeader = payload.ClientConfig.EnforceAuthOnInference
|
||||
updatedConfig.EnforceSCIMAuth = payload.ClientConfig.EnforceAuthOnInference
|
||||
|
||||
// Only update MaxRequestBodySizeMB if explicitly provided (> 0) to avoid clearing stored value
|
||||
if payload.ClientConfig.MaxRequestBodySizeMB > 0 {
|
||||
if payload.ClientConfig.MaxRequestBodySizeMB != currentConfig.MaxRequestBodySizeMB {
|
||||
restartReasons = append(restartReasons, "Max request body size")
|
||||
}
|
||||
updatedConfig.MaxRequestBodySizeMB = payload.ClientConfig.MaxRequestBodySizeMB
|
||||
}
|
||||
|
||||
// Handle compat plugin toggle
|
||||
newCompat := payload.ClientConfig.Compat
|
||||
oldCompat := currentConfig.Compat
|
||||
if newCompat != oldCompat {
|
||||
newEnabled := newCompat.ConvertTextToChat || newCompat.ConvertChatToResponses || newCompat.ShouldDropParams || newCompat.ShouldConvertParams
|
||||
if newEnabled {
|
||||
compatCfg := &compat.Config{
|
||||
ConvertTextToChat: newCompat.ConvertTextToChat,
|
||||
ConvertChatToResponses: newCompat.ConvertChatToResponses,
|
||||
ShouldDropParams: newCompat.ShouldDropParams,
|
||||
ShouldConvertParams: newCompat.ShouldConvertParams,
|
||||
}
|
||||
if err := h.configManager.ReloadPlugin(ctx, compat.PluginName, nil, compatCfg, nil, nil); err != nil {
|
||||
logger.Warn("failed to load compat plugin: %v", err)
|
||||
SendError(ctx, 400, "Failed to load compat plugin")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
disabledCtx := context.WithValue(ctx, PluginDisabledKey, true)
|
||||
if err := h.configManager.RemovePlugin(disabledCtx, compat.PluginName); err != nil {
|
||||
logger.Warn("failed to remove compat plugin: %v", err)
|
||||
SendError(ctx, 400, "Failed to remove compat plugin")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
updatedConfig.Compat = newCompat
|
||||
// Only update MCP fields if explicitly provided (non-zero) to avoid clearing stored values
|
||||
if payload.ClientConfig.MCPAgentDepth > 0 {
|
||||
updatedConfig.MCPAgentDepth = payload.ClientConfig.MCPAgentDepth
|
||||
}
|
||||
if payload.ClientConfig.MCPToolExecutionTimeout > 0 {
|
||||
updatedConfig.MCPToolExecutionTimeout = payload.ClientConfig.MCPToolExecutionTimeout
|
||||
}
|
||||
// Only update MCPCodeModeBindingLevel if payload is non-empty to avoid clearing stored value
|
||||
if payload.ClientConfig.MCPCodeModeBindingLevel != "" {
|
||||
updatedConfig.MCPCodeModeBindingLevel = payload.ClientConfig.MCPCodeModeBindingLevel
|
||||
}
|
||||
|
||||
// Only update AsyncJobResultTTL if explicitly provided (> 0) to avoid clearing stored value
|
||||
if payload.ClientConfig.AsyncJobResultTTL > 0 {
|
||||
updatedConfig.AsyncJobResultTTL = payload.ClientConfig.AsyncJobResultTTL
|
||||
}
|
||||
|
||||
// Handle RequiredHeaders changes (no restart needed - governance plugin reads via pointer)
|
||||
updatedConfig.RequiredHeaders = payload.ClientConfig.RequiredHeaders
|
||||
|
||||
// Handle LoggingHeaders changes (no restart needed - logging plugin reads via pointer)
|
||||
updatedConfig.LoggingHeaders = payload.ClientConfig.LoggingHeaders
|
||||
|
||||
// Handle WhitelistedRoutes changes (updated dynamically via AuthMiddleware)
|
||||
updatedConfig.WhitelistedRoutes = payload.ClientConfig.WhitelistedRoutes
|
||||
|
||||
// Toggle whether deleted virtual keys should appear in logs filter data.
|
||||
updatedConfig.HideDeletedVirtualKeysInFilters = payload.ClientConfig.HideDeletedVirtualKeysInFilters
|
||||
|
||||
// No restart needed - routing engine reads via pointer, change is effective immediately.
|
||||
if payload.ClientConfig.RoutingChainMaxDepth > 0 {
|
||||
updatedConfig.RoutingChainMaxDepth = payload.ClientConfig.RoutingChainMaxDepth
|
||||
}
|
||||
|
||||
// Handle HeaderFilterConfig changes
|
||||
if !headerFilterConfigEqual(payload.ClientConfig.HeaderFilterConfig, currentConfig.HeaderFilterConfig) {
|
||||
// Validate that no security headers are in the allowlist or denylist
|
||||
if err := validateHeaderFilterConfig(payload.ClientConfig.HeaderFilterConfig); err != nil {
|
||||
logger.Warn("invalid header filter config: %v", err)
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
updatedConfig.HeaderFilterConfig = payload.ClientConfig.HeaderFilterConfig
|
||||
if err := h.configManager.ReloadHeaderFilterConfig(ctx, payload.ClientConfig.HeaderFilterConfig); err != nil {
|
||||
logger.Warn("failed to reload header filter config: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to reload header filter config: %v", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Validate LogRetentionDays
|
||||
if payload.ClientConfig.LogRetentionDays < 1 {
|
||||
logger.Warn("log_retention_days must be at least 1")
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "log_retention_days must be at least 1")
|
||||
return
|
||||
}
|
||||
updatedConfig.LogRetentionDays = payload.ClientConfig.LogRetentionDays
|
||||
|
||||
// Update the store with the new config
|
||||
h.store.ClientConfig = updatedConfig
|
||||
|
||||
if err := h.store.ConfigStore.UpdateClientConfig(ctx, updatedConfig); err != nil {
|
||||
logger.Warn("failed to save configuration: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to save configuration: %v", err))
|
||||
return
|
||||
}
|
||||
// Reloading client config from config store
|
||||
if err := h.configManager.ReloadClientConfigFromConfigStore(ctx); err != nil {
|
||||
logger.Warn("failed to reload client config from config store: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to reload client config from config store: %v", err))
|
||||
return
|
||||
}
|
||||
// Fetching existing framework config
|
||||
frameworkConfig, err := h.store.ConfigStore.GetFrameworkConfig(ctx)
|
||||
if err != nil {
|
||||
logger.Warn("failed to get framework config from store: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to get framework config from store: %v", err))
|
||||
return
|
||||
}
|
||||
// if framework config is nil, we will use the default pricing config
|
||||
if frameworkConfig == nil {
|
||||
frameworkConfig = &configstoreTables.TableFrameworkConfig{
|
||||
ID: 0,
|
||||
PricingURL: bifrost.Ptr(modelcatalog.DefaultPricingURL),
|
||||
PricingSyncInterval: bifrost.Ptr(int64(modelcatalog.DefaultSyncInterval.Seconds())),
|
||||
}
|
||||
}
|
||||
// Handling individual nil cases
|
||||
if frameworkConfig.PricingURL == nil {
|
||||
frameworkConfig.PricingURL = bifrost.Ptr(modelcatalog.DefaultPricingURL)
|
||||
}
|
||||
if frameworkConfig.PricingSyncInterval == nil {
|
||||
frameworkConfig.PricingSyncInterval = bifrost.Ptr(int64(modelcatalog.DefaultSyncInterval.Seconds()))
|
||||
}
|
||||
// Updating framework config
|
||||
shouldReloadFrameworkConfig := false
|
||||
if payload.FrameworkConfig.PricingURL != nil && *payload.FrameworkConfig.PricingURL != *frameworkConfig.PricingURL {
|
||||
// Checking the accessibility of the pricing URL
|
||||
resp, err := http.Get(*payload.FrameworkConfig.PricingURL)
|
||||
if err != nil {
|
||||
logger.Warn("failed to check the accessibility of the pricing URL: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to check the accessibility of the pricing URL: %v", err))
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.Warn("failed to check the accessibility of the pricing URL: %v", resp.StatusCode)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to check the accessibility of the pricing URL: %v", resp.StatusCode))
|
||||
return
|
||||
}
|
||||
frameworkConfig.PricingURL = payload.FrameworkConfig.PricingURL
|
||||
shouldReloadFrameworkConfig = true
|
||||
}
|
||||
if payload.FrameworkConfig.PricingSyncInterval != nil {
|
||||
syncInterval := int64(*payload.FrameworkConfig.PricingSyncInterval)
|
||||
if syncInterval != *frameworkConfig.PricingSyncInterval {
|
||||
frameworkConfig.PricingSyncInterval = &syncInterval
|
||||
shouldReloadFrameworkConfig = true
|
||||
}
|
||||
}
|
||||
// Reload config if required
|
||||
if shouldReloadFrameworkConfig {
|
||||
var syncSeconds int64
|
||||
if frameworkConfig.PricingSyncInterval != nil {
|
||||
syncSeconds = *frameworkConfig.PricingSyncInterval
|
||||
} else {
|
||||
syncSeconds = int64(modelcatalog.DefaultSyncInterval.Seconds())
|
||||
}
|
||||
h.store.FrameworkConfig = &framework.FrameworkConfig{
|
||||
Pricing: &modelcatalog.Config{
|
||||
PricingURL: frameworkConfig.PricingURL,
|
||||
PricingSyncInterval: &syncSeconds,
|
||||
},
|
||||
}
|
||||
// Saving framework config
|
||||
if err := h.store.ConfigStore.UpdateFrameworkConfig(ctx, frameworkConfig); err != nil {
|
||||
logger.Warn("failed to save framework configuration: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to save framework configuration: %v", err))
|
||||
return
|
||||
}
|
||||
// Reloading pricing manager
|
||||
h.configManager.UpdateSyncConfig(ctx)
|
||||
}
|
||||
// Checking auth config and trying to update if required
|
||||
if payload.AuthConfig != nil {
|
||||
// Getting current governance config
|
||||
authConfig, err := h.store.ConfigStore.GetAuthConfig(ctx)
|
||||
if err != nil {
|
||||
if !errors.Is(err, configstore.ErrNotFound) {
|
||||
logger.Warn("failed to get auth config from store: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to get auth config from store: %v", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check if auth config has changed
|
||||
authChanged := false
|
||||
if authConfig == nil {
|
||||
// No existing config, any enabled state is a change
|
||||
if payload.AuthConfig.IsEnabled {
|
||||
authChanged = true
|
||||
}
|
||||
} else {
|
||||
// Compare with existing config using value comparison (not pointer comparison)
|
||||
// Password is considered changed only if it's NOT redacted and has a value
|
||||
// (IsRedacted() returns true for <redacted>, asterisk patterns, and env var references)
|
||||
passwordChanged := payload.AuthConfig.AdminPassword != nil &&
|
||||
!payload.AuthConfig.AdminPassword.IsRedacted() &&
|
||||
payload.AuthConfig.AdminPassword.GetValue() != ""
|
||||
usernameChanged := payload.AuthConfig.AdminUserName != nil &&
|
||||
!payload.AuthConfig.AdminUserName.Equals(authConfig.AdminUserName)
|
||||
if payload.AuthConfig.IsEnabled != authConfig.IsEnabled ||
|
||||
usernameChanged ||
|
||||
passwordChanged {
|
||||
authChanged = true
|
||||
}
|
||||
}
|
||||
|
||||
if payload.AuthConfig.IsEnabled {
|
||||
// Initialize nil pointers to empty EnvVar to prevent nil-pointer dereference
|
||||
if payload.AuthConfig.AdminUserName == nil {
|
||||
payload.AuthConfig.AdminUserName = &schemas.EnvVar{}
|
||||
}
|
||||
if payload.AuthConfig.AdminPassword == nil {
|
||||
payload.AuthConfig.AdminPassword = &schemas.EnvVar{}
|
||||
}
|
||||
|
||||
// Validate env variables are set if referenced
|
||||
if payload.AuthConfig.AdminUserName.IsFromEnv() && payload.AuthConfig.AdminUserName.GetValue() == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("environment variable %s is not set", payload.AuthConfig.AdminUserName.EnvVar))
|
||||
return
|
||||
}
|
||||
if payload.AuthConfig.AdminPassword.IsFromEnv() && payload.AuthConfig.AdminPassword.GetValue() == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("environment variable %s is not set", payload.AuthConfig.AdminPassword.EnvVar))
|
||||
return
|
||||
}
|
||||
|
||||
if authConfig == nil && (payload.AuthConfig.AdminUserName.GetValue() == "" || payload.AuthConfig.AdminPassword.GetValue() == "") {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "auth username and password must be provided")
|
||||
return
|
||||
}
|
||||
// Fetching current Auth config
|
||||
if payload.AuthConfig.AdminUserName.GetValue() != "" {
|
||||
if payload.AuthConfig.AdminPassword.IsRedacted() {
|
||||
if authConfig == nil || authConfig.AdminPassword.GetValue() == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "auth password must be provided")
|
||||
return
|
||||
}
|
||||
// Assuming that password hasn't been changed
|
||||
payload.AuthConfig.AdminPassword = authConfig.AdminPassword
|
||||
} else {
|
||||
// Password has been changed
|
||||
// We will hash the password
|
||||
hashedPassword, err := encrypt.Hash(payload.AuthConfig.AdminPassword.GetValue())
|
||||
if err != nil {
|
||||
logger.Warn("failed to hash password: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to hash password: %v", err))
|
||||
return
|
||||
}
|
||||
// Preserve env-var metadata when storing hashed password
|
||||
payload.AuthConfig.AdminPassword = &schemas.EnvVar{
|
||||
Val: hashedPassword,
|
||||
FromEnv: payload.AuthConfig.AdminPassword.IsFromEnv(),
|
||||
EnvVar: payload.AuthConfig.AdminPassword.EnvVar,
|
||||
}
|
||||
}
|
||||
}
|
||||
// Save auth config - this handles both first-time creation and updates
|
||||
err = h.configManager.UpdateAuthConfig(ctx, payload.AuthConfig)
|
||||
if err != nil {
|
||||
logger.Warn("failed to update auth config: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to update auth config: %v", err))
|
||||
return
|
||||
}
|
||||
} else if authConfig != nil {
|
||||
// Auth is being disabled but there's an existing config - preserve credentials and update disabled state
|
||||
if payload.AuthConfig.AdminPassword == nil || payload.AuthConfig.AdminPassword.IsRedacted() || payload.AuthConfig.AdminPassword.GetValue() == "" {
|
||||
payload.AuthConfig.AdminPassword = authConfig.AdminPassword
|
||||
}
|
||||
if payload.AuthConfig.AdminUserName == nil || payload.AuthConfig.AdminUserName.GetValue() == "" {
|
||||
payload.AuthConfig.AdminUserName = authConfig.AdminUserName
|
||||
}
|
||||
err = h.configManager.UpdateAuthConfig(ctx, payload.AuthConfig)
|
||||
if err != nil {
|
||||
logger.Warn("failed to update auth config: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to update auth config: %v", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Flush all existing sessions if auth details have been changed
|
||||
if authChanged {
|
||||
if err := h.store.ConfigStore.FlushSessions(ctx); err != nil {
|
||||
logger.Warn("updated auth config but failed to flush existing sessions, please restart the server: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("updated auth config but failed to flush existing sessions, please restart the server: %v", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
// Note: AuthMiddleware is updated via ServerCallbacks.UpdateAuthConfig (handled by BifrostHTTPServer)
|
||||
}
|
||||
|
||||
// Set restart required flag if any restart-requiring configs changed
|
||||
if len(restartReasons) > 0 {
|
||||
reason := fmt.Sprintf("%s settings have been updated. A restart is required for changes to take full effect.", strings.Join(restartReasons, ", "))
|
||||
if err := h.store.ConfigStore.SetRestartRequiredConfig(ctx, &configstoreTables.RestartRequiredConfig{
|
||||
Required: true,
|
||||
Reason: reason,
|
||||
}); err != nil {
|
||||
logger.Warn("failed to set restart required config: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
SendJSON(ctx, map[string]any{
|
||||
"status": "success",
|
||||
"message": "configuration updated successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// forceSyncPricing triggers an immediate pricing sync and resets the pricing sync timer
|
||||
func (h *ConfigHandler) forceSyncPricing(ctx *fasthttp.RequestCtx) {
|
||||
if h.store.ConfigStore == nil {
|
||||
SendError(ctx, fasthttp.StatusServiceUnavailable, "config store not available")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.configManager.ForceReloadPricing(ctx); err != nil {
|
||||
logger.Warn("failed to force pricing sync: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to force pricing sync: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
SendJSON(ctx, map[string]any{
|
||||
"status": "success",
|
||||
"message": "pricing sync triggered",
|
||||
})
|
||||
}
|
||||
|
||||
// getProxyConfig handles GET /api/proxy-config - Get the current proxy configuration
|
||||
func (h *ConfigHandler) getProxyConfig(ctx *fasthttp.RequestCtx) {
|
||||
if h.store.ConfigStore == nil {
|
||||
SendError(ctx, fasthttp.StatusServiceUnavailable, "config store not available")
|
||||
return
|
||||
}
|
||||
proxyConfig, err := h.store.ConfigStore.GetProxyConfig(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to get proxy config: %v", err))
|
||||
return
|
||||
}
|
||||
if proxyConfig == nil {
|
||||
// Return default empty config
|
||||
SendJSON(ctx, configstoreTables.GlobalProxyConfig{
|
||||
Enabled: false,
|
||||
Type: network.GlobalProxyTypeHTTP,
|
||||
})
|
||||
return
|
||||
}
|
||||
// Redact password if present
|
||||
if proxyConfig.Password != "" {
|
||||
proxyConfig.Password = "<redacted>"
|
||||
}
|
||||
SendJSON(ctx, proxyConfig)
|
||||
}
|
||||
|
||||
// updateProxyConfig handles PUT /api/proxy-config - Update the proxy configuration
|
||||
func (h *ConfigHandler) updateProxyConfig(ctx *fasthttp.RequestCtx) {
|
||||
if h.store.ConfigStore == nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "config store not initialized")
|
||||
return
|
||||
}
|
||||
|
||||
var payload configstoreTables.GlobalProxyConfig
|
||||
if err := json.Unmarshal(ctx.PostBody(), &payload); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("invalid request format: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Validate proxy config
|
||||
if payload.Enabled {
|
||||
// Validate proxy type
|
||||
switch payload.Type {
|
||||
case network.GlobalProxyTypeHTTP:
|
||||
// HTTP proxy is supported
|
||||
// Make sure the URL is provided
|
||||
if payload.URL == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "proxy URL is required when proxy is enabled")
|
||||
return
|
||||
}
|
||||
// Validate timeout if provided
|
||||
if payload.Timeout < 0 {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "proxy timeout must be non-negative")
|
||||
return
|
||||
}
|
||||
case network.GlobalProxyTypeSOCKS5, network.GlobalProxyTypeTCP:
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("proxy type %s is not yet supported", payload.Type))
|
||||
return
|
||||
default:
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("invalid proxy type: %s", payload.Type))
|
||||
return
|
||||
}
|
||||
|
||||
// Validate URL is provided when enabled
|
||||
if payload.URL == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "proxy URL is required when proxy is enabled")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate timeout if provided
|
||||
if payload.Timeout < 0 {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "proxy timeout must be non-negative")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Handle password - if it's "<redacted>", keep the existing password
|
||||
if payload.Password == "<redacted>" {
|
||||
existingConfig, err := h.store.ConfigStore.GetProxyConfig(ctx)
|
||||
if err != nil && !errors.Is(err, configstore.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to get existing proxy config: %v", err))
|
||||
return
|
||||
}
|
||||
if existingConfig != nil {
|
||||
payload.Password = existingConfig.Password
|
||||
} else {
|
||||
payload.Password = ""
|
||||
}
|
||||
}
|
||||
|
||||
// Save proxy config
|
||||
if err := h.store.ConfigStore.UpdateProxyConfig(ctx, &payload); err != nil {
|
||||
logger.Warn("failed to save proxy configuration: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to save proxy configuration: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Pulling the proxy config from the config store
|
||||
newProxyConfig, err := h.store.ConfigStore.GetProxyConfig(ctx)
|
||||
if err != nil {
|
||||
logger.Warn("failed to get proxy config from store: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to get proxy config from store: %v", err))
|
||||
return
|
||||
}
|
||||
if newProxyConfig == nil {
|
||||
newProxyConfig = &configstoreTables.GlobalProxyConfig{
|
||||
Enabled: false,
|
||||
Type: network.GlobalProxyTypeHTTP,
|
||||
URL: "",
|
||||
Username: "",
|
||||
Password: "",
|
||||
NoProxy: "",
|
||||
Timeout: 0,
|
||||
SkipTLSVerify: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Reload proxy config in the server
|
||||
if err := h.configManager.ReloadProxyConfig(ctx, newProxyConfig); err != nil {
|
||||
logger.Warn("failed to reload proxy config: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to reload proxy config: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Set restart required flag for proxy config changes
|
||||
if err := h.store.ConfigStore.SetRestartRequiredConfig(ctx, &configstoreTables.RestartRequiredConfig{
|
||||
Required: true,
|
||||
Reason: "Proxy configuration has been updated. A restart is required for all changes to take full effect.",
|
||||
}); err != nil {
|
||||
logger.Warn("failed to set restart required config: %v", err)
|
||||
}
|
||||
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
SendJSON(ctx, map[string]any{
|
||||
"status": "success",
|
||||
"message": "proxy configuration updated successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// headerFilterConfigEqual compares two GlobalHeaderFilterConfig for equality
|
||||
func headerFilterConfigEqual(a, b *configstoreTables.GlobalHeaderFilterConfig) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
return slices.Equal(a.Allowlist, b.Allowlist) && slices.Equal(a.Denylist, b.Denylist)
|
||||
}
|
||||
|
||||
// validateHeaderFilterConfig validates that no exact security header names are in the allowlist or denylist
|
||||
// and that wildcard patterns use valid syntax (only trailing * is supported).
|
||||
// Wildcard patterns that would match security headers are allowed because security headers
|
||||
// are unconditionally stripped at runtime regardless of configuration.
|
||||
// Returns an error if any exact security headers are found or patterns are invalid.
|
||||
func validateHeaderFilterConfig(config *configstoreTables.GlobalHeaderFilterConfig) error {
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate pattern syntax and normalize entries (trim, lowercase, drop empties)
|
||||
filteredAllow := config.Allowlist[:0]
|
||||
for _, header := range config.Allowlist {
|
||||
h := strings.ToLower(strings.TrimSpace(header))
|
||||
if h == "" {
|
||||
continue
|
||||
}
|
||||
if idx := strings.Index(h, "*"); idx != -1 && idx != len(h)-1 {
|
||||
return fmt.Errorf("invalid pattern %q: wildcard (*) is only supported at the end of a pattern", h)
|
||||
}
|
||||
filteredAllow = append(filteredAllow, h)
|
||||
}
|
||||
config.Allowlist = filteredAllow
|
||||
filteredDeny := config.Denylist[:0]
|
||||
for _, header := range config.Denylist {
|
||||
h := strings.ToLower(strings.TrimSpace(header))
|
||||
if h == "" {
|
||||
continue
|
||||
}
|
||||
if idx := strings.Index(h, "*"); idx != -1 && idx != len(h)-1 {
|
||||
return fmt.Errorf("invalid pattern %q: wildcard (*) is only supported at the end of a pattern", h)
|
||||
}
|
||||
filteredDeny = append(filteredDeny, h)
|
||||
}
|
||||
config.Denylist = filteredDeny
|
||||
|
||||
var foundSecurityHeaders []string
|
||||
|
||||
// Check allowlist for exact security header names.
|
||||
// Wildcard patterns are allowed — security headers are always stripped at runtime
|
||||
// unconditionally in ctx.go, regardless of allowlist/denylist configuration.
|
||||
for _, header := range config.Allowlist {
|
||||
headerLower := strings.ToLower(strings.TrimSpace(header))
|
||||
if strings.Contains(headerLower, "*") {
|
||||
continue
|
||||
}
|
||||
if slices.Contains(securityHeaders, headerLower) {
|
||||
foundSecurityHeaders = append(foundSecurityHeaders, headerLower)
|
||||
}
|
||||
}
|
||||
|
||||
// Check denylist for exact security header names.
|
||||
for _, header := range config.Denylist {
|
||||
headerLower := strings.ToLower(strings.TrimSpace(header))
|
||||
if strings.Contains(headerLower, "*") {
|
||||
continue
|
||||
}
|
||||
if slices.Contains(securityHeaders, headerLower) && !slices.Contains(foundSecurityHeaders, headerLower) {
|
||||
foundSecurityHeaders = append(foundSecurityHeaders, headerLower)
|
||||
}
|
||||
}
|
||||
|
||||
if len(foundSecurityHeaders) > 0 {
|
||||
return fmt.Errorf("the following headers are not allowed to be configured: %s. These headers are security headers and are always blocked", strings.Join(foundSecurityHeaders, ", "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
198
transports/bifrost-http/handlers/config_headerfilter_test.go
Normal file
198
transports/bifrost-http/handlers/config_headerfilter_test.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
)
|
||||
|
||||
func TestValidateHeaderFilterConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *configstoreTables.GlobalHeaderFilterConfig
|
||||
wantErr bool
|
||||
errSubstr string
|
||||
}{
|
||||
{
|
||||
name: "nil config",
|
||||
config: nil,
|
||||
},
|
||||
{
|
||||
name: "empty lists",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{},
|
||||
},
|
||||
{
|
||||
name: "empty allowlist and denylist slices",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{},
|
||||
Denylist: []string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid allowlist patterns",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"anthropic-beta", "x-custom-*", "content-type"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid denylist patterns",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Denylist: []string{"x-internal-*", "x-debug"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid allowlist and denylist together",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"anthropic-*", "content-type"},
|
||||
Denylist: []string{"x-internal-*"},
|
||||
},
|
||||
},
|
||||
// Empty/whitespace entries should be silently dropped, not cause errors
|
||||
{
|
||||
name: "whitespace-only entries in allowlist are dropped",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{" ", "anthropic-beta", ""},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "whitespace-only entries in denylist are dropped",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Denylist: []string{"", "x-debug", " "},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all-empty allowlist becomes effectively empty",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"", " ", "\t"},
|
||||
},
|
||||
},
|
||||
// Security header checks
|
||||
{
|
||||
name: "security header in allowlist rejected",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"authorization"},
|
||||
},
|
||||
wantErr: true,
|
||||
errSubstr: "not allowed to be configured",
|
||||
},
|
||||
{
|
||||
name: "security header in denylist rejected",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Denylist: []string{"x-api-key"},
|
||||
},
|
||||
wantErr: true,
|
||||
errSubstr: "not allowed to be configured",
|
||||
},
|
||||
{
|
||||
name: "wildcard matching security header allowed (runtime strips security headers)",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"authorization*"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "wildcard prefix matching security headers allowed (runtime strips security headers)",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"x-api-*"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bare wildcard in allowlist allowed (runtime strips security headers)",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"*"},
|
||||
},
|
||||
},
|
||||
// Invalid wildcard syntax
|
||||
{
|
||||
name: "wildcard in middle of pattern rejected",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"x-*-header"},
|
||||
},
|
||||
wantErr: true,
|
||||
errSubstr: "wildcard (*) is only supported at the end",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateHeaderFilterConfig(tt.config)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error containing %q, got nil", tt.errSubstr)
|
||||
}
|
||||
if tt.errSubstr != "" && !contains(err.Error(), tt.errSubstr) {
|
||||
t.Fatalf("expected error containing %q, got %q", tt.errSubstr, err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateHeaderFilterConfig_EmptyEntriesDropped(t *testing.T) {
|
||||
// Verify that empty/whitespace entries are actually removed from the stored config
|
||||
config := &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{" ", "anthropic-beta", "", "content-type", "\t"},
|
||||
Denylist: []string{"", "x-debug", " "},
|
||||
}
|
||||
if err := validateHeaderFilterConfig(config); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(config.Allowlist) != 2 {
|
||||
t.Fatalf("expected allowlist length 2, got %d: %v", len(config.Allowlist), config.Allowlist)
|
||||
}
|
||||
if config.Allowlist[0] != "anthropic-beta" || config.Allowlist[1] != "content-type" {
|
||||
t.Fatalf("unexpected allowlist: %v", config.Allowlist)
|
||||
}
|
||||
if len(config.Denylist) != 1 {
|
||||
t.Fatalf("expected denylist length 1, got %d: %v", len(config.Denylist), config.Denylist)
|
||||
}
|
||||
if config.Denylist[0] != "x-debug" {
|
||||
t.Fatalf("unexpected denylist: %v", config.Denylist)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateHeaderFilterConfig_EmptyConfigStillForwardsHeaders verifies that when
|
||||
// all entries are empty/whitespace, validation strips them and the compiled matcher
|
||||
// allows all headers through (same behavior as no config — x-bf-eh-* headers forwarded as-is).
|
||||
func TestValidateHeaderFilterConfig_EmptyConfigStillForwardsHeaders(t *testing.T) {
|
||||
// Config where all entries are whitespace-only
|
||||
config := &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"", " ", "\t"},
|
||||
Denylist: []string{"", " "},
|
||||
}
|
||||
if err := validateHeaderFilterConfig(config); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
// After validation, both lists should be empty
|
||||
if len(config.Allowlist) != 0 {
|
||||
t.Fatalf("expected empty allowlist, got %v", config.Allowlist)
|
||||
}
|
||||
if len(config.Denylist) != 0 {
|
||||
t.Fatalf("expected empty denylist, got %v", config.Denylist)
|
||||
}
|
||||
// Compile the validated config into a matcher — should allow everything
|
||||
m := lib.NewHeaderMatcher(config)
|
||||
// Matcher with empty lists should allow all headers (x-bf-eh-* forwarded as-is)
|
||||
for _, header := range []string{"anthropic-beta", "x-custom-header", "content-type", "x-anything"} {
|
||||
if !m.ShouldAllow(header) {
|
||||
t.Errorf("expected header %q to be allowed with empty config, but it was denied", header)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && searchString(s, substr)
|
||||
}
|
||||
|
||||
func searchString(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
778
transports/bifrost-http/handlers/devpprof.go
Normal file
778
transports/bifrost-http/handlers/devpprof.go
Normal file
@@ -0,0 +1,778 @@
|
||||
//go:build dev
|
||||
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/google/pprof/profile"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
const (
|
||||
// Collection interval for metrics
|
||||
metricsCollectionInterval = 10 * time.Second
|
||||
// Number of data points to keep (5 minutes / 10 seconds = 30 points)
|
||||
historySize = 30
|
||||
// Top allocations to return per table (cumulative and in-use)
|
||||
topAllocationsCount = 50
|
||||
)
|
||||
|
||||
// MemoryStats represents memory statistics at a point in time
|
||||
type MemoryStats struct {
|
||||
Alloc uint64 `json:"alloc"`
|
||||
TotalAlloc uint64 `json:"total_alloc"`
|
||||
HeapInuse uint64 `json:"heap_inuse"`
|
||||
HeapObjects uint64 `json:"heap_objects"`
|
||||
Sys uint64 `json:"sys"`
|
||||
}
|
||||
|
||||
// CPUStats represents CPU statistics
|
||||
type CPUStats struct {
|
||||
UsagePercent float64 `json:"usage_percent"`
|
||||
UserTime float64 `json:"user_time"`
|
||||
SystemTime float64 `json:"system_time"`
|
||||
}
|
||||
|
||||
// RuntimeStats represents runtime statistics
|
||||
type RuntimeStats struct {
|
||||
NumGoroutine int `json:"num_goroutine"`
|
||||
NumGC uint32 `json:"num_gc"`
|
||||
GCPauseNs uint64 `json:"gc_pause_ns"`
|
||||
NumCPU int `json:"num_cpu"`
|
||||
GOMAXPROCS int `json:"gomaxprocs"`
|
||||
}
|
||||
|
||||
// AllocationInfo represents a single allocation site
|
||||
type AllocationInfo struct {
|
||||
Function string `json:"function"`
|
||||
File string `json:"file"`
|
||||
Line int `json:"line"`
|
||||
Bytes int64 `json:"bytes"`
|
||||
Count int64 `json:"count"`
|
||||
Stack []string `json:"stack"`
|
||||
}
|
||||
|
||||
// GoroutineGroup represents a group of goroutines with the same stack trace
|
||||
type GoroutineGroup struct {
|
||||
Count int `json:"count"`
|
||||
State string `json:"state"`
|
||||
WaitReason string `json:"wait_reason,omitempty"`
|
||||
WaitMinutes int `json:"wait_minutes,omitempty"` // Parsed wait time in minutes
|
||||
TopFunc string `json:"top_func"`
|
||||
Stack []string `json:"stack"`
|
||||
Category string `json:"category"` // "background", "per-request", "unknown"
|
||||
}
|
||||
|
||||
// GoroutineProfile represents the goroutine profile response
|
||||
type GoroutineProfile struct {
|
||||
Timestamp string `json:"timestamp"`
|
||||
TotalGoroutines int `json:"total_goroutines"`
|
||||
Groups []GoroutineGroup `json:"groups"`
|
||||
Summary GoroutineSummary `json:"summary"`
|
||||
RawProfile string `json:"raw_profile,omitempty"`
|
||||
}
|
||||
|
||||
// GoroutineSummary provides a quick overview of goroutine health
|
||||
type GoroutineSummary struct {
|
||||
Background int `json:"background"` // Expected long-running goroutines
|
||||
PerRequest int `json:"per_request"` // Goroutines that should complete with requests
|
||||
LongWaiting int `json:"long_waiting"` // Goroutines waiting > 1 minute (potential leaks)
|
||||
PotentiallyStuck int `json:"potentially_stuck"` // Per-request goroutines waiting > 1 minute
|
||||
}
|
||||
|
||||
// HistoryPoint represents a single point in the metrics history
|
||||
type HistoryPoint struct {
|
||||
Timestamp string `json:"timestamp"`
|
||||
Alloc uint64 `json:"alloc"`
|
||||
HeapInuse uint64 `json:"heap_inuse"`
|
||||
Goroutines int `json:"goroutines"`
|
||||
GCPauseNs uint64 `json:"gc_pause_ns"`
|
||||
CPUPercent float64 `json:"cpu_percent"`
|
||||
}
|
||||
|
||||
// PprofData represents the complete pprof response
|
||||
type PprofData struct {
|
||||
Timestamp string `json:"timestamp"`
|
||||
Memory MemoryStats `json:"memory"`
|
||||
CPU CPUStats `json:"cpu"`
|
||||
Runtime RuntimeStats `json:"runtime"`
|
||||
TopAllocations []AllocationInfo `json:"top_allocations"`
|
||||
InuseAllocations []AllocationInfo `json:"inuse_allocations"`
|
||||
History []HistoryPoint `json:"history"`
|
||||
}
|
||||
|
||||
// cpuSample holds a CPU time sample for calculating usage
|
||||
type cpuSample struct {
|
||||
timestamp time.Time
|
||||
userTime time.Duration
|
||||
systemTime time.Duration
|
||||
}
|
||||
|
||||
// MetricsCollector collects and stores runtime metrics
|
||||
type MetricsCollector struct {
|
||||
mu sync.RWMutex
|
||||
history []HistoryPoint
|
||||
stopCh chan struct{}
|
||||
started bool
|
||||
lastCPUSample cpuSample
|
||||
currentCPU CPUStats
|
||||
}
|
||||
|
||||
// DevPprofHandler handles development profiling endpoints
|
||||
type DevPprofHandler struct {
|
||||
collector *MetricsCollector
|
||||
}
|
||||
|
||||
// Global collector instance
|
||||
var globalCollector *MetricsCollector
|
||||
var collectorOnce sync.Once
|
||||
|
||||
// IsDevMode checks if dev mode is enabled via environment variable
|
||||
func IsDevMode() bool {
|
||||
return os.Getenv("BIFROST_UI_DEV") == "true"
|
||||
}
|
||||
|
||||
// getOrCreateCollector returns the global metrics collector, creating it if needed
|
||||
func getOrCreateCollector() *MetricsCollector {
|
||||
collectorOnce.Do(func() {
|
||||
globalCollector = &MetricsCollector{
|
||||
history: make([]HistoryPoint, 0, historySize),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
})
|
||||
return globalCollector
|
||||
}
|
||||
|
||||
// NewDevPprofHandler creates a new dev pprof handler
|
||||
func NewDevPprofHandler() *DevPprofHandler {
|
||||
return &DevPprofHandler{
|
||||
collector: getOrCreateCollector(),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the background metrics collection
|
||||
func (c *MetricsCollector) Start() {
|
||||
c.mu.Lock()
|
||||
if c.started {
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
c.stopCh = make(chan struct{})
|
||||
c.started = true
|
||||
c.mu.Unlock()
|
||||
|
||||
go c.collectLoop()
|
||||
}
|
||||
|
||||
// Stop stops the background metrics collection
|
||||
func (c *MetricsCollector) Stop() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if !c.started {
|
||||
return
|
||||
}
|
||||
close(c.stopCh)
|
||||
c.stopCh = nil
|
||||
c.started = false
|
||||
}
|
||||
|
||||
func (c *MetricsCollector) collectLoop() {
|
||||
// Initialize CPU sample
|
||||
c.lastCPUSample = getCPUSample()
|
||||
|
||||
// Wait a bit before first collection to get accurate CPU reading
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Collect immediately on start
|
||||
c.collect()
|
||||
|
||||
ticker := time.NewTicker(metricsCollectionInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
c.collect()
|
||||
case <-c.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// calculateCPUUsage calculates CPU usage percentage between two samples
|
||||
func calculateCPUUsage(prev, curr cpuSample, numCPU int) CPUStats {
|
||||
elapsed := curr.timestamp.Sub(prev.timestamp)
|
||||
if elapsed <= 0 {
|
||||
return CPUStats{}
|
||||
}
|
||||
|
||||
userDelta := curr.userTime - prev.userTime
|
||||
systemDelta := curr.systemTime - prev.systemTime
|
||||
totalCPUTime := userDelta + systemDelta
|
||||
|
||||
// Calculate percentage: (CPU time used / wall time) * 100
|
||||
// Normalized by number of CPUs to get 0-100% range
|
||||
cpuPercent := (float64(totalCPUTime) / float64(elapsed)) * 100.0
|
||||
|
||||
// Cap at 100% * numCPU (in case of measurement errors)
|
||||
maxPercent := float64(numCPU) * 100.0
|
||||
if cpuPercent > maxPercent {
|
||||
cpuPercent = maxPercent
|
||||
}
|
||||
|
||||
return CPUStats{
|
||||
UsagePercent: cpuPercent,
|
||||
UserTime: userDelta.Seconds(),
|
||||
SystemTime: systemDelta.Seconds(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MetricsCollector) collect() {
|
||||
var memStats runtime.MemStats
|
||||
runtime.ReadMemStats(&memStats)
|
||||
|
||||
// Get current CPU sample and calculate usage
|
||||
currentSample := getCPUSample()
|
||||
cpuStats := calculateCPUUsage(c.lastCPUSample, currentSample, runtime.NumCPU())
|
||||
c.lastCPUSample = currentSample
|
||||
|
||||
point := HistoryPoint{
|
||||
Timestamp: time.Now().Format(time.RFC3339),
|
||||
Alloc: memStats.Alloc,
|
||||
HeapInuse: memStats.HeapInuse,
|
||||
Goroutines: runtime.NumGoroutine(),
|
||||
GCPauseNs: memStats.PauseNs[(memStats.NumGC+255)%256],
|
||||
CPUPercent: cpuStats.UsagePercent,
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Store current CPU stats for API response
|
||||
c.currentCPU = cpuStats
|
||||
|
||||
// Append to history, maintaining ring buffer behavior
|
||||
if len(c.history) >= historySize {
|
||||
// Shift left by one and append
|
||||
copy(c.history, c.history[1:])
|
||||
c.history[len(c.history)-1] = point
|
||||
} else {
|
||||
c.history = append(c.history, point)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MetricsCollector) getHistory() []HistoryPoint {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
// Return a copy to avoid race conditions
|
||||
result := make([]HistoryPoint, len(c.history))
|
||||
copy(result, c.history)
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *MetricsCollector) getCPUStats() CPUStats {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.currentCPU
|
||||
}
|
||||
|
||||
// getAllocations analyzes the heap profile and returns two allocation lists
|
||||
// aggregated by full call stack:
|
||||
// - cumulative: alloc_space / alloc_objects (total since process start)
|
||||
// - inuse: inuse_space / inuse_objects (currently live on the heap)
|
||||
//
|
||||
// Both are produced from a single pprof.WriteHeapProfile call.
|
||||
func getAllocations() (cumulative, inuse []AllocationInfo) {
|
||||
var buf bytes.Buffer
|
||||
if err := pprof.WriteHeapProfile(&buf); err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
p, err := profile.Parse(&buf)
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
allocObjectsIdx, allocSpaceIdx := -1, -1
|
||||
inuseObjectsIdx, inuseSpaceIdx := -1, -1
|
||||
for i, st := range p.SampleType {
|
||||
switch st.Type {
|
||||
case "alloc_objects":
|
||||
allocObjectsIdx = i
|
||||
case "alloc_space":
|
||||
allocSpaceIdx = i
|
||||
case "inuse_objects":
|
||||
inuseObjectsIdx = i
|
||||
case "inuse_space":
|
||||
inuseSpaceIdx = i
|
||||
}
|
||||
}
|
||||
|
||||
allocMap := make(map[string]*AllocationInfo)
|
||||
inuseMap := make(map[string]*AllocationInfo)
|
||||
|
||||
for _, sample := range p.Sample {
|
||||
if len(sample.Location) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
topLoc := sample.Location[0]
|
||||
if len(topLoc.Line) == 0 {
|
||||
continue
|
||||
}
|
||||
topLine := topLoc.Line[0]
|
||||
topFn := topLine.Function
|
||||
if topFn == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Filter only the top frame — filtering inner frames would drop real
|
||||
// user allocations that merely pass through runtime/profiler code.
|
||||
if isProfilerFunction(topFn.Name, topFn.Filename) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Build full stack in goroutine-dump format: alternating "funcName" and
|
||||
// "\tfile:line" entries, top-down. Matches GoroutineGroup.Stack so the
|
||||
// UI can render both with the same code path.
|
||||
stack := make([]string, 0, len(sample.Location)*2)
|
||||
for _, loc := range sample.Location {
|
||||
if len(loc.Line) == 0 {
|
||||
continue
|
||||
}
|
||||
frame := loc.Line[0]
|
||||
if frame.Function == nil {
|
||||
continue
|
||||
}
|
||||
stack = append(stack, frame.Function.Name)
|
||||
stack = append(stack, "\t"+frame.Function.Filename+":"+strconv.FormatInt(frame.Line, 10))
|
||||
}
|
||||
if len(stack) == 0 {
|
||||
continue
|
||||
}
|
||||
key := strings.Join(stack, "\n")
|
||||
|
||||
if allocSpaceIdx >= 0 && allocObjectsIdx >= 0 {
|
||||
b := sample.Value[allocSpaceIdx]
|
||||
c := sample.Value[allocObjectsIdx]
|
||||
if existing, ok := allocMap[key]; ok {
|
||||
existing.Bytes += b
|
||||
existing.Count += c
|
||||
} else {
|
||||
allocMap[key] = &AllocationInfo{
|
||||
Function: topFn.Name,
|
||||
File: topFn.Filename,
|
||||
Line: int(topLine.Line),
|
||||
Bytes: b,
|
||||
Count: c,
|
||||
Stack: stack,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if inuseSpaceIdx >= 0 && inuseObjectsIdx >= 0 {
|
||||
b := sample.Value[inuseSpaceIdx]
|
||||
c := sample.Value[inuseObjectsIdx]
|
||||
// Most samples have inuse=0 (already freed) — skip them so the live
|
||||
// table isn't padded with noise.
|
||||
if b == 0 && c == 0 {
|
||||
continue
|
||||
}
|
||||
if existing, ok := inuseMap[key]; ok {
|
||||
existing.Bytes += b
|
||||
existing.Count += c
|
||||
} else {
|
||||
inuseMap[key] = &AllocationInfo{
|
||||
Function: topFn.Name,
|
||||
File: topFn.Filename,
|
||||
Line: int(topLine.Line),
|
||||
Bytes: b,
|
||||
Count: c,
|
||||
Stack: stack,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return flattenAndTopN(allocMap), flattenAndTopN(inuseMap)
|
||||
}
|
||||
|
||||
// flattenAndTopN sorts an allocation map by bytes desc and caps it.
|
||||
func flattenAndTopN(m map[string]*AllocationInfo) []AllocationInfo {
|
||||
out := make([]AllocationInfo, 0, len(m))
|
||||
for _, a := range m {
|
||||
out = append(out, *a)
|
||||
}
|
||||
sort.Slice(out, func(i, j int) bool { return out[i].Bytes > out[j].Bytes })
|
||||
if len(out) > topAllocationsCount {
|
||||
out = out[:topAllocationsCount]
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the dev pprof routes
|
||||
func (h *DevPprofHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
// Start the collector when routes are registered
|
||||
h.collector.Start()
|
||||
|
||||
r.GET("/api/dev/pprof", lib.ChainMiddlewares(h.getPprof, middlewares...))
|
||||
r.GET("/api/dev/pprof/goroutines", lib.ChainMiddlewares(h.getGoroutines, middlewares...))
|
||||
}
|
||||
|
||||
// getPprof handles GET /api/dev/pprof
|
||||
func (h *DevPprofHandler) getPprof(ctx *fasthttp.RequestCtx) {
|
||||
var memStats runtime.MemStats
|
||||
runtime.ReadMemStats(&memStats)
|
||||
|
||||
data := PprofData{
|
||||
Timestamp: time.Now().Format(time.RFC3339),
|
||||
Memory: MemoryStats{
|
||||
Alloc: memStats.Alloc,
|
||||
TotalAlloc: memStats.TotalAlloc,
|
||||
HeapInuse: memStats.HeapInuse,
|
||||
HeapObjects: memStats.HeapObjects,
|
||||
Sys: memStats.Sys,
|
||||
},
|
||||
CPU: h.collector.getCPUStats(),
|
||||
Runtime: RuntimeStats{
|
||||
NumGoroutine: runtime.NumGoroutine(),
|
||||
NumGC: memStats.NumGC,
|
||||
GCPauseNs: memStats.PauseNs[(memStats.NumGC+255)%256],
|
||||
NumCPU: runtime.NumCPU(),
|
||||
GOMAXPROCS: runtime.GOMAXPROCS(0),
|
||||
},
|
||||
History: h.collector.getHistory(),
|
||||
}
|
||||
data.TopAllocations, data.InuseAllocations = getAllocations()
|
||||
|
||||
SendJSON(ctx, data)
|
||||
}
|
||||
|
||||
// getGoroutines handles GET /api/dev/pprof/goroutines
|
||||
// Returns goroutine stack traces grouped by stack signature
|
||||
func (h *DevPprofHandler) getGoroutines(ctx *fasthttp.RequestCtx) {
|
||||
// Check if raw output is requested
|
||||
includeRaw := string(ctx.QueryArgs().Peek("raw")) == "true"
|
||||
|
||||
// Get goroutine profile
|
||||
var buf bytes.Buffer
|
||||
if err := pprof.Lookup("goroutine").WriteTo(&buf, 2); err != nil {
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
SendJSON(ctx, map[string]string{"error": "failed to get goroutine profile"})
|
||||
return
|
||||
}
|
||||
|
||||
rawProfile := buf.String()
|
||||
allGroups := parseGoroutineProfile(rawProfile)
|
||||
|
||||
// Filter out profiler goroutines and calculate summary
|
||||
groups := make([]GoroutineGroup, 0, len(allGroups))
|
||||
summary := GoroutineSummary{}
|
||||
profilerGoroutineCount := 0
|
||||
|
||||
for i := range allGroups {
|
||||
categorizeGoroutine(&allGroups[i])
|
||||
|
||||
// Skip profiler's own goroutines
|
||||
if isProfilerGoroutine(&allGroups[i]) {
|
||||
profilerGoroutineCount += allGroups[i].Count
|
||||
continue
|
||||
}
|
||||
|
||||
groups = append(groups, allGroups[i])
|
||||
|
||||
switch allGroups[i].Category {
|
||||
case "background":
|
||||
summary.Background += allGroups[i].Count
|
||||
case "per-request":
|
||||
summary.PerRequest += allGroups[i].Count
|
||||
}
|
||||
|
||||
if allGroups[i].WaitMinutes >= 1 {
|
||||
summary.LongWaiting += allGroups[i].Count
|
||||
if allGroups[i].Category == "per-request" {
|
||||
summary.PotentiallyStuck += allGroups[i].Count
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort: potentially stuck first, then by wait time, then by count
|
||||
sort.Slice(groups, func(i, j int) bool {
|
||||
// Potentially stuck (per-request + long wait) first
|
||||
iStuck := groups[i].Category == "per-request" && groups[i].WaitMinutes >= 1
|
||||
jStuck := groups[j].Category == "per-request" && groups[j].WaitMinutes >= 1
|
||||
if iStuck != jStuck {
|
||||
return iStuck
|
||||
}
|
||||
// Then by wait time
|
||||
if groups[i].WaitMinutes != groups[j].WaitMinutes {
|
||||
return groups[i].WaitMinutes > groups[j].WaitMinutes
|
||||
}
|
||||
// Then by count
|
||||
return groups[i].Count > groups[j].Count
|
||||
})
|
||||
|
||||
// Calculate app goroutines (total minus profiler goroutines)
|
||||
// Calculate total goroutines from profile snapshot
|
||||
totalFromProfile := 0
|
||||
for _, g := range groups {
|
||||
totalFromProfile += g.Count
|
||||
}
|
||||
|
||||
response := GoroutineProfile{
|
||||
Timestamp: time.Now().Format(time.RFC3339),
|
||||
TotalGoroutines: totalFromProfile,
|
||||
Groups: groups,
|
||||
Summary: summary,
|
||||
}
|
||||
|
||||
if includeRaw {
|
||||
response.RawProfile = rawProfile
|
||||
}
|
||||
|
||||
SendJSON(ctx, response)
|
||||
}
|
||||
|
||||
// categorizeGoroutine determines if a goroutine is a background worker or per-request
|
||||
func categorizeGoroutine(g *GoroutineGroup) {
|
||||
// Parse wait time from wait reason (e.g., "5 minutes" -> 5)
|
||||
g.WaitMinutes = parseWaitMinutes(g.WaitReason)
|
||||
|
||||
stackStr := strings.Join(g.Stack, " ")
|
||||
|
||||
// Background goroutines - expected to run forever
|
||||
backgroundPatterns := []string{
|
||||
"requestWorker", // Provider queue workers
|
||||
"collectLoop", // Metrics collector
|
||||
"cleanupWorker", // Various cleanup workers
|
||||
"startAccumulatorMapCleanup", // Stream accumulator cleanup
|
||||
"cleanupOldTraces", // Trace store cleanup
|
||||
"startCleanup", // Generic cleanup
|
||||
"monitorLoop", // Health monitor
|
||||
"StartHeartbeat", // WebSocket heartbeat
|
||||
"time.Sleep", // Ticker-based workers
|
||||
"runtime.gopark", // Runtime parking (often tickers)
|
||||
"sync.(*Cond).Wait", // Condition variable waits
|
||||
"net/http.(*persistConn)", // HTTP connection pool
|
||||
"internal/poll.runtime_pollWait", // Network polling
|
||||
}
|
||||
|
||||
for _, pattern := range backgroundPatterns {
|
||||
if strings.Contains(stackStr, pattern) {
|
||||
g.Category = "background"
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Per-request goroutines - should complete when request ends
|
||||
perRequestPatterns := []string{
|
||||
"PreLLMHook",
|
||||
"PostLLMHook",
|
||||
"PreMCPHook",
|
||||
"PostMCPHook",
|
||||
"HTTPTransportPreHook",
|
||||
"HTTPTransportPostHook",
|
||||
"CompleteAndFlushTrace",
|
||||
"ProcessAndSend",
|
||||
"handleProvider",
|
||||
"Inject", // Observability plugin inject
|
||||
"insertInitialLogEntry", // Logging
|
||||
"updateLogEntry", // Logging
|
||||
"retryOnNotFound",
|
||||
"BroadcastLogUpdate",
|
||||
}
|
||||
|
||||
for _, pattern := range perRequestPatterns {
|
||||
if strings.Contains(stackStr, pattern) {
|
||||
g.Category = "per-request"
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
g.Category = "unknown"
|
||||
}
|
||||
|
||||
// parseWaitMinutes extracts wait time in minutes from wait reason string
|
||||
func parseWaitMinutes(waitReason string) int {
|
||||
if waitReason == "" {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Match patterns like "5 minutes", "1 minute", "30 seconds", "2 hours"
|
||||
minuteRegex := regexp.MustCompile(`(\d+)\s*minute`)
|
||||
if matches := minuteRegex.FindStringSubmatch(waitReason); len(matches) >= 2 {
|
||||
if mins, err := strconv.Atoi(matches[1]); err == nil {
|
||||
return mins
|
||||
}
|
||||
}
|
||||
|
||||
hourRegex := regexp.MustCompile(`(\d+)\s*hour`)
|
||||
if matches := hourRegex.FindStringSubmatch(waitReason); len(matches) >= 2 {
|
||||
if hours, err := strconv.Atoi(matches[1]); err == nil {
|
||||
return hours * 60
|
||||
}
|
||||
}
|
||||
|
||||
secondRegex := regexp.MustCompile(`(\d+)\s*second`)
|
||||
if matches := secondRegex.FindStringSubmatch(waitReason); len(matches) >= 2 {
|
||||
if secs, err := strconv.Atoi(matches[1]); err == nil {
|
||||
return secs / 60 // Convert to minutes, will be 0 for < 60 seconds
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// parseGoroutineProfile parses the text output of pprof goroutine profile
|
||||
// and groups goroutines by their stack trace
|
||||
func parseGoroutineProfile(profile string) []GoroutineGroup {
|
||||
// Regex to match goroutine header: "goroutine N [state, wait reason]:"
|
||||
// Examples:
|
||||
// goroutine 1 [running]:
|
||||
// goroutine 42 [select, 5 minutes]:
|
||||
// goroutine 100 [chan receive]:
|
||||
headerRegex := regexp.MustCompile(`goroutine \d+ \[([^\]]+)\]:`)
|
||||
|
||||
// Split by "goroutine " to get individual goroutine blocks
|
||||
blocks := strings.Split(profile, "goroutine ")
|
||||
|
||||
// Map to group goroutines by stack signature
|
||||
groupMap := make(map[string]*GoroutineGroup)
|
||||
|
||||
for _, block := range blocks {
|
||||
block = strings.TrimSpace(block)
|
||||
if block == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Re-add "goroutine " prefix for regex matching
|
||||
fullBlock := "goroutine " + block
|
||||
|
||||
// Extract state from header
|
||||
matches := headerRegex.FindStringSubmatch(fullBlock)
|
||||
if len(matches) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
stateInfo := matches[1]
|
||||
state := stateInfo
|
||||
waitReason := ""
|
||||
|
||||
// Parse state and wait reason (e.g., "select, 5 minutes" -> state="select", waitReason="5 minutes")
|
||||
if idx := strings.Index(stateInfo, ","); idx != -1 {
|
||||
state = strings.TrimSpace(stateInfo[:idx])
|
||||
waitReason = strings.TrimSpace(stateInfo[idx+1:])
|
||||
}
|
||||
|
||||
// Get stack trace (everything after the header line)
|
||||
lines := strings.Split(block, "\n")
|
||||
if len(lines) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract stack frames (skip the header line which is lines[0])
|
||||
var stackLines []string
|
||||
var topFunc string
|
||||
for i := 1; i < len(lines); i++ {
|
||||
line := strings.TrimSpace(lines[i])
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
stackLines = append(stackLines, line)
|
||||
|
||||
// First function line (not a file:line) is the top function
|
||||
if topFunc == "" && !strings.HasPrefix(line, "/") && !strings.Contains(line, ".go:") {
|
||||
topFunc = line
|
||||
}
|
||||
}
|
||||
|
||||
if len(stackLines) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create a signature from the stack (top 10 frames for grouping)
|
||||
maxFrames := 10
|
||||
if len(stackLines) < maxFrames {
|
||||
maxFrames = len(stackLines)
|
||||
}
|
||||
signature := state + "|" + strings.Join(stackLines[:maxFrames], "|")
|
||||
|
||||
// Group by signature
|
||||
if existing, ok := groupMap[signature]; ok {
|
||||
existing.Count++
|
||||
} else {
|
||||
groupMap[signature] = &GoroutineGroup{
|
||||
Count: 1,
|
||||
State: state,
|
||||
WaitReason: waitReason,
|
||||
TopFunc: topFunc,
|
||||
Stack: stackLines,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert map to slice
|
||||
groups := make([]GoroutineGroup, 0, len(groupMap))
|
||||
for _, group := range groupMap {
|
||||
groups = append(groups, *group)
|
||||
}
|
||||
|
||||
return groups
|
||||
}
|
||||
|
||||
// profilerPatterns contains patterns to identify profiler-related code
|
||||
var profilerPatterns = []string{
|
||||
"devpprof",
|
||||
"pprof.WriteHeapProfile",
|
||||
"pprof.Lookup",
|
||||
"profile.Parse",
|
||||
"MetricsCollector",
|
||||
"collectLoop",
|
||||
"getAllocations",
|
||||
"flattenAndTopN",
|
||||
"parseGoroutineProfile",
|
||||
"getGoroutines",
|
||||
"getCPUSample",
|
||||
}
|
||||
|
||||
// isProfilerFunction checks if a function belongs to the profiler itself
|
||||
func isProfilerFunction(funcName, fileName string) bool {
|
||||
for _, pattern := range profilerPatterns {
|
||||
if strings.Contains(funcName, pattern) || strings.Contains(fileName, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isProfilerGoroutine checks if a goroutine belongs to the profiler
|
||||
func isProfilerGoroutine(g *GoroutineGroup) bool {
|
||||
stackStr := strings.Join(g.Stack, " ")
|
||||
for _, pattern := range profilerPatterns {
|
||||
if strings.Contains(stackStr, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Cleanup stops the metrics collector
|
||||
func (h *DevPprofHandler) Cleanup() {
|
||||
if h.collector != nil {
|
||||
h.collector.Stop()
|
||||
}
|
||||
}
|
||||
23
transports/bifrost-http/handlers/devpprof_prod.go
Normal file
23
transports/bifrost-http/handlers/devpprof_prod.go
Normal file
@@ -0,0 +1,23 @@
|
||||
//go:build !dev
|
||||
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// DevPprofHandler is a no-op stub for production builds (built without the "dev" tag).
|
||||
type DevPprofHandler struct{}
|
||||
|
||||
// IsDevMode always returns false in production builds.
|
||||
func IsDevMode() bool { return false }
|
||||
|
||||
// NewDevPprofHandler returns nil in production builds.
|
||||
func NewDevPprofHandler() *DevPprofHandler { return nil }
|
||||
|
||||
// RegisterRoutes is a no-op in production builds.
|
||||
func (h *DevPprofHandler) RegisterRoutes(_ *router.Router, _ ...schemas.BifrostHTTPMiddleware) {}
|
||||
|
||||
// Cleanup is a no-op in production builds.
|
||||
func (h *DevPprofHandler) Cleanup() {}
|
||||
26
transports/bifrost-http/handlers/devpprof_unix.go
Normal file
26
transports/bifrost-http/handlers/devpprof_unix.go
Normal file
@@ -0,0 +1,26 @@
|
||||
//go:build dev && !windows
|
||||
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// getCPUSample gets the current CPU time sample using syscall
|
||||
func getCPUSample() cpuSample {
|
||||
var rusage syscall.Rusage
|
||||
if err := syscall.Getrusage(syscall.RUSAGE_SELF, &rusage); err != nil {
|
||||
return cpuSample{timestamp: time.Now()}
|
||||
}
|
||||
|
||||
userTime := time.Duration(rusage.Utime.Sec)*time.Second + time.Duration(rusage.Utime.Usec)*time.Microsecond
|
||||
systemTime := time.Duration(rusage.Stime.Sec)*time.Second + time.Duration(rusage.Stime.Usec)*time.Microsecond
|
||||
|
||||
return cpuSample{
|
||||
timestamp: time.Now(),
|
||||
userTime: userTime,
|
||||
systemTime: systemTime,
|
||||
}
|
||||
}
|
||||
|
||||
12
transports/bifrost-http/handlers/devpprof_windows.go
Normal file
12
transports/bifrost-http/handlers/devpprof_windows.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build dev && windows
|
||||
|
||||
package handlers
|
||||
|
||||
import "time"
|
||||
|
||||
// getCPUSample returns a zeroed CPU sample on Windows
|
||||
// Windows does not support syscall.Getrusage
|
||||
func getCPUSample() cpuSample {
|
||||
return cpuSample{timestamp: time.Now()}
|
||||
}
|
||||
|
||||
3850
transports/bifrost-http/handlers/governance.go
Normal file
3850
transports/bifrost-http/handlers/governance.go
Normal file
File diff suppressed because it is too large
Load Diff
337
transports/bifrost-http/handlers/governance_test.go
Normal file
337
transports/bifrost-http/handlers/governance_test.go
Normal file
@@ -0,0 +1,337 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/maximhq/bifrost/plugins/governance"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// mockGovernanceManagerForVK embeds the interface so unimplemented methods panic.
|
||||
// Only GetGovernanceData is needed for the getVirtualKeys handler path.
|
||||
type mockGovernanceManagerForVK struct {
|
||||
GovernanceManager
|
||||
}
|
||||
|
||||
func (m *mockGovernanceManagerForVK) GetGovernanceData(ctx context.Context) *governance.GovernanceData {
|
||||
return nil
|
||||
}
|
||||
|
||||
// mockConfigStoreForVK embeds the interface so unimplemented methods panic.
|
||||
// Only GetVirtualKeysPaginated is called in the non-from_memory path.
|
||||
type mockConfigStoreForVK struct {
|
||||
configstore.ConfigStore
|
||||
}
|
||||
|
||||
func (m *mockConfigStoreForVK) GetVirtualKeysPaginated(_ context.Context, _ configstore.VirtualKeyQueryParams) ([]configstoreTables.TableVirtualKey, int64, error) {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
func (m *mockConfigStoreForVK) GetVirtualKeys(_ context.Context) ([]configstoreTables.TableVirtualKey, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// TestGetVirtualKeys_PaginatedEndpoint_ResponseShape verifies the JSON response
|
||||
// from the paginated virtual keys endpoint contains all expected fields.
|
||||
func TestGetVirtualKeys_PaginatedEndpoint_ResponseShape(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := &GovernanceHandler{
|
||||
configStore: &mockConfigStoreForVK{},
|
||||
governanceManager: &mockGovernanceManagerForVK{},
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/api/governance/virtual-keys?limit=10&offset=0")
|
||||
|
||||
h.getVirtualKeys(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != 200 {
|
||||
t.Fatalf("expected status 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
|
||||
t.Fatalf("failed to parse JSON response: %v", err)
|
||||
}
|
||||
|
||||
// Assert expected fields exist with correct types
|
||||
requiredFields := []struct {
|
||||
key string
|
||||
wantType string
|
||||
}{
|
||||
{"virtual_keys", "array"},
|
||||
{"total_count", "number"},
|
||||
{"count", "number"},
|
||||
{"limit", "number"},
|
||||
{"offset", "number"},
|
||||
}
|
||||
|
||||
for _, f := range requiredFields {
|
||||
val, ok := resp[f.key]
|
||||
if !ok {
|
||||
t.Errorf("response missing required field %q", f.key)
|
||||
continue
|
||||
}
|
||||
switch f.wantType {
|
||||
case "array":
|
||||
if _, ok := val.([]interface{}); !ok {
|
||||
// nil decodes as nil, which is fine — JSON null for empty array
|
||||
if val != nil {
|
||||
t.Errorf("field %q: expected array, got %T", f.key, val)
|
||||
}
|
||||
}
|
||||
case "number":
|
||||
if _, ok := val.(float64); !ok {
|
||||
t.Errorf("field %q: expected number, got %T", f.key, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify no unexpected extra top-level fields
|
||||
allowedKeys := map[string]bool{
|
||||
"virtual_keys": true,
|
||||
"total_count": true,
|
||||
"count": true,
|
||||
"limit": true,
|
||||
"offset": true,
|
||||
}
|
||||
for key := range resp {
|
||||
if !allowedKeys[key] {
|
||||
t.Errorf("unexpected field %q in response", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetVirtualKeys_PaginatedEndpoint_QueryParams verifies query parameters are
|
||||
// parsed and reflected in the response.
|
||||
func TestGetVirtualKeys_PaginatedEndpoint_QueryParams(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := &GovernanceHandler{
|
||||
configStore: &mockConfigStoreForVK{},
|
||||
governanceManager: &mockGovernanceManagerForVK{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
uri string
|
||||
wantLimit float64
|
||||
wantOffset float64
|
||||
}{
|
||||
{
|
||||
name: "explicit limit and offset",
|
||||
uri: "/api/governance/virtual-keys?limit=10&offset=5",
|
||||
wantLimit: 10,
|
||||
wantOffset: 5,
|
||||
},
|
||||
{
|
||||
name: "no params uses defaults",
|
||||
uri: "/api/governance/virtual-keys",
|
||||
wantLimit: 0,
|
||||
wantOffset: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI(tt.uri)
|
||||
|
||||
h.getVirtualKeys(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != 200 {
|
||||
t.Fatalf("expected status 200, got %d", ctx.Response.StatusCode())
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
|
||||
t.Fatalf("failed to parse JSON: %v", err)
|
||||
}
|
||||
|
||||
if got := resp["limit"].(float64); got != tt.wantLimit {
|
||||
t.Errorf("limit: got %v, want %v", got, tt.wantLimit)
|
||||
}
|
||||
if got := resp["offset"].(float64); got != tt.wantOffset {
|
||||
t.Errorf("offset: got %v, want %v", got, tt.wantOffset)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure mockLogger satisfies schemas.Logger (already defined in middlewares_test.go
|
||||
// but we reference it here — same package, so no redeclaration needed).
|
||||
var _ schemas.Logger = (*mockLogger)(nil)
|
||||
|
||||
func TestBudgetRemovalRequestDetection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req *UpdateBudgetRequest
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil request is not removal",
|
||||
req: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty object is removal",
|
||||
req: &UpdateBudgetRequest{},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "max limit present is not removal",
|
||||
req: &UpdateBudgetRequest{MaxLimit: bifrostFloat(10)},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "reset duration only is not removal",
|
||||
req: &UpdateBudgetRequest{ResetDuration: bifrostString("1h")},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "calendar aligned only is treated as removal",
|
||||
req: &UpdateBudgetRequest{CalendarAligned: bifrostBool(true)},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := isBudgetRemovalRequest(tt.req); got != tt.want {
|
||||
t.Fatalf("isBudgetRemovalRequest() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitRemovalRequestDetection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req *UpdateRateLimitRequest
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil request is not removal",
|
||||
req: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty object is removal",
|
||||
req: &UpdateRateLimitRequest{},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "token limit present is not removal",
|
||||
req: &UpdateRateLimitRequest{TokenMaxLimit: bifrostInt64(100)},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "request limit present is not removal",
|
||||
req: &UpdateRateLimitRequest{RequestMaxLimit: bifrostInt64(10)},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "durations only is not removal",
|
||||
req: &UpdateRateLimitRequest{
|
||||
TokenResetDuration: bifrostString("1h"),
|
||||
RequestResetDuration: bifrostString("1h"),
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := isRateLimitRemovalRequest(tt.req); got != tt.want {
|
||||
t.Fatalf("isRateLimitRemovalRequest() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectProviderConfigDeleteIDs(t *testing.T) {
|
||||
budgetID := "budget-1"
|
||||
rateLimitID := "rate-limit-1"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config configstoreTables.TableVirtualKeyProviderConfig
|
||||
initialBudgetIDs []string
|
||||
initialRateIDs []string
|
||||
wantBudgetIDs []string
|
||||
wantRateIDs []string
|
||||
}{
|
||||
{
|
||||
name: "collects both IDs",
|
||||
config: configstoreTables.TableVirtualKeyProviderConfig{
|
||||
Budgets: []configstoreTables.TableBudget{{ID: budgetID}},
|
||||
RateLimitID: &rateLimitID,
|
||||
},
|
||||
wantBudgetIDs: []string{budgetID},
|
||||
wantRateIDs: []string{rateLimitID},
|
||||
},
|
||||
{
|
||||
name: "appends to existing slices",
|
||||
config: configstoreTables.TableVirtualKeyProviderConfig{
|
||||
Budgets: []configstoreTables.TableBudget{{ID: budgetID}},
|
||||
RateLimitID: &rateLimitID,
|
||||
},
|
||||
initialBudgetIDs: []string{"budget-0"},
|
||||
initialRateIDs: []string{"rate-limit-0"},
|
||||
wantBudgetIDs: []string{"budget-0", budgetID},
|
||||
wantRateIDs: []string{"rate-limit-0", rateLimitID},
|
||||
},
|
||||
{
|
||||
name: "ignores missing IDs",
|
||||
config: configstoreTables.TableVirtualKeyProviderConfig{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotBudgetIDs, gotRateIDs := collectProviderConfigDeleteIDs(tt.config, tt.initialBudgetIDs, tt.initialRateIDs)
|
||||
|
||||
if len(gotBudgetIDs) != len(tt.wantBudgetIDs) {
|
||||
t.Fatalf("budget IDs length = %d, want %d", len(gotBudgetIDs), len(tt.wantBudgetIDs))
|
||||
}
|
||||
for i := range gotBudgetIDs {
|
||||
if gotBudgetIDs[i] != tt.wantBudgetIDs[i] {
|
||||
t.Fatalf("budget IDs[%d] = %q, want %q", i, gotBudgetIDs[i], tt.wantBudgetIDs[i])
|
||||
}
|
||||
}
|
||||
|
||||
if len(gotRateIDs) != len(tt.wantRateIDs) {
|
||||
t.Fatalf("rate limit IDs length = %d, want %d", len(gotRateIDs), len(tt.wantRateIDs))
|
||||
}
|
||||
for i := range gotRateIDs {
|
||||
if gotRateIDs[i] != tt.wantRateIDs[i] {
|
||||
t.Fatalf("rate limit IDs[%d] = %q, want %q", i, gotRateIDs[i], tt.wantRateIDs[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func bifrostFloat(v float64) *float64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
func bifrostInt64(v int64) *int64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
func bifrostString(v string) *string {
|
||||
return &v
|
||||
}
|
||||
|
||||
func bifrostBool(v bool) *bool {
|
||||
return &v
|
||||
}
|
||||
90
transports/bifrost-http/handlers/health.go
Normal file
90
transports/bifrost-http/handlers/health.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// HealthHandler manages HTTP requests for health checks.
|
||||
type HealthHandler struct {
|
||||
config *lib.Config
|
||||
}
|
||||
|
||||
// NewHealthHandler creates a new health handler instance.
|
||||
func NewHealthHandler(config *lib.Config) *HealthHandler {
|
||||
return &HealthHandler{
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the health-related routes.
|
||||
func (h *HealthHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
r.GET("/health", lib.ChainMiddlewares(h.getHealth, middlewares...))
|
||||
}
|
||||
|
||||
// getHealth handles GET /api/health - Get the health status of the server.
|
||||
func (h *HealthHandler) getHealth(ctx *fasthttp.RequestCtx) {
|
||||
// If DB pings are disabled, just return OK
|
||||
if h.config.ClientConfig.DisableDBPingsInHealth {
|
||||
SendJSON(ctx, map[string]any{"status": "ok", "components": map[string]any{"db_pings": "disabled"}})
|
||||
return
|
||||
}
|
||||
// Pinging config store
|
||||
reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
var errors []string
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
|
||||
if h.config.ConfigStore != nil {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := h.config.ConfigStore.Ping(reqCtx); err != nil {
|
||||
mu.Lock()
|
||||
errors = append(errors, "config store not available")
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Pinging log store
|
||||
if h.config.LogsStore != nil {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := h.config.LogsStore.Ping(reqCtx); err != nil {
|
||||
mu.Lock()
|
||||
errors = append(errors, "log store not available")
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Pinging vector store
|
||||
if h.config.VectorStore != nil {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := h.config.VectorStore.Ping(reqCtx); err != nil {
|
||||
mu.Lock()
|
||||
errors = append(errors, "vector store not available")
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if len(errors) > 0 {
|
||||
SendError(ctx, fasthttp.StatusServiceUnavailable, errors[0])
|
||||
return
|
||||
}
|
||||
SendJSON(ctx, map[string]any{"status": "ok", "components": map[string]any{"db_pings": "ok"}})
|
||||
}
|
||||
3807
transports/bifrost-http/handlers/inference.go
Normal file
3807
transports/bifrost-http/handlers/inference.go
Normal file
File diff suppressed because it is too large
Load Diff
20
transports/bifrost-http/handlers/init.go
Normal file
20
transports/bifrost-http/handlers/init.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package handlers
|
||||
|
||||
import "github.com/maximhq/bifrost/core/schemas"
|
||||
|
||||
var version string
|
||||
var logger schemas.Logger
|
||||
|
||||
// SetLogger sets the logger for the application.
|
||||
func SetLogger(l schemas.Logger) {
|
||||
logger = l
|
||||
}
|
||||
|
||||
// SetVersion sets the version for the application.
|
||||
func SetVersion(v string) {
|
||||
version = v
|
||||
}
|
||||
|
||||
func GetVersion() string {
|
||||
return version
|
||||
}
|
||||
111
transports/bifrost-http/handlers/integrations.go
Normal file
111
transports/bifrost-http/handlers/integrations.go
Normal file
@@ -0,0 +1,111 @@
|
||||
// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
|
||||
// This file contains integration management handlers for AI provider integrations.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"github.com/fasthttp/router"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/integrations"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
)
|
||||
|
||||
// IntegrationHandler manages HTTP requests for AI provider integrations
|
||||
type IntegrationHandler struct {
|
||||
extensions []integrations.ExtensionRouter
|
||||
wsResponses *WSResponsesHandler
|
||||
wsRealtime *WSRealtimeHandler
|
||||
webrtcRealtime *WebRTCRealtimeHandler
|
||||
realtimeClientSecrets *RealtimeClientSecretsHandler
|
||||
}
|
||||
|
||||
// NewIntegrationHandler creates a new integration handler instance.
|
||||
// WebSocket handlers may be nil if WebSocket support is not configured.
|
||||
func NewIntegrationHandler(client *bifrost.Bifrost, handlerStore lib.HandlerStore, wsResponses *WSResponsesHandler, wsRealtime *WSRealtimeHandler, webrtcRealtime *WebRTCRealtimeHandler, realtimeClientSecrets *RealtimeClientSecretsHandler) *IntegrationHandler {
|
||||
// Initialize all available integration routers
|
||||
extensions := []integrations.ExtensionRouter{
|
||||
integrations.NewOpenAIRouter(client, handlerStore, logger),
|
||||
integrations.NewAnthropicRouter(client, handlerStore, logger),
|
||||
integrations.NewGenAIRouter(client, handlerStore, logger),
|
||||
integrations.NewLiteLLMRouter(client, handlerStore, logger),
|
||||
integrations.NewCohereRouter(client, handlerStore, logger),
|
||||
integrations.NewLangChainRouter(client, handlerStore, logger),
|
||||
integrations.NewPydanticAIRouter(client, handlerStore, logger),
|
||||
integrations.NewBedrockRouter(client, handlerStore, logger),
|
||||
// passthrough routers
|
||||
integrations.NewGenAIPassthroughRouter(client, handlerStore, logger),
|
||||
integrations.NewOpenAIPassthroughRouter(client, handlerStore, logger),
|
||||
integrations.NewAnthropicPassthroughRouter(client, handlerStore, logger),
|
||||
integrations.NewAzurePassthroughRouter(client, handlerStore, logger),
|
||||
integrations.NewCursorRouter(client, handlerStore, logger),
|
||||
}
|
||||
|
||||
return &IntegrationHandler{
|
||||
extensions: extensions,
|
||||
wsResponses: wsResponses,
|
||||
wsRealtime: wsRealtime,
|
||||
webrtcRealtime: webrtcRealtime,
|
||||
realtimeClientSecrets: realtimeClientSecrets,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers all integration routes for AI provider compatibility endpoints
|
||||
func (h *IntegrationHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
// Register routes for each integration extension
|
||||
for _, extension := range h.extensions {
|
||||
extension.RegisterRoutes(r, middlewares...)
|
||||
}
|
||||
// Register WebSocket routes (base path + integration paths)
|
||||
if h.wsResponses != nil {
|
||||
h.wsResponses.RegisterRoutes(r, middlewares...)
|
||||
}
|
||||
if h.wsRealtime != nil {
|
||||
h.wsRealtime.RegisterRoutes(r, middlewares...)
|
||||
}
|
||||
if h.webrtcRealtime != nil {
|
||||
h.webrtcRealtime.RegisterRoutes(r, middlewares...)
|
||||
}
|
||||
if h.realtimeClientSecrets != nil {
|
||||
h.realtimeClientSecrets.RegisterRoutes(r, middlewares...)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *IntegrationHandler) Close() {
|
||||
if h == nil {
|
||||
return
|
||||
}
|
||||
if h.wsResponses != nil {
|
||||
h.wsResponses.Close()
|
||||
}
|
||||
if h.wsRealtime != nil {
|
||||
h.wsRealtime.Close()
|
||||
}
|
||||
if h.webrtcRealtime != nil {
|
||||
h.webrtcRealtime.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// SetLargePayloadHook sets the large payload detection hook on all integration routers
|
||||
// that support it. This is used by enterprise to inject large payload optimization.
|
||||
func (h *IntegrationHandler) SetLargePayloadHook(hook integrations.LargePayloadHook) {
|
||||
for _, extension := range h.extensions {
|
||||
if setter, ok := extension.(interface {
|
||||
SetLargePayloadHook(integrations.LargePayloadHook)
|
||||
}); ok {
|
||||
setter.SetLargePayloadHook(hook)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetLargeResponseHook sets the large response scanning hook on all integration routers
|
||||
// that support it. Enterprise uses this to inject Phase B usage extraction into the
|
||||
// response stream without embedding scanning logic in the OSS router.
|
||||
func (h *IntegrationHandler) SetLargeResponseHook(hook integrations.LargeResponseHook) {
|
||||
for _, extension := range h.extensions {
|
||||
if setter, ok := extension.(interface {
|
||||
SetLargeResponseHook(integrations.LargeResponseHook)
|
||||
}); ok {
|
||||
setter.SetLargeResponseHook(hook)
|
||||
}
|
||||
}
|
||||
}
|
||||
1653
transports/bifrost-http/handlers/logging.go
Normal file
1653
transports/bifrost-http/handlers/logging.go
Normal file
File diff suppressed because it is too large
Load Diff
1164
transports/bifrost-http/handlers/mcp.go
Normal file
1164
transports/bifrost-http/handlers/mcp.go
Normal file
File diff suppressed because it is too large
Load Diff
112
transports/bifrost-http/handlers/mcpinference.go
Normal file
112
transports/bifrost-http/handlers/mcpinference.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/fasthttp/router"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
type MCPInferenceHandler struct {
|
||||
client *bifrost.Bifrost
|
||||
config *lib.Config
|
||||
}
|
||||
|
||||
// NewMCPInferenceHandler creates a new MCP inference handler instance
|
||||
func NewMCPInferenceHandler(client *bifrost.Bifrost, config *lib.Config) *MCPInferenceHandler {
|
||||
return &MCPInferenceHandler{
|
||||
client: client,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the MCP inference routes
|
||||
func (h *MCPInferenceHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
r.POST("/v1/mcp/tool/execute", lib.ChainMiddlewares(h.executeTool, middlewares...))
|
||||
}
|
||||
|
||||
// executeTool handles POST /v1/mcp/tool/execute - Execute MCP tool
|
||||
func (h *MCPInferenceHandler) executeTool(ctx *fasthttp.RequestCtx) {
|
||||
// Check format query parameter
|
||||
format := strings.ToLower(string(ctx.QueryArgs().Peek("format")))
|
||||
switch format {
|
||||
case "chat", "":
|
||||
h.executeChatMCPTool(ctx)
|
||||
case "responses":
|
||||
h.executeResponsesMCPTool(ctx)
|
||||
default:
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Invalid format value, must be 'chat' or 'responses'")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// executeChatMCPTool handles POST /v1/mcp/tool/execute?format=chat - Execute MCP tool
|
||||
func (h *MCPInferenceHandler) executeChatMCPTool(ctx *fasthttp.RequestCtx) {
|
||||
var req schemas.ChatAssistantMessageToolCall
|
||||
if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.Function.Name == nil || *req.Function.Name == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Tool function name is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Convert context
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
defer cancel() // Ensure cleanup on function exit
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
|
||||
// Execute MCP tool
|
||||
toolMessage, bifrostErr := h.client.ExecuteChatMCPTool(bifrostCtx, &req)
|
||||
if bifrostErr != nil {
|
||||
SendBifrostError(ctx, bifrostErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Send successful response
|
||||
SendJSON(ctx, toolMessage)
|
||||
}
|
||||
|
||||
// executeResponsesMCPTool handles POST /v1/mcp/tool/execute?format=responses - Execute MCP tool
|
||||
func (h *MCPInferenceHandler) executeResponsesMCPTool(ctx *fasthttp.RequestCtx) {
|
||||
var req schemas.ResponsesToolMessage
|
||||
if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.Name == nil || *req.Name == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Tool function name is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Convert context
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
defer cancel() // Ensure cleanup on function exit
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
|
||||
// Execute MCP tool
|
||||
toolMessage, bifrostErr := h.client.ExecuteResponsesMCPTool(bifrostCtx, &req)
|
||||
if bifrostErr != nil {
|
||||
SendBifrostError(ctx, bifrostErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Send successful response
|
||||
SendJSON(ctx, toolMessage)
|
||||
}
|
||||
568
transports/bifrost-http/handlers/mcpserver.go
Normal file
568
transports/bifrost-http/handlers/mcpserver.go
Normal file
@@ -0,0 +1,568 @@
|
||||
// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
|
||||
// This file contains MCP (Model Context Protocol) server implementation for HTTP streaming.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/maximhq/bifrost/plugins/governance"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// MCPToolExecutor interface defines the method needed for executing MCP tools
|
||||
type MCPToolManager interface {
|
||||
GetAvailableMCPTools(ctx context.Context) []schemas.ChatTool
|
||||
ExecuteChatMCPTool(ctx context.Context, toolCall *schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, *schemas.BifrostError)
|
||||
ExecuteResponsesMCPTool(ctx context.Context, toolCall *schemas.ResponsesToolMessage) (*schemas.ResponsesMessage, *schemas.BifrostError)
|
||||
}
|
||||
|
||||
// MCPServerHandler manages HTTP requests for MCP server operations
|
||||
// It implements the MCP protocol over HTTP streaming (SSE) for MCP clients
|
||||
type MCPServerHandler struct {
|
||||
toolManager MCPToolManager
|
||||
globalMCPServer *server.MCPServer
|
||||
vkMCPServers map[string]*server.MCPServer // Map of vk value -> mcp server
|
||||
config *lib.Config
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewMCPServerHandler creates a new MCP server handler instance
|
||||
func NewMCPServerHandler(ctx context.Context, config *lib.Config, toolManager MCPToolManager) (*MCPServerHandler, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("config is required")
|
||||
}
|
||||
if toolManager == nil {
|
||||
return nil, fmt.Errorf("tool manager is required")
|
||||
}
|
||||
|
||||
// Create MCP server instance using mcp-go
|
||||
globalMCPServer := server.NewMCPServer(
|
||||
"global",
|
||||
version,
|
||||
server.WithToolCapabilities(true),
|
||||
)
|
||||
|
||||
handler := &MCPServerHandler{
|
||||
toolManager: toolManager,
|
||||
globalMCPServer: globalMCPServer,
|
||||
config: config,
|
||||
vkMCPServers: make(map[string]*server.MCPServer),
|
||||
}
|
||||
|
||||
// Register per-request tool filter so x-bf-mcp-include-clients and x-bf-mcp-include-tools are respected on tools/list
|
||||
server.WithToolFilter(handler.makeIncludeClientsFilter())(handler.globalMCPServer)
|
||||
|
||||
// Register per-request tool filter so x-bf-mcp-include-clients and x-bf-mcp-include-tools are respected on tools/list
|
||||
server.WithToolFilter(handler.makeIncludeClientsFilter())(handler.globalMCPServer)
|
||||
|
||||
if err := handler.SyncAllMCPServers(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to sync all MCP servers: %w", err)
|
||||
}
|
||||
|
||||
return handler, nil
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the MCP server route
|
||||
func (h *MCPServerHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
// MCP server endpoint - supports both POST (JSON-RPC) and GET (SSE)
|
||||
r.POST("/mcp", lib.ChainMiddlewares(h.handleMCPServer, middlewares...))
|
||||
r.GET("/mcp", lib.ChainMiddlewares(h.handleMCPServerSSE, middlewares...))
|
||||
}
|
||||
|
||||
// handleMCPServer handles POST requests for MCP JSON-RPC 2.0 messages
|
||||
// injectMCPSessionIdentity sets the MCP gateway flag and, if a per-user OAuth
|
||||
// session exists, injects the session token and identity (VK / User ID) directly
|
||||
// into the BifrostContext. This avoids header-based identity propagation which
|
||||
// would be vulnerable to spoofing by upstream callers.
|
||||
//
|
||||
// Governance context keys are set here intentionally (bypassing governance plugin)
|
||||
// because in the MCP gateway path, identity is pre-authenticated via the OAuth session.
|
||||
func injectMCPSessionIdentity(bifrostCtx *schemas.BifrostContext, session *tables.TablePerUserOAuthSession) {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyIsMCPGateway, true)
|
||||
if session != nil {
|
||||
if session.AccessToken != "" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyMCPUserSession, session.AccessToken)
|
||||
}
|
||||
if session.VirtualKeyID != nil && *session.VirtualKeyID != "" && session.VirtualKey != nil && session.VirtualKey.Value != "" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, session.VirtualKey.Value)
|
||||
}
|
||||
if session.UserID != nil && *session.UserID != "" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyUserID, *session.UserID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *MCPServerHandler) handleMCPServer(ctx *fasthttp.RequestCtx) {
|
||||
mcpServer, session, err := h.getMCPServerForRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Convert context
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
defer cancel()
|
||||
|
||||
injectMCPSessionIdentity(bifrostCtx, session)
|
||||
|
||||
// Use mcp-go server to handle the request
|
||||
// HandleMessage processes JSON-RPC messages and returns appropriate responses
|
||||
response := mcpServer.HandleMessage(bifrostCtx, ctx.PostBody())
|
||||
|
||||
// Check if response is nil (notification - no response needed)
|
||||
if response == nil {
|
||||
ctx.SetStatusCode(fasthttp.StatusAccepted)
|
||||
return
|
||||
}
|
||||
|
||||
// Marshal and send response
|
||||
responseJSON, err := sonic.Marshal(response)
|
||||
if err != nil {
|
||||
logger.Warn(fmt.Sprintf("Failed to marshal MCP response: %v", err))
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to encode response: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx.SetContentType("application/json")
|
||||
ctx.SetBody(responseJSON)
|
||||
}
|
||||
|
||||
// handleMCPServerSSE handles GET requests for MCP Server-Sent Events streaming
|
||||
func (h *MCPServerHandler) handleMCPServerSSE(ctx *fasthttp.RequestCtx) {
|
||||
_, session, err := h.getMCPServerForRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Set SSE headers
|
||||
ctx.SetContentType("text/event-stream")
|
||||
ctx.Response.Header.Set("Cache-Control", "no-cache")
|
||||
ctx.Response.Header.Set("Connection", "keep-alive")
|
||||
|
||||
// Convert context
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
|
||||
injectMCPSessionIdentity(bifrostCtx, session)
|
||||
|
||||
// Use SSEStreamReader to bypass fasthttp's internal pipe batching
|
||||
reader := lib.NewSSEStreamReader()
|
||||
ctx.Response.SetBodyStream(reader, -1)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
cancel()
|
||||
reader.Done()
|
||||
}()
|
||||
|
||||
// Send initial connection message
|
||||
initMessage := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"method": "connection/opened",
|
||||
}
|
||||
if initJSON, err := sonic.Marshal(initMessage); err == nil {
|
||||
buf := make([]byte, 0, len(initJSON)+8)
|
||||
buf = append(buf, "data: "...)
|
||||
buf = append(buf, initJSON...)
|
||||
buf = append(buf, '\n', '\n')
|
||||
reader.Send(buf)
|
||||
}
|
||||
|
||||
// Wait for context cancellation (client disconnect or server-side cancel)
|
||||
<-(*bifrostCtx).Done()
|
||||
}()
|
||||
}
|
||||
|
||||
// Sync methods for MCP servers
|
||||
|
||||
func (h *MCPServerHandler) SyncAllMCPServers(ctx context.Context) error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
availableTools := h.toolManager.GetAvailableMCPTools(ctx)
|
||||
h.syncServer(h.globalMCPServer, availableTools, nil)
|
||||
logger.Debug("Synced global MCP server with %d tools", len(availableTools))
|
||||
|
||||
// initialize vkMCPServers map
|
||||
if h.config.ConfigStore != nil {
|
||||
virtualKeys, err := h.config.ConfigStore.GetVirtualKeys(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get virtual keys: %w", err)
|
||||
}
|
||||
h.vkMCPServers = make(map[string]*server.MCPServer)
|
||||
for i := range virtualKeys {
|
||||
vk := &virtualKeys[i]
|
||||
vkServer := server.NewMCPServer(
|
||||
vk.Name,
|
||||
version,
|
||||
server.WithToolCapabilities(true),
|
||||
)
|
||||
server.WithToolFilter(h.makeIncludeClientsFilter())(vkServer)
|
||||
h.vkMCPServers[vk.Value] = vkServer
|
||||
availableTools, toolFilter := h.fetchToolsForVK(vk)
|
||||
h.syncServer(h.vkMCPServers[vk.Value], availableTools, toolFilter)
|
||||
logger.Debug("Synced MCP server for virtual key '%s' with %d tools", vk.Name, len(availableTools))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MCPServerHandler) SyncVKMCPServer(vk *tables.TableVirtualKey) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
vkServer, ok := h.vkMCPServers[vk.Value]
|
||||
if !ok {
|
||||
// Add new server
|
||||
vkServer = server.NewMCPServer(
|
||||
vk.Name,
|
||||
version,
|
||||
server.WithToolCapabilities(true),
|
||||
)
|
||||
server.WithToolFilter(h.makeIncludeClientsFilter())(vkServer)
|
||||
h.vkMCPServers[vk.Value] = vkServer
|
||||
}
|
||||
availableTools, toolFilter := h.fetchToolsForVK(vk)
|
||||
h.syncServer(vkServer, availableTools, toolFilter)
|
||||
h.vkMCPServers[vk.Value] = vkServer
|
||||
logger.Debug("Synced MCP server for virtual key '%s' with %d tools", vk.Name, len(availableTools))
|
||||
}
|
||||
|
||||
func (h *MCPServerHandler) DeleteVKMCPServer(vkValue string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
delete(h.vkMCPServers, vkValue)
|
||||
}
|
||||
|
||||
func (h *MCPServerHandler) syncServer(server *server.MCPServer, availableTools []schemas.ChatTool, toolFilter []string) {
|
||||
// Clear existing tools
|
||||
toolMap := server.ListTools()
|
||||
for toolName, _ := range toolMap {
|
||||
server.DeleteTools(toolName)
|
||||
}
|
||||
|
||||
// Register tools from all connected clients
|
||||
for _, tool := range availableTools {
|
||||
// Only process function tools (skip custom tools)
|
||||
if tool.Function == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Capture tool name for closure
|
||||
toolName := tool.Function.Name
|
||||
|
||||
handler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
// Inject tool filter into execution context if present
|
||||
if toolFilter != nil {
|
||||
ctx = context.WithValue(ctx, schemas.MCPContextKeyIncludeTools, toolFilter)
|
||||
}
|
||||
// Convert to Bifrost tool call format
|
||||
toolCallType := "function"
|
||||
toolCallID := fmt.Sprintf("mcp-%s", toolName)
|
||||
argsJSON, jsonErr := sonic.Marshal(request.GetArguments())
|
||||
if jsonErr != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("Failed to marshal tool arguments: %v", jsonErr)), nil
|
||||
}
|
||||
toolCall := schemas.ChatAssistantMessageToolCall{
|
||||
ID: &toolCallID,
|
||||
Type: &toolCallType,
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: &toolName,
|
||||
Arguments: string(argsJSON),
|
||||
},
|
||||
}
|
||||
|
||||
// Execute the tool via tool executor
|
||||
toolMessage, err := h.toolManager.ExecuteChatMCPTool(ctx, &toolCall)
|
||||
if err != nil {
|
||||
if err.ExtraFields.MCPAuthRequired != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf(
|
||||
"Authentication required for %s. Open this URL to connect your account: %s",
|
||||
err.ExtraFields.MCPAuthRequired.MCPClientName, err.ExtraFields.MCPAuthRequired.AuthorizeURL,
|
||||
)), nil
|
||||
}
|
||||
return mcp.NewToolResultError(fmt.Sprintf("Tool execution failed: %v", bifrost.GetErrorMessage(err))), nil
|
||||
}
|
||||
|
||||
// Extract content from tool message
|
||||
var resultText string
|
||||
if toolMessage != nil && toolMessage.Content != nil {
|
||||
// Handle ContentStr (string content)
|
||||
if toolMessage.Content.ContentStr != nil {
|
||||
resultText = *toolMessage.Content.ContentStr
|
||||
} else if toolMessage.Content.ContentBlocks != nil {
|
||||
// Handle ContentBlocks (structured content)
|
||||
for _, block := range toolMessage.Content.ContentBlocks {
|
||||
if block.Type == schemas.ChatContentBlockTypeText && block.Text != nil {
|
||||
resultText += *block.Text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return result using mcp-go helper
|
||||
return mcp.NewToolResultText(resultText), nil
|
||||
}
|
||||
|
||||
// Convert description from *string to string
|
||||
description := ""
|
||||
if tool.Function.Description != nil {
|
||||
description = *tool.Function.Description
|
||||
}
|
||||
|
||||
// Convert Parameters to mcp.ToolInputSchema
|
||||
var inputSchema mcp.ToolInputSchema
|
||||
if tool.Function.Parameters != nil {
|
||||
inputSchema.Type = tool.Function.Parameters.Type
|
||||
if tool.Function.Parameters.Properties != nil {
|
||||
// Convert *map[string]interface{} to map[string]any
|
||||
props := make(map[string]any)
|
||||
tool.Function.Parameters.Properties.Range(func(key string, value interface{}) bool {
|
||||
props[key] = value
|
||||
return true
|
||||
})
|
||||
inputSchema.Properties = props
|
||||
}
|
||||
if tool.Function.Parameters.Required != nil {
|
||||
inputSchema.Required = tool.Function.Parameters.Required
|
||||
}
|
||||
} else {
|
||||
// Default to empty object schema if no parameters
|
||||
inputSchema.Type = "object"
|
||||
inputSchema.Properties = make(map[string]any)
|
||||
}
|
||||
|
||||
// Map Bifrost annotations back to MCP tool annotations
|
||||
var toolAnnotation mcp.ToolAnnotation
|
||||
if tool.Annotations != nil {
|
||||
toolAnnotation = mcp.ToolAnnotation{
|
||||
Title: tool.Annotations.Title,
|
||||
ReadOnlyHint: tool.Annotations.ReadOnlyHint,
|
||||
DestructiveHint: tool.Annotations.DestructiveHint,
|
||||
IdempotentHint: tool.Annotations.IdempotentHint,
|
||||
OpenWorldHint: tool.Annotations.OpenWorldHint,
|
||||
}
|
||||
}
|
||||
|
||||
// Register tool with the server
|
||||
server.AddTool(mcp.Tool{
|
||||
Name: toolName,
|
||||
Description: description,
|
||||
InputSchema: inputSchema,
|
||||
Annotations: toolAnnotation,
|
||||
}, handler)
|
||||
}
|
||||
}
|
||||
|
||||
// fetchToolsForVK fetches the tools for a given virtual key value.
|
||||
// vkValue is the virtual key value for the server, if empty, all tools will be fetched for global mcp server.
|
||||
// Returns the list of available tools and the tool filter to be applied during execution.
|
||||
func (h *MCPServerHandler) fetchToolsForVK(vk *tables.TableVirtualKey) ([]schemas.ChatTool, []string) {
|
||||
ctx := context.Background()
|
||||
var toolFilter []string
|
||||
|
||||
executeOnlyTools := make([]string, 0)
|
||||
|
||||
// Build a lookup of AllowOnAllVirtualKeys clients: clientID -> clientName.
|
||||
// Explicit VK MCPConfigs always take precedence over AllowOnAllVirtualKeys.
|
||||
allowAllVKsClients := h.config.GetAllowOnAllVirtualKeysClients()
|
||||
if allowAllVKsClients == nil {
|
||||
allowAllVKsClients = make(map[string]string)
|
||||
}
|
||||
|
||||
// Process explicit VK MCPConfigs first.
|
||||
handledClients := make(map[string]bool)
|
||||
for _, vkMcpConfig := range vk.MCPConfigs {
|
||||
clientID := vkMcpConfig.MCPClient.ClientID
|
||||
if _, isAllowAll := allowAllVKsClients[clientID]; isAllowAll {
|
||||
// Explicit config exists — it takes precedence; mark handled regardless of tool list.
|
||||
handledClients[clientID] = true
|
||||
}
|
||||
if vkMcpConfig.ToolsToExecute.IsEmpty() {
|
||||
continue
|
||||
}
|
||||
if vkMcpConfig.ToolsToExecute.IsUnrestricted() {
|
||||
executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-*", vkMcpConfig.MCPClient.Name))
|
||||
continue
|
||||
}
|
||||
for _, tool := range vkMcpConfig.ToolsToExecute {
|
||||
if tool != "" {
|
||||
// Add the tool - client config filtering will be handled by mcp.go
|
||||
// Note: Use '-' separator for individual tools (wildcard uses '-*' after client name, e.g., "client-*")
|
||||
executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-%s", vkMcpConfig.MCPClient.Name, tool))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For AllowOnAllVirtualKeys clients with no explicit VK config, allow all their tools.
|
||||
for clientID, clientName := range allowAllVKsClients {
|
||||
if !handledClients[clientID] {
|
||||
executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-*", clientName))
|
||||
}
|
||||
}
|
||||
|
||||
// Always set the include-tools filter (empty = deny-all when no MCPConfigs and no AllowOnAllVirtualKeys clients)
|
||||
ctx = context.WithValue(ctx, schemas.MCPContextKeyIncludeTools, executeOnlyTools)
|
||||
toolFilter = executeOnlyTools
|
||||
|
||||
return h.toolManager.GetAvailableMCPTools(ctx), toolFilter
|
||||
}
|
||||
|
||||
// makeIncludeClientsFilter returns a ToolFilterFunc that dynamically filters the tools/list
|
||||
// response based on the x-bf-mcp-include-clients and x-bf-mcp-include-tools request headers.
|
||||
// When neither header is present the filter is a no-op, preserving existing behaviour.
|
||||
func (h *MCPServerHandler) makeIncludeClientsFilter() server.ToolFilterFunc {
|
||||
return func(ctx context.Context, tools []mcp.Tool) []mcp.Tool {
|
||||
if ctx.Value(schemas.MCPContextKeyIncludeClients) == nil && ctx.Value(schemas.MCPContextKeyIncludeTools) == nil {
|
||||
return tools
|
||||
}
|
||||
allowed := h.toolManager.GetAvailableMCPTools(ctx)
|
||||
allowedNames := make(map[string]bool, len(allowed))
|
||||
for _, t := range allowed {
|
||||
if t.Function != nil {
|
||||
allowedNames[t.Function.Name] = true
|
||||
}
|
||||
}
|
||||
result := make([]mcp.Tool, 0, len(tools))
|
||||
for _, tool := range tools {
|
||||
if allowedNames[tool.Name] {
|
||||
result = append(result, tool)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
// Utility methods
|
||||
|
||||
func (h *MCPServerHandler) getMCPServerForRequest(ctx *fasthttp.RequestCtx) (*server.MCPServer, *tables.TablePerUserOAuthSession, error) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
h.config.Mu.RLock()
|
||||
enforceVK := h.config.ClientConfig.EnforceAuthOnInference
|
||||
h.config.Mu.RUnlock()
|
||||
|
||||
vk := getVKFromRequest(ctx)
|
||||
|
||||
// Check for Bifrost per-user OAuth Bearer token (not a VK)
|
||||
userOauthSession, sessionErr := h.getPerUserOAuthSession(ctx)
|
||||
if sessionErr != nil {
|
||||
return nil, nil, fmt.Errorf("failed to look up OAuth session: %w", sessionErr)
|
||||
}
|
||||
|
||||
// If per_user_oauth MCP clients are configured and no valid auth, return 401 with discovery
|
||||
if clients := h.config.GetPerUserOAuthMCPClients(); len(clients) > 0 && userOauthSession == nil && vk == "" {
|
||||
scheme := "http"
|
||||
if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" {
|
||||
scheme = "https"
|
||||
}
|
||||
host := string(ctx.Host())
|
||||
resourceMetadataURL := fmt.Sprintf("%s://%s/.well-known/oauth-protected-resource", scheme, host)
|
||||
ctx.Response.Header.Set("WWW-Authenticate",
|
||||
fmt.Sprintf(`Bearer resource_metadata="%s"`, resourceMetadataURL))
|
||||
return nil, nil, fmt.Errorf("oauth authentication required for mcp access")
|
||||
}
|
||||
|
||||
if userOauthSession != nil {
|
||||
if !enforceVK && (userOauthSession.VirtualKeyID == nil || *userOauthSession.VirtualKeyID == "") {
|
||||
return h.globalMCPServer, userOauthSession, nil
|
||||
}
|
||||
|
||||
if userOauthSession.VirtualKeyID == nil || *userOauthSession.VirtualKeyID == "" || userOauthSession.VirtualKey == nil {
|
||||
return nil, nil, fmt.Errorf("virtual key required in oauth session to access mcp server, please re-authenticate with a virtual key")
|
||||
}
|
||||
|
||||
vkServer, ok := h.vkMCPServers[userOauthSession.VirtualKey.Value]
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("virtual key not found")
|
||||
}
|
||||
|
||||
return vkServer, userOauthSession, nil
|
||||
}
|
||||
|
||||
// Return global MCP server if not enforcing virtual key header and no virtual key is provided
|
||||
if !enforceVK && vk == "" {
|
||||
return h.globalMCPServer, nil, nil
|
||||
}
|
||||
|
||||
if vk == "" {
|
||||
return nil, nil, fmt.Errorf("virtual key header required to access mcp server")
|
||||
}
|
||||
|
||||
vkServer, ok := h.vkMCPServers[vk]
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("virtual key not found")
|
||||
}
|
||||
|
||||
return vkServer, nil, nil
|
||||
}
|
||||
|
||||
// getPerUserOAuthSession extracts and validates a Bifrost-issued per-user OAuth
|
||||
// token from the Authorization header. Returns the session if valid, nil otherwise.
|
||||
func (h *MCPServerHandler) getPerUserOAuthSession(ctx *fasthttp.RequestCtx) (*tables.TablePerUserOAuthSession, error) {
|
||||
authHeader := strings.TrimSpace(string(ctx.Request.Header.Peek("Authorization")))
|
||||
if authHeader == "" || !strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
|
||||
return nil, nil
|
||||
}
|
||||
token := strings.TrimSpace(authHeader[7:])
|
||||
if token == "" || strings.HasPrefix(strings.ToLower(token), governance.VirtualKeyPrefix) {
|
||||
return nil, nil // It's a virtual key, not a per-user OAuth token
|
||||
}
|
||||
|
||||
if h.config.ConfigStore == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
session, err := h.config.ConfigStore.GetPerUserOAuthSessionByAccessToken(ctx, token)
|
||||
if err != nil {
|
||||
logger.Warn("[mcp/auth] GetPerUserOAuthSessionByAccessToken error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
if session == nil {
|
||||
logger.Debug("[mcp/auth] Session not found for token")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Check expiry
|
||||
if session.ExpiresAt.Before(time.Now()) {
|
||||
logger.Debug("[mcp/auth] Session expired: session_id=%s expires_at=%v", session.ID, session.ExpiresAt)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func getVKFromRequest(ctx *fasthttp.RequestCtx) string {
|
||||
if value := strings.TrimSpace(string(ctx.Request.Header.Peek(string(schemas.BifrostContextKeyVirtualKey)))); value != "" {
|
||||
return value
|
||||
}
|
||||
|
||||
authHeader := strings.TrimSpace(string(ctx.Request.Header.Peek("Authorization")))
|
||||
if authHeader != "" {
|
||||
if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
|
||||
token := strings.TrimSpace(authHeader[7:])
|
||||
if token != "" && strings.HasPrefix(strings.ToLower(token), governance.VirtualKeyPrefix) {
|
||||
return token
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if apiKey := strings.TrimSpace(string(ctx.Request.Header.Peek("x-api-key"))); apiKey != "" {
|
||||
if strings.HasPrefix(strings.ToLower(apiKey), governance.VirtualKeyPrefix) {
|
||||
return apiKey
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
1092
transports/bifrost-http/handlers/middlewares.go
Normal file
1092
transports/bifrost-http/handlers/middlewares.go
Normal file
File diff suppressed because it is too large
Load Diff
2022
transports/bifrost-http/handlers/middlewares_test.go
Normal file
2022
transports/bifrost-http/handlers/middlewares_test.go
Normal file
File diff suppressed because it is too large
Load Diff
320
transports/bifrost-http/handlers/oauth2.go
Normal file
320
transports/bifrost-http/handlers/oauth2.go
Normal file
@@ -0,0 +1,320 @@
|
||||
// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
|
||||
// This file contains OAuth 2.0 authentication flow handlers.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/oauth2"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// OAuth2Handler manages HTTP requests for OAuth2 operations
|
||||
type OAuthHandler struct {
|
||||
client *bifrost.Bifrost
|
||||
store *lib.Config
|
||||
oauthProvider *oauth2.OAuth2Provider
|
||||
}
|
||||
|
||||
// NewOAuthHandler creates a new OAuth handler instance
|
||||
func NewOAuthHandler(oauthProvider *oauth2.OAuth2Provider, client *bifrost.Bifrost, store *lib.Config) *OAuthHandler {
|
||||
return &OAuthHandler{
|
||||
client: client,
|
||||
store: store,
|
||||
oauthProvider: oauthProvider,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers all OAuth-related routes
|
||||
func (h *OAuthHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
r.GET("/api/oauth/callback", lib.ChainMiddlewares(h.handleOAuthCallback, middlewares...))
|
||||
r.GET("/api/oauth/config/{id}/status", lib.ChainMiddlewares(h.getOAuthConfigStatus, middlewares...))
|
||||
r.DELETE("/api/oauth/config/{id}", lib.ChainMiddlewares(h.revokeOAuthConfig, middlewares...))
|
||||
}
|
||||
|
||||
// handleOAuthCallback handles the OAuth provider callback
|
||||
// GET /api/oauth/callback?state=xxx&code=yyy&error=zzz
|
||||
func (h *OAuthHandler) handleOAuthCallback(ctx *fasthttp.RequestCtx) {
|
||||
state := string(ctx.QueryArgs().Peek("state"))
|
||||
code := string(ctx.QueryArgs().Peek("code"))
|
||||
errorParam := string(ctx.QueryArgs().Peek("error"))
|
||||
errorDescription := string(ctx.QueryArgs().Peek("error_description"))
|
||||
|
||||
// Handle authorization denial
|
||||
if errorParam != "" {
|
||||
h.handleCallbackError(ctx, state, errorParam, errorDescription)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required parameters
|
||||
if state == "" || code == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Missing required parameters: state and code")
|
||||
return
|
||||
}
|
||||
|
||||
// Try per-user OAuth runtime flow first (state from oauth_user_sessions table).
|
||||
// This handles the case where an end-user authenticates during inference.
|
||||
sessionToken, perUserErr := h.oauthProvider.CompleteUserOAuthFlow(context.Background(), state, code)
|
||||
if perUserErr != nil && !errors.Is(perUserErr, schemas.ErrOAuth2NotPerUserSession) {
|
||||
// Real per-user error (not "state not found") — don't fall through to admin flow
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Per-user OAuth flow failed: %v", perUserErr))
|
||||
return
|
||||
}
|
||||
if perUserErr == nil && sessionToken != "" {
|
||||
// Consent flow: session token is a flow proxy ("flow:<flowID>:<mcpClientID>").
|
||||
// Redirect back to the MCPs consent page so the user can continue.
|
||||
if strings.HasPrefix(sessionToken, "flow:") {
|
||||
rest := strings.TrimPrefix(sessionToken, "flow:")
|
||||
flowID := strings.SplitN(rest, ":", 2)[0]
|
||||
mcpsURL := fmt.Sprintf("/oauth/consent/mcps?flow_id=%s", url.QueryEscape(flowID))
|
||||
ctx.Redirect(mcpsURL, fasthttp.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Per-user runtime OAuth flow completed — show success page.
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
ctx.SetContentType("text/html")
|
||||
ctx.SetBodyString(oauthSuccessPage(`
|
||||
if (window.opener) {
|
||||
window.opener.postMessage({ type: 'oauth_success' }, window.location.origin);
|
||||
window.close();
|
||||
}
|
||||
`, "Authorization Successful", "You can close this tab."))
|
||||
return
|
||||
}
|
||||
|
||||
// Fall through to standard OAuth flow (handles both admin test logins for
|
||||
// per_user_oauth setup and regular server-level OAuth).
|
||||
if err := h.oauthProvider.CompleteOAuthFlow(context.Background(), state, code); err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("OAuth flow completion failed: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Redirect to success page (or close popup)
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
ctx.SetContentType("text/html")
|
||||
ctx.SetBodyString(oauthSuccessPage(`
|
||||
if (window.opener) {
|
||||
window.opener.postMessage({ type: 'oauth_success' }, window.location.origin);
|
||||
window.close();
|
||||
}
|
||||
`, "Authorization Successful", "OAuth authorization successful! You can close this window."))
|
||||
}
|
||||
|
||||
// handleCallbackError handles OAuth callback errors
|
||||
func (h *OAuthHandler) handleCallbackError(ctx *fasthttp.RequestCtx, state, errorParam, errorDescription string) {
|
||||
// Update OAuth config status to failed if state is provided
|
||||
if state != "" {
|
||||
oauthConfig, err := h.store.ConfigStore.GetOauthConfigByState(context.Background(), state)
|
||||
if err == nil && oauthConfig != nil {
|
||||
oauthConfig.Status = "failed"
|
||||
h.store.ConfigStore.UpdateOauthConfig(context.Background(), oauthConfig)
|
||||
}
|
||||
}
|
||||
|
||||
// Show error page
|
||||
ctx.SetStatusCode(fasthttp.StatusBadRequest)
|
||||
ctx.SetContentType("text/html")
|
||||
errorMsg := errorParam
|
||||
if errorDescription != "" {
|
||||
errorMsg = fmt.Sprintf("%s: %s", errorParam, errorDescription)
|
||||
}
|
||||
// JSON-encode for safe embedding in JavaScript context (prevents JS injection)
|
||||
jsEscaped, _ := json.Marshal(errorMsg)
|
||||
// HTML-escape for safe embedding in HTML body (prevents HTML injection)
|
||||
htmlEscaped := html.EscapeString(errorMsg)
|
||||
ctx.SetBodyString(oauthErrorPage(string(jsEscaped), htmlEscaped))
|
||||
}
|
||||
|
||||
// getOAuthConfigStatus returns the current status of an OAuth config
|
||||
// GET /api/oauth/config/{id}/status
|
||||
func (h *OAuthHandler) getOAuthConfigStatus(ctx *fasthttp.RequestCtx) {
|
||||
configID := ctx.UserValue("id").(string)
|
||||
|
||||
oauthConfig, err := h.store.ConfigStore.GetOauthConfigByID(context.Background(), configID)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get OAuth config: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if oauthConfig == nil {
|
||||
SendError(ctx, fasthttp.StatusNotFound, "OAuth config not found")
|
||||
return
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"id": oauthConfig.ID,
|
||||
"status": oauthConfig.Status,
|
||||
"created_at": oauthConfig.CreatedAt,
|
||||
"expires_at": oauthConfig.ExpiresAt,
|
||||
}
|
||||
|
||||
if oauthConfig.Status == "authorized" && oauthConfig.TokenID != nil {
|
||||
response["token_id"] = *oauthConfig.TokenID
|
||||
|
||||
// Get token metadata
|
||||
token, err := h.store.ConfigStore.GetOauthTokenByID(context.Background(), *oauthConfig.TokenID)
|
||||
if err == nil && token != nil {
|
||||
response["token_expires_at"] = token.ExpiresAt
|
||||
response["token_scopes"] = token.Scopes
|
||||
}
|
||||
}
|
||||
|
||||
SendJSON(ctx, response)
|
||||
}
|
||||
|
||||
// revokeOAuthConfig revokes an OAuth configuration and its associated token
|
||||
// DELETE /api/oauth/config/{id}
|
||||
func (h *OAuthHandler) revokeOAuthConfig(ctx *fasthttp.RequestCtx) {
|
||||
configID := ctx.UserValue("id").(string)
|
||||
|
||||
if err := h.oauthProvider.RevokeToken(context.Background(), configID); err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to revoke OAuth token: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
SendJSON(ctx, map[string]interface{}{
|
||||
"message": "OAuth token revoked successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// OAuthInitiationRequest represents the request to initiate an OAuth flow
|
||||
type OAuthInitiationRequest struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
AuthorizeURL string `json:"authorize_url"`
|
||||
TokenURL string `json:"token_url"`
|
||||
RegistrationURL string `json:"registration_url"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
Scopes []string `json:"scopes"`
|
||||
ServerURL string `json:"server_url"` // For OAuth discovery
|
||||
}
|
||||
|
||||
// InitiateOAuthFlow initiates an OAuth flow and returns the authorization URL
|
||||
// This is called internally by the MCP client creation endpoint
|
||||
func (h *OAuthHandler) InitiateOAuthFlow(ctx context.Context, req OAuthInitiationRequest) (*schemas.OAuth2FlowInitiation, error) {
|
||||
var registrationURL *string
|
||||
if req.RegistrationURL != "" {
|
||||
registrationURL = &req.RegistrationURL
|
||||
}
|
||||
|
||||
config := &schemas.OAuth2Config{
|
||||
ClientID: req.ClientID,
|
||||
ClientSecret: req.ClientSecret,
|
||||
AuthorizeURL: req.AuthorizeURL,
|
||||
TokenURL: req.TokenURL,
|
||||
RegistrationURL: registrationURL,
|
||||
RedirectURI: req.RedirectURI,
|
||||
Scopes: req.Scopes,
|
||||
ServerURL: req.ServerURL, // MCP server URL for OAuth discovery
|
||||
}
|
||||
|
||||
return h.oauthProvider.InitiateOAuthFlow(ctx, config)
|
||||
}
|
||||
|
||||
// StorePendingMCPClient stores an MCP client config in the database while waiting for OAuth completion
|
||||
// This supports multi-instance deployments where OAuth callback may hit a different server instance
|
||||
func (h *OAuthHandler) StorePendingMCPClient(oauthConfigID string, mcpClientConfig schemas.MCPClientConfig) error {
|
||||
return h.oauthProvider.StorePendingMCPClient(oauthConfigID, mcpClientConfig)
|
||||
}
|
||||
|
||||
// GetPendingMCPClient retrieves a pending MCP client config by oauth_config_id
|
||||
func (h *OAuthHandler) GetPendingMCPClient(oauthConfigID string) (*schemas.MCPClientConfig, error) {
|
||||
return h.oauthProvider.GetPendingMCPClient(oauthConfigID)
|
||||
}
|
||||
|
||||
// GetPendingMCPClientByState retrieves a pending MCP client config by OAuth state token
|
||||
func (h *OAuthHandler) GetPendingMCPClientByState(state string) (*schemas.MCPClientConfig, string, error) {
|
||||
return h.oauthProvider.GetPendingMCPClientByState(state)
|
||||
}
|
||||
|
||||
// RemovePendingMCPClient removes a pending MCP client after OAuth completion.
|
||||
func (h *OAuthHandler) RemovePendingMCPClient(oauthConfigID string) error {
|
||||
return h.oauthProvider.RemovePendingMCPClient(oauthConfigID)
|
||||
}
|
||||
|
||||
// GetAccessToken retrieves the access token for a given oauth_config_id.
|
||||
// Used during per-user OAuth setup to get the admin's temporary token for verification.
|
||||
func (h *OAuthHandler) GetAccessToken(ctx context.Context, oauthConfigID string) (string, error) {
|
||||
return h.oauthProvider.GetAccessToken(ctx, oauthConfigID)
|
||||
}
|
||||
|
||||
// oauthSuccessPage renders a Bifrost-themed success HTML page.
|
||||
// extraScript is injected verbatim into a <script> tag (caller is responsible for safety).
|
||||
func oauthSuccessPage(extraScript, title, message string) string {
|
||||
return fmt.Sprintf(`<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>%s</title>
|
||||
<style>%s
|
||||
.icon{font-size:2.5rem;margin-bottom:16px}
|
||||
.msg{font-size:0.9rem;color:oklch(0.552 0.016 285.938);margin-top:8px}
|
||||
</style>
|
||||
<script>%s</script>
|
||||
</head>
|
||||
<body>
|
||||
<div class="card" style="text-align:center">
|
||||
<div class="icon">✓</div>
|
||||
<h1>%s</h1>
|
||||
<p class="msg">%s</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`, html.EscapeString(title), bifrostPageCSS, extraScript, html.EscapeString(title), html.EscapeString(message))
|
||||
}
|
||||
|
||||
// oauthErrorPage renders a Bifrost-themed error HTML page.
|
||||
// jsEscapedError must already be JSON-encoded (with quotes) for safe JS embedding.
|
||||
// htmlError must already be HTML-escaped for safe body embedding.
|
||||
func oauthErrorPage(jsEscapedError, htmlError string) string {
|
||||
return fmt.Sprintf(`<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Authorization Failed</title>
|
||||
<style>%s
|
||||
.icon{font-size:2.5rem;margin-bottom:16px;color:oklch(0.50 0.18 27)}
|
||||
.err-msg{font-size:0.9rem;color:oklch(0.552 0.016 285.938);margin-top:8px}
|
||||
.hint{font-size:0.8rem;color:oklch(0.65 0.01 286);margin-top:16px}
|
||||
</style>
|
||||
<script>
|
||||
if (window.opener) {
|
||||
window.opener.postMessage({ type: 'oauth_failed', error: %s }, window.location.origin);
|
||||
window.close();
|
||||
}
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div class="card" style="text-align:center">
|
||||
<div class="icon">✗</div>
|
||||
<h1>Authorization Failed</h1>
|
||||
<p class="err-msg">%s</p>
|
||||
<p class="hint">You can close this window.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`, bifrostPageCSS, jsEscapedError, htmlError)
|
||||
}
|
||||
|
||||
// jsEscapeString returns a JSON-encoded string (with quotes) safe for embedding in JavaScript.
|
||||
func jsEscapeString(s string) string {
|
||||
b, _ := json.Marshal(s)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// RevokeToken revokes the OAuth token for a given oauth_config_id.
|
||||
// Used during per-user OAuth setup to discard the admin's temporary token after verification.
|
||||
func (h *OAuthHandler) RevokeToken(ctx context.Context, oauthConfigID string) error {
|
||||
return h.oauthProvider.RevokeToken(ctx, oauthConfigID)
|
||||
}
|
||||
643
transports/bifrost-http/handlers/oauth2_consent.go
Normal file
643
transports/bifrost-http/handlers/oauth2_consent.go
Normal file
@@ -0,0 +1,643 @@
|
||||
// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
|
||||
// This file implements the per-user OAuth consent flow — the intermediate screens
|
||||
// shown between the MCP client's authorize request and the final authorization code
|
||||
// issuance. The flow is:
|
||||
//
|
||||
// 1. GET /oauth/consent?flow_id=xxx → VK input page (HTML)
|
||||
// 2. POST /api/oauth/per-user/consent/vk → validate VK, update PendingFlow, redirect
|
||||
// 3. GET /oauth/consent/mcps?flow_id=xxx → MCPs page (HTML, server-rendered)
|
||||
// 4. POST /api/oauth/per-user/consent/submit → create session + code, redirect to client
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"html"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/google/uuid"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// ConsentHandler manages the per-user OAuth consent flow screens.
|
||||
type ConsentHandler struct {
|
||||
store *lib.Config
|
||||
}
|
||||
|
||||
// NewConsentHandler creates a new consent handler instance.
|
||||
func NewConsentHandler(store *lib.Config) *ConsentHandler {
|
||||
return &ConsentHandler{store: store}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the consent flow routes.
|
||||
// All routes are public — no auth middleware — since they are part of the OAuth
|
||||
// flow for unauthenticated users acquiring credentials.
|
||||
func (h *ConsentHandler) RegisterRoutes(r *router.Router) {
|
||||
// HTML pages (GET, served by Go)
|
||||
r.GET("/oauth/consent", h.handleIdentityPage)
|
||||
r.GET("/oauth/consent/mcps", h.handleMCPsPage)
|
||||
|
||||
// API actions (POST)
|
||||
// NOTE: All state-mutating endpoints use POST. CSRF protection relies on the
|
||||
// SameSite=Lax browser-binding cookie (__bifrost_flow_secret) combined with
|
||||
// the flow_id — SameSite=Lax blocks cross-site POST, and the cookie is
|
||||
// HttpOnly+Secure. This is sufficient for the threat model here.
|
||||
r.POST("/api/oauth/per-user/consent/vk", h.handleSubmitVK)
|
||||
r.POST("/api/oauth/per-user/consent/user-id", h.handleSubmitUserID)
|
||||
r.POST("/api/oauth/per-user/consent/skip", h.handleSkip)
|
||||
r.POST("/api/oauth/per-user/consent/submit", h.handleSubmit)
|
||||
}
|
||||
|
||||
// ---------- HTML pages ----------
|
||||
|
||||
// handleIdentityPage renders the identity selection page with three options:
|
||||
// User ID, Virtual Key, or skip (lazy auth when tools are called).
|
||||
// GET /oauth/consent?flow_id=xxx[&error=xxx]
|
||||
func (h *ConsentHandler) handleIdentityPage(ctx *fasthttp.RequestCtx) {
|
||||
flowID := string(ctx.QueryArgs().Peek("flow_id"))
|
||||
errorMsg := string(ctx.QueryArgs().Peek("error"))
|
||||
|
||||
if flowID == "" {
|
||||
ctx.SetStatusCode(fasthttp.StatusBadRequest)
|
||||
ctx.SetBodyString("Missing flow_id")
|
||||
return
|
||||
}
|
||||
|
||||
if h.store.ConfigStore == nil {
|
||||
ctx.SetStatusCode(fasthttp.StatusServiceUnavailable)
|
||||
ctx.SetBodyString("Config store unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID)
|
||||
if err != nil {
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
ctx.SetBodyString("Failed to load consent flow.")
|
||||
return
|
||||
}
|
||||
if flow == nil || time.Now().After(flow.ExpiresAt) {
|
||||
ctx.SetStatusCode(fasthttp.StatusBadRequest)
|
||||
ctx.SetBodyString("Invalid or expired consent flow. Please restart the authentication process.")
|
||||
return
|
||||
}
|
||||
if !validateFlowBrowserSecret(ctx, flow) {
|
||||
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||
ctx.SetBodyString("Flow does not belong to this browser session. Please restart the authentication process.")
|
||||
return
|
||||
}
|
||||
|
||||
h.store.Mu.RLock()
|
||||
enforceVK := h.store.ClientConfig.EnforceAuthOnInference
|
||||
h.store.Mu.RUnlock()
|
||||
|
||||
safeFlowID := html.EscapeString(flowID)
|
||||
safeError := html.EscapeString(errorMsg)
|
||||
|
||||
errorBanner := ""
|
||||
if safeError != "" {
|
||||
errorBanner = fmt.Sprintf(`<div class="error-banner">%s</div>`, safeError)
|
||||
}
|
||||
|
||||
skipOption := ""
|
||||
if !enforceVK {
|
||||
skipOption = fmt.Sprintf(`
|
||||
<div class="option">
|
||||
<span class="option-title">Skip for now</span>
|
||||
<span class="option-desc">Connect to services when a tool is called</span>
|
||||
<form action="/api/oauth/per-user/consent/skip" method="POST" style="margin-top:10px">
|
||||
<input type="hidden" name="flow_id" value="%s">
|
||||
<button type="submit" class="btn btn-ghost">Skip</button>
|
||||
</form>
|
||||
</div>`, safeFlowID)
|
||||
}
|
||||
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
ctx.SetContentType("text/html; charset=utf-8")
|
||||
ctx.SetBodyString(fmt.Sprintf(`<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Connect to Bifrost</title>
|
||||
<style>
|
||||
%s
|
||||
.option{border:1px solid oklch(0.92 0.004 286.32);border-radius:0.5rem;padding:16px 18px;margin-bottom:10px}
|
||||
.option-title{display:block;font-size:0.9rem;font-weight:600;color:oklch(0.141 0.005 285.823);margin-bottom:2px}
|
||||
.option-desc{display:block;font-size:0.8rem;color:oklch(0.552 0.016 285.938);margin-bottom:12px}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="card">
|
||||
<h1>Connect to Bifrost</h1>
|
||||
<p class="subtitle">Choose how to identify yourself for this session.</p>
|
||||
<p style="font-size:0.75rem;color:oklch(0.65 0.01 286);margin-bottom:18px">This setup page expires in 15 minutes.</p>
|
||||
%s
|
||||
<div class="option">
|
||||
<span class="option-title">User ID</span>
|
||||
<span class="option-desc">Use a stable identifier — access all available services</span>
|
||||
<form action="/api/oauth/per-user/consent/user-id" method="POST">
|
||||
<input type="hidden" name="flow_id" value="%s">
|
||||
<label for="user_id">User ID</label>
|
||||
<input type="text" id="user_id" name="user_id" placeholder="e.g. alice" autocomplete="off" spellcheck="false" autocapitalize="none" autocorrect="off">
|
||||
<button type="submit" class="btn btn-primary">Continue with User ID</button>
|
||||
</form>
|
||||
</div>
|
||||
<div class="option">
|
||||
<span class="option-title">Virtual Key</span>
|
||||
<span class="option-desc">Use a VK — access services within your key's limits</span>
|
||||
<form action="/api/oauth/per-user/consent/vk" method="POST">
|
||||
<input type="hidden" name="flow_id" value="%s">
|
||||
<label for="vk">Virtual Key</label>
|
||||
<input type="password" id="vk" name="vk" placeholder="sk-bf-..." autocomplete="off" spellcheck="false" autocapitalize="none">
|
||||
<button type="submit" class="btn btn-primary">Continue with Virtual Key</button>
|
||||
</form>
|
||||
</div>
|
||||
%s
|
||||
</div>
|
||||
</body>
|
||||
</html>`, bifrostPageCSS, errorBanner, safeFlowID, safeFlowID, skipOption))
|
||||
}
|
||||
|
||||
// handleMCPsPage renders the MCP authentication list page.
|
||||
// GET /oauth/consent/mcps?flow_id=xxx
|
||||
func (h *ConsentHandler) handleMCPsPage(ctx *fasthttp.RequestCtx) {
|
||||
flowID := string(ctx.QueryArgs().Peek("flow_id"))
|
||||
|
||||
if flowID == "" {
|
||||
ctx.SetStatusCode(fasthttp.StatusBadRequest)
|
||||
ctx.SetBodyString("Missing flow_id")
|
||||
return
|
||||
}
|
||||
|
||||
if h.store.ConfigStore == nil {
|
||||
ctx.SetStatusCode(fasthttp.StatusServiceUnavailable)
|
||||
ctx.SetBodyString("Config store unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID)
|
||||
if err != nil {
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
ctx.SetBodyString("Failed to load consent flow.")
|
||||
return
|
||||
}
|
||||
if flow == nil || time.Now().After(flow.ExpiresAt) {
|
||||
ctx.SetStatusCode(fasthttp.StatusBadRequest)
|
||||
ctx.SetBodyString("Invalid or expired consent flow. Please restart the authentication process.")
|
||||
return
|
||||
}
|
||||
if !validateFlowBrowserSecret(ctx, flow) {
|
||||
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||
ctx.SetBodyString("Flow does not belong to this browser session. Please restart the authentication process.")
|
||||
return
|
||||
}
|
||||
|
||||
// Find which MCP clients the user has already authed.
|
||||
// Check both: tokens stored under the flow proxy (connected during this flow)
|
||||
// and tokens already stored under the VK/user identity (connected in a prior flow).
|
||||
completedTokens, err := h.store.ConfigStore.GetOauthUserTokensByGatewaySessionID(ctx, flowID)
|
||||
if err != nil {
|
||||
completedTokens = nil // non-fatal; just show no checkmarks
|
||||
}
|
||||
completedMCPs := make(map[string]bool, len(completedTokens))
|
||||
for _, t := range completedTokens {
|
||||
completedMCPs[t.MCPClientID] = true
|
||||
}
|
||||
|
||||
// Per_user_oauth MCP clients visible to this identity — sorted for deterministic rendering.
|
||||
// When a VK is set on the flow, only show clients that VK is allowed to use.
|
||||
perUserClients := h.store.GetPerUserOAuthMCPClientsForVirtualKey(ctx, strVal(flow.VirtualKeyID))
|
||||
clientIDs := make([]string, 0, len(perUserClients))
|
||||
for id := range perUserClients {
|
||||
clientIDs = append(clientIDs, id)
|
||||
}
|
||||
sort.Strings(clientIDs)
|
||||
|
||||
safeFlowID := html.EscapeString(flowID)
|
||||
|
||||
// Determine if user skipped identity selection.
|
||||
isSkipped := strVal(flow.VirtualKeyID) == "" && strVal(flow.UserID) == ""
|
||||
|
||||
// Build MCP rows — only show connect buttons if user has an identity.
|
||||
var mcpRows strings.Builder
|
||||
if isSkipped {
|
||||
mcpRows.WriteString(`<p style="color:#6b7280;font-size:14px;">You skipped identity selection. Services will be connected when you first use their tools. Since no identity is attached, your connections will only persist as long as the service keeps the OAuth token active — they will not be remembered across sessions.</p>`)
|
||||
} else {
|
||||
for _, clientID := range clientIDs {
|
||||
clientName := perUserClients[clientID]
|
||||
safeName := html.EscapeString(clientName)
|
||||
|
||||
// Also check if a token already exists under the user's identity (e.g. from a prior LLM gateway auth).
|
||||
alreadyConnected := completedMCPs[clientID]
|
||||
if !alreadyConnected && (strVal(flow.VirtualKeyID) != "" || strVal(flow.UserID) != "") {
|
||||
existing, tokenErr := h.store.ConfigStore.GetOauthUserTokenByIdentity(ctx, strVal(flow.VirtualKeyID), strVal(flow.UserID), "", clientID)
|
||||
if tokenErr != nil {
|
||||
logger.Warn("[consent/mcps] failed to check existing token: mcp_client_id=%s err=%v", clientID, tokenErr)
|
||||
}
|
||||
alreadyConnected = existing != nil
|
||||
}
|
||||
|
||||
if alreadyConnected {
|
||||
mcpRows.WriteString(fmt.Sprintf(`
|
||||
<div class="mcp-row">
|
||||
<div class="mcp-name">%s</div>
|
||||
<span class="badge connected">✓ Connected</span>
|
||||
</div>`, safeName))
|
||||
} else {
|
||||
connectURL := fmt.Sprintf("/api/oauth/per-user/upstream/authorize?mcp_client_id=%s&flow_id=%s",
|
||||
url.QueryEscape(clientID), url.QueryEscape(flowID))
|
||||
mcpRows.WriteString(fmt.Sprintf(`
|
||||
<div class="mcp-row">
|
||||
<div class="mcp-name">%s</div>
|
||||
<a class="badge connect" href="%s">Connect</a>
|
||||
</div>`, safeName, html.EscapeString(connectURL)))
|
||||
}
|
||||
}
|
||||
if len(perUserClients) == 0 {
|
||||
mcpRows.WriteString(`<p style="color:#6b7280;font-size:14px;">No MCP services require authentication.</p>`)
|
||||
}
|
||||
}
|
||||
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
ctx.SetContentType("text/html; charset=utf-8")
|
||||
ctx.SetBodyString(fmt.Sprintf(`<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Connect Your Apps — Bifrost</title>
|
||||
<style>
|
||||
%s
|
||||
.mcp-row{display:flex;align-items:center;justify-content:space-between;padding:12px 0;border-bottom:1px solid oklch(0.92 0.004 286.32)}
|
||||
.mcp-row:last-of-type{border-bottom:none}
|
||||
.mcp-name{font-size:0.9rem;font-weight:500;color:oklch(0.141 0.005 285.823)}
|
||||
.badge{font-size:0.8rem;font-weight:500;padding:4px 12px;border-radius:20px;text-decoration:none;display:inline-block}
|
||||
.badge.connected{background:oklch(0.95 0.05 160);color:oklch(0.35 0.08 160)}
|
||||
.badge.connect{background:oklch(0.5081 0.1049 165.61);color:oklch(0.985 0 0);cursor:pointer;
|
||||
padding:8px 18px;border-radius:0.5rem;font-weight:500;
|
||||
transition:background .15s}
|
||||
.badge.connect:hover{background:oklch(0.43 0.1049 165.61)}
|
||||
.mcp-list{margin-bottom:4px}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="card">
|
||||
<h1>Connect Your Apps</h1>
|
||||
<p class="subtitle">Authenticate with the services below to enable their tools.</p>
|
||||
<p style="font-size:0.75rem;color:oklch(0.65 0.01 286);margin-bottom:18px">This setup page expires in 15 minutes.</p>
|
||||
<div class="mcp-list">%s</div>
|
||||
<form action="/api/oauth/per-user/consent/submit" method="POST" style="margin-top:24px">
|
||||
<input type="hidden" name="flow_id" value="%s">
|
||||
<button type="submit" class="btn btn-primary">Finish Setup</button>
|
||||
</form>
|
||||
<div style="text-align:center;margin-top:12px">
|
||||
<a href="/oauth/consent?flow_id=%s" style="font-size:0.8rem;color:oklch(0.552 0.016 285.938);text-decoration:none">Change identity</a>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>`, bifrostPageCSS, mcpRows.String(), safeFlowID, safeFlowID))
|
||||
}
|
||||
|
||||
// ---------- API action handlers ----------
|
||||
|
||||
// handleSubmitVK validates the submitted Virtual Key, links it to the pending flow,
|
||||
// and redirects to the MCPs page.
|
||||
// POST /api/oauth/per-user/consent/vk (form: flow_id, vk)
|
||||
func (h *ConsentHandler) handleSubmitVK(ctx *fasthttp.RequestCtx) {
|
||||
if h.store.ConfigStore == nil {
|
||||
SendError(ctx, fasthttp.StatusServiceUnavailable, "Config store unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
flowID := string(ctx.FormValue("flow_id"))
|
||||
vkValue := strings.TrimSpace(string(ctx.FormValue("vk")))
|
||||
|
||||
if flowID == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "flow_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to load consent flow")
|
||||
return
|
||||
}
|
||||
if flow == nil || time.Now().After(flow.ExpiresAt) {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Invalid or expired consent flow")
|
||||
return
|
||||
}
|
||||
if !validateFlowBrowserSecret(ctx, flow) {
|
||||
SendError(ctx, fasthttp.StatusForbidden, "Flow does not belong to this browser session")
|
||||
return
|
||||
}
|
||||
|
||||
if vkValue == "" {
|
||||
redirectToIdentityPage(ctx, flowID, "Please enter a Virtual Key.")
|
||||
return
|
||||
}
|
||||
|
||||
vk, err := h.store.ConfigStore.GetVirtualKeyByValue(ctx, vkValue)
|
||||
if err != nil {
|
||||
redirectToIdentityPage(ctx, flowID, "Failed to validate Virtual Key. Please try again.")
|
||||
return
|
||||
}
|
||||
if vk == nil || !vk.IsActive {
|
||||
redirectToIdentityPage(ctx, flowID, "Virtual Key not found or inactive. Please check and try again.")
|
||||
return
|
||||
}
|
||||
|
||||
flow.VirtualKeyID = &vk.ID
|
||||
flow.UserID = nil // Clear other identity to keep selection exclusive
|
||||
if err := h.store.ConfigStore.UpdatePerUserOAuthPendingFlow(ctx, flow); err != nil {
|
||||
redirectToIdentityPage(ctx, flowID, "Failed to save Virtual Key. Please try again.")
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Redirect(fmt.Sprintf("/oauth/consent/mcps?flow_id=%s", url.QueryEscape(flowID)), fasthttp.StatusFound)
|
||||
}
|
||||
|
||||
// handleSubmitUserID links a user-supplied User ID to the pending flow and proceeds to MCPs page.
|
||||
// SECURITY: The User ID is self-declared (typed in a form) with no server-side verification.
|
||||
// This matches the trust model of X-Bf-User-Id in the LLM gateway path. Deployments requiring
|
||||
// verified identity should use Virtual Keys or an auth layer in front of Bifrost.
|
||||
// POST /api/oauth/per-user/consent/user-id (form: flow_id, user_id)
|
||||
func (h *ConsentHandler) handleSubmitUserID(ctx *fasthttp.RequestCtx) {
|
||||
if h.store.ConfigStore == nil {
|
||||
SendError(ctx, fasthttp.StatusServiceUnavailable, "Config store unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
flowID := string(ctx.FormValue("flow_id"))
|
||||
userID := strings.TrimSpace(string(ctx.FormValue("user_id")))
|
||||
|
||||
if flowID == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "flow_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to load consent flow")
|
||||
return
|
||||
}
|
||||
if flow == nil || time.Now().After(flow.ExpiresAt) {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Invalid or expired consent flow")
|
||||
return
|
||||
}
|
||||
if !validateFlowBrowserSecret(ctx, flow) {
|
||||
SendError(ctx, fasthttp.StatusForbidden, "Flow does not belong to this browser session")
|
||||
return
|
||||
}
|
||||
|
||||
if userID == "" {
|
||||
redirectToIdentityPage(ctx, flowID, "Please enter a User ID.")
|
||||
return
|
||||
}
|
||||
if len(userID) > 255 {
|
||||
redirectToIdentityPage(ctx, flowID, "User ID is too long (max 255 characters).")
|
||||
return
|
||||
}
|
||||
|
||||
if userID != "" {
|
||||
flow.UserID = &userID
|
||||
}
|
||||
flow.VirtualKeyID = nil // Clear other identity to keep selection exclusive
|
||||
if err := h.store.ConfigStore.UpdatePerUserOAuthPendingFlow(ctx, flow); err != nil {
|
||||
redirectToIdentityPage(ctx, flowID, "Failed to save User ID. Please try again.")
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Redirect(fmt.Sprintf("/oauth/consent/mcps?flow_id=%s", url.QueryEscape(flowID)), fasthttp.StatusFound)
|
||||
}
|
||||
|
||||
// handleSkip skips identity selection and proceeds directly to the MCPs page.
|
||||
// Upstream services will be connected lazily when tools are first called.
|
||||
// POST /api/oauth/per-user/consent/skip (form: flow_id)
|
||||
func (h *ConsentHandler) handleSkip(ctx *fasthttp.RequestCtx) {
|
||||
if h.store.ConfigStore == nil {
|
||||
SendError(ctx, fasthttp.StatusServiceUnavailable, "Config store unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
flowID := string(ctx.FormValue("flow_id"))
|
||||
if flowID == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "flow_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to load consent flow")
|
||||
return
|
||||
}
|
||||
if flow == nil || time.Now().After(flow.ExpiresAt) {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Invalid or expired consent flow")
|
||||
return
|
||||
}
|
||||
if !validateFlowBrowserSecret(ctx, flow) {
|
||||
SendError(ctx, fasthttp.StatusForbidden, "Flow does not belong to this browser session")
|
||||
return
|
||||
}
|
||||
|
||||
h.store.Mu.RLock()
|
||||
enforceVK := h.store.ClientConfig.EnforceAuthOnInference
|
||||
h.store.Mu.RUnlock()
|
||||
|
||||
if enforceVK {
|
||||
redirectToIdentityPage(ctx, flowID, "An identity (Virtual Key or User ID) is required. Please choose one to continue.")
|
||||
return
|
||||
}
|
||||
|
||||
// Clear any previously selected identity so skip truly resets the flow.
|
||||
if strVal(flow.VirtualKeyID) != "" || strVal(flow.UserID) != "" {
|
||||
flow.VirtualKeyID = nil
|
||||
flow.UserID = nil
|
||||
if err := h.store.ConfigStore.UpdatePerUserOAuthPendingFlow(ctx, flow); err != nil {
|
||||
redirectToIdentityPage(ctx, flowID, "Failed to clear identity. Please try again.")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Skip goes straight to MCPs page; no identity means only lazy auth is available.
|
||||
ctx.Redirect(fmt.Sprintf("/oauth/consent/mcps?flow_id=%s", url.QueryEscape(flowID)), fasthttp.StatusFound)
|
||||
}
|
||||
|
||||
// handleSubmit finalises the consent flow:
|
||||
// 1. Creates a real Bifrost session (TablePerUserOAuthSession)
|
||||
// 2. Migrates upstream tokens from the flow proxy to the real session
|
||||
// 3. Issues a TablePerUserOAuthCode
|
||||
// 4. Deletes the PendingFlow
|
||||
// 5. Redirects to the original MCP client callback URL with code + state
|
||||
//
|
||||
// POST /api/oauth/per-user/consent/submit (form: flow_id)
|
||||
func (h *ConsentHandler) handleSubmit(ctx *fasthttp.RequestCtx) {
|
||||
if h.store.ConfigStore == nil {
|
||||
SendError(ctx, fasthttp.StatusServiceUnavailable, "Config store unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
flowID := string(ctx.FormValue("flow_id"))
|
||||
if flowID == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "flow_id is required")
|
||||
return
|
||||
}
|
||||
flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to load consent flow")
|
||||
return
|
||||
}
|
||||
if flow == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Invalid consent flow")
|
||||
return
|
||||
}
|
||||
if time.Now().After(flow.ExpiresAt) {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Consent flow has expired. Please restart the authentication process.")
|
||||
return
|
||||
}
|
||||
if !validateFlowBrowserSecret(ctx, flow) {
|
||||
SendError(ctx, fasthttp.StatusForbidden, "Flow does not belong to this browser session")
|
||||
return
|
||||
}
|
||||
|
||||
// Server-side enforcement: reject if identity is required but not provided.
|
||||
h.store.Mu.RLock()
|
||||
enforceAuth := h.store.ClientConfig.EnforceAuthOnInference
|
||||
h.store.Mu.RUnlock()
|
||||
if enforceAuth && strVal(flow.VirtualKeyID) == "" && strVal(flow.UserID) == "" {
|
||||
redirectToIdentityPage(ctx, flowID, "An identity (Virtual Key or User ID) is required. Please choose one to continue.")
|
||||
return
|
||||
}
|
||||
|
||||
// 1. Generate session credentials.
|
||||
accessToken, err := generateOpaqueToken(32)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate session token")
|
||||
return
|
||||
}
|
||||
refreshToken, err := generateOpaqueToken(32)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate refresh token")
|
||||
return
|
||||
}
|
||||
|
||||
session := &tables.TablePerUserOAuthSession{
|
||||
ID: uuid.New().String(),
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
ClientID: flow.ClientID,
|
||||
VirtualKeyID: flow.VirtualKeyID,
|
||||
UserID: flow.UserID,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
// 2. Generate authorization code.
|
||||
code, err := generateOpaqueToken(32)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate authorization code")
|
||||
return
|
||||
}
|
||||
codeRecord := &tables.TablePerUserOAuthCode{
|
||||
ID: uuid.New().String(),
|
||||
Code: code,
|
||||
ClientID: flow.ClientID,
|
||||
RedirectURI: flow.RedirectURI,
|
||||
CodeChallenge: flow.CodeChallenge,
|
||||
SessionID: session.ID, // Links token endpoint to this session so it can return the same access token
|
||||
// Scopes intentionally omitted: the consent flow has no scope selection step.
|
||||
ExpiresAt: time.Now().Add(5 * time.Minute),
|
||||
}
|
||||
|
||||
// 3. Atomically consume the pending flow, create session, and create auth code.
|
||||
// If another concurrent request already consumed the flow, rowsAffected will be 0.
|
||||
rowsAffected, err := h.store.ConfigStore.FinalizePerUserOAuthConsent(ctx, flowID, session, codeRecord)
|
||||
if err != nil {
|
||||
if errors.Is(err, schemas.ErrPerUserOAuthPendingFlowExpired) {
|
||||
SendError(ctx, fasthttp.StatusGone, "Consent flow has expired. Please restart the authentication process.")
|
||||
return
|
||||
}
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to finalize consent flow")
|
||||
return
|
||||
}
|
||||
if rowsAffected == 0 {
|
||||
SendError(ctx, fasthttp.StatusConflict, "Consent flow has already been submitted")
|
||||
return
|
||||
}
|
||||
logger.Debug("[consent/submit] session created: session_id=%s flow_id=%s", session.ID, flowID)
|
||||
|
||||
// 4. Migrate upstream tokens from flow proxy sessions to real session (non-fatal).
|
||||
if err := h.store.ConfigStore.TransferOauthUserTokensFromGatewaySession(ctx, flowID, accessToken, strVal(flow.VirtualKeyID), strVal(flow.UserID)); err != nil {
|
||||
// Non-fatal: tokens can be re-acquired on first tool use.
|
||||
logger.Warn("[consent/submit] failed to transfer upstream tokens: flow_id=%s err=%v", flowID, err)
|
||||
}
|
||||
|
||||
// 5. Redirect to MCP client callback with code + original state.
|
||||
redirectURL, err := url.Parse(flow.RedirectURI)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Invalid redirect URI in pending flow")
|
||||
return
|
||||
}
|
||||
q := redirectURL.Query()
|
||||
q.Set("code", code)
|
||||
if flow.State != "" {
|
||||
q.Set("state", flow.State)
|
||||
}
|
||||
redirectURL.RawQuery = q.Encode()
|
||||
|
||||
ctx.Redirect(redirectURL.String(), fasthttp.StatusFound)
|
||||
}
|
||||
|
||||
// ---------- helpers ----------
|
||||
|
||||
// bifrostPageCSS is the shared inline CSS for all Go-rendered consent/callback pages.
|
||||
// It mirrors Bifrost's UI design tokens: teal primary, zinc palette, Geist font stack.
|
||||
const bifrostPageCSS = `
|
||||
*,*::before,*::after{box-sizing:border-box;margin:0;padding:0}
|
||||
body{font-family:"Geist",system-ui,-apple-system,sans-serif;font-size:0.95rem;
|
||||
line-height:1.5;background:#f4f4f5;color:oklch(0.141 0.005 285.823);
|
||||
display:flex;align-items:center;justify-content:center;min-height:100vh;
|
||||
-webkit-font-smoothing:antialiased}
|
||||
.card{background:#fff;border:1px solid oklch(0.92 0.004 286.32);border-radius:12px;
|
||||
padding:40px;width:100%;max-width:480px}
|
||||
h1{font-size:1.25rem;font-weight:600;color:oklch(0.141 0.005 285.823);margin-bottom:6px}
|
||||
.subtitle{font-size:0.825rem;color:oklch(0.552 0.016 285.938);line-height:1.5;margin-bottom:24px}
|
||||
label{display:block;font-size:0.825rem;font-weight:500;color:oklch(0.141 0.005 285.823);margin-bottom:5px}
|
||||
input[type=text],input[type=password]{width:100%;padding:8px 12px;border:1px solid oklch(0.92 0.004 286.32);
|
||||
border-radius:0.5rem;font-size:0.875rem;outline:none;
|
||||
transition:border-color .15s,box-shadow .15s;margin-bottom:10px;
|
||||
background:#fff;color:oklch(0.141 0.005 285.823)}
|
||||
input[type=text]:focus,input[type=password]:focus{border-color:oklch(0.5081 0.1049 165.61);
|
||||
box-shadow:0 0 0 3px oklch(0.5081 0.1049 165.61 / 0.15)}
|
||||
.btn{display:block;width:100%;padding:9px 16px;border-radius:0.5rem;font-size:0.875rem;
|
||||
font-weight:500;cursor:pointer;border:none;text-align:center;text-decoration:none;
|
||||
transition:background .15s;font-family:inherit}
|
||||
.btn-primary{background:oklch(0.5081 0.1049 165.61);color:oklch(0.985 0 0)}
|
||||
.btn-primary:hover{background:oklch(0.43 0.1049 165.61)}
|
||||
.btn-ghost{background:transparent;border:1px solid oklch(0.92 0.004 286.32);
|
||||
color:oklch(0.552 0.016 285.938);display:inline-block;width:auto;padding:8px 16px}
|
||||
.btn-ghost:hover{background:#f4f4f5}
|
||||
.error-banner{background:oklch(0.97 0.02 27);border:1px solid oklch(0.88 0.06 27);
|
||||
border-radius:0.5rem;padding:12px 14px;margin-bottom:18px;
|
||||
color:oklch(0.50 0.18 27);font-size:0.825rem}
|
||||
`
|
||||
|
||||
// redirectToIdentityPage redirects to the identity selection page with an error message.
|
||||
func redirectToIdentityPage(ctx *fasthttp.RequestCtx, flowID, errorMsg string) {
|
||||
u := fmt.Sprintf("/oauth/consent?flow_id=%s&error=%s",
|
||||
url.QueryEscape(flowID), url.QueryEscape(errorMsg))
|
||||
ctx.Redirect(u, fasthttp.StatusFound)
|
||||
}
|
||||
|
||||
// strVal safely dereferences a *string, returning "" for nil.
|
||||
func strVal(s *string) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return *s
|
||||
}
|
||||
93
transports/bifrost-http/handlers/oauth2_metadata.go
Normal file
93
transports/bifrost-http/handlers/oauth2_metadata.go
Normal file
@@ -0,0 +1,93 @@
|
||||
// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
|
||||
// This file implements OAuth 2.0 metadata discovery endpoints per RFC 9728
|
||||
// (Protected Resource Metadata) and RFC 8414 (Authorization Server Metadata).
|
||||
// These endpoints enable MCP-spec-compliant clients (like Claude Code) to
|
||||
// automatically discover Bifrost's OAuth configuration and authenticate.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// OAuthMetadataHandler serves OAuth 2.0 discovery metadata endpoints.
|
||||
// It provides the Protected Resource Metadata (RFC 9728) and Authorization
|
||||
// Server Metadata (RFC 8414) that MCP clients use to discover how to
|
||||
// authenticate with Bifrost's MCP server endpoint.
|
||||
type OAuthMetadataHandler struct {
|
||||
store *lib.Config
|
||||
}
|
||||
|
||||
// NewOAuthMetadataHandler creates a new OAuth metadata handler instance.
|
||||
func NewOAuthMetadataHandler(store *lib.Config) *OAuthMetadataHandler {
|
||||
return &OAuthMetadataHandler{store: store}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the well-known metadata discovery routes.
|
||||
// These routes do NOT go through auth middleware since they must be
|
||||
// accessible to unauthenticated clients during OAuth discovery.
|
||||
func (h *OAuthMetadataHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
// RFC 9728: Protected Resource Metadata
|
||||
r.GET("/.well-known/oauth-protected-resource", lib.ChainMiddlewares(h.handleProtectedResourceMetadata, middlewares...))
|
||||
// RFC 8414: Authorization Server Metadata
|
||||
r.GET("/.well-known/oauth-authorization-server", lib.ChainMiddlewares(h.handleAuthorizationServerMetadata, middlewares...))
|
||||
}
|
||||
|
||||
// handleProtectedResourceMetadata serves the Protected Resource Metadata
|
||||
// document per RFC 9728. MCP clients fetch this after receiving a 401 response
|
||||
// to discover which authorization server(s) protect the MCP resource.
|
||||
//
|
||||
// GET /.well-known/oauth-protected-resource
|
||||
func (h *OAuthMetadataHandler) handleProtectedResourceMetadata(ctx *fasthttp.RequestCtx) {
|
||||
if clients := h.store.GetPerUserOAuthMCPClients(); len(clients) == 0 {
|
||||
sendStringError(ctx, fasthttp.StatusNotFound, "Not Found")
|
||||
return
|
||||
}
|
||||
scheme := "http"
|
||||
if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" {
|
||||
scheme = "https"
|
||||
}
|
||||
host := string(ctx.Host())
|
||||
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
||||
|
||||
SendJSON(ctx, map[string]interface{}{
|
||||
"resource": baseURL + "/mcp",
|
||||
"authorization_servers": []string{baseURL},
|
||||
"scopes_supported": []string{"mcp:read", "mcp:write"},
|
||||
"bearer_methods_supported": []string{"header"},
|
||||
})
|
||||
}
|
||||
|
||||
// handleAuthorizationServerMetadata serves the Authorization Server Metadata
|
||||
// document per RFC 8414. MCP clients use this to discover Bifrost's OAuth
|
||||
// endpoints (authorize, token, register) and supported capabilities.
|
||||
//
|
||||
// GET /.well-known/oauth-authorization-server
|
||||
func (h *OAuthMetadataHandler) handleAuthorizationServerMetadata(ctx *fasthttp.RequestCtx) {
|
||||
if clients := h.store.GetPerUserOAuthMCPClients(); len(clients) == 0 {
|
||||
sendStringError(ctx, fasthttp.StatusNotFound, "Not Found")
|
||||
return
|
||||
}
|
||||
scheme := "http"
|
||||
if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" {
|
||||
scheme = "https"
|
||||
}
|
||||
host := string(ctx.Host())
|
||||
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
||||
|
||||
SendJSON(ctx, map[string]interface{}{
|
||||
"issuer": baseURL,
|
||||
"authorization_endpoint": baseURL + "/api/oauth/per-user/authorize",
|
||||
"token_endpoint": baseURL + "/api/oauth/per-user/token",
|
||||
"registration_endpoint": baseURL + "/api/oauth/per-user/register",
|
||||
"response_types_supported": []string{"code"},
|
||||
"grant_types_supported": []string{"authorization_code"},
|
||||
"code_challenge_methods_supported": []string{"S256"},
|
||||
"token_endpoint_auth_methods_supported": []string{"none"},
|
||||
"scopes_supported": []string{"mcp:read", "mcp:write"},
|
||||
})
|
||||
}
|
||||
577
transports/bifrost-http/handlers/oauth2_per_user.go
Normal file
577
transports/bifrost-http/handlers/oauth2_per_user.go
Normal file
@@ -0,0 +1,577 @@
|
||||
// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
|
||||
// This file implements Bifrost's OAuth 2.1 Authorization Server for per-user MCP
|
||||
// authentication. It provides Dynamic Client Registration (RFC 7591), Authorization
|
||||
// Code flow with PKCE, and token issuance. MCP clients (Claude Code, IDEs) use
|
||||
// these endpoints to authenticate users before accessing Bifrost's /mcp endpoint.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/google/uuid"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
"github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// PerUserOAuthHandler implements Bifrost's OAuth 2.1 Authorization Server.
|
||||
// It handles dynamic client registration, authorization code issuance with PKCE,
|
||||
// and token exchange for MCP per-user authentication.
|
||||
type PerUserOAuthHandler struct {
|
||||
store *lib.Config
|
||||
}
|
||||
|
||||
// NewPerUserOAuthHandler creates a new per-user OAuth handler instance.
|
||||
func NewPerUserOAuthHandler(store *lib.Config) *PerUserOAuthHandler {
|
||||
return &PerUserOAuthHandler{store: store}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the per-user OAuth authorization server routes.
|
||||
// These routes do NOT go through auth middleware since they are part of the
|
||||
// OAuth flow that unauthenticated clients use to obtain tokens.
|
||||
func (h *PerUserOAuthHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
r.POST("/api/oauth/per-user/register", lib.ChainMiddlewares(h.handleDynamicClientRegistration, middlewares...))
|
||||
r.GET("/api/oauth/per-user/authorize", lib.ChainMiddlewares(h.handleAuthorize, middlewares...))
|
||||
r.POST("/api/oauth/per-user/token", lib.ChainMiddlewares(h.handleToken, middlewares...))
|
||||
r.GET("/api/oauth/per-user/upstream/authorize", lib.ChainMiddlewares(h.handleUpstreamAuthorize, middlewares...))
|
||||
}
|
||||
|
||||
// handleDynamicClientRegistration handles OAuth 2.0 Dynamic Client Registration
|
||||
// per RFC 7591. MCP clients register themselves to obtain a client_id.
|
||||
//
|
||||
// POST /api/oauth/per-user/register
|
||||
func (h *PerUserOAuthHandler) handleDynamicClientRegistration(ctx *fasthttp.RequestCtx) {
|
||||
if h.store.ConfigStore == nil {
|
||||
SendError(ctx, fasthttp.StatusServiceUnavailable, "OAuth registration unavailable: config store is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
if len(h.store.GetPerUserOAuthMCPClients()) == 0 {
|
||||
sendStringError(ctx, fasthttp.StatusNotFound, "Not found")
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
ClientName string `json:"client_name"`
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
ResponseTypes []string `json:"response_types"`
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(ctx.PostBody(), &req); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid registration request: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.RedirectURIs) == 0 {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "redirect_uris is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Generate client_id
|
||||
clientID := uuid.New().String()
|
||||
|
||||
// Serialize arrays
|
||||
redirectURIsJSON, _ := json.Marshal(req.RedirectURIs)
|
||||
grantTypes := req.GrantTypes
|
||||
if len(grantTypes) == 0 {
|
||||
grantTypes = []string{"authorization_code"}
|
||||
}
|
||||
grantTypesJSON, _ := json.Marshal(grantTypes)
|
||||
|
||||
client := &tables.TablePerUserOAuthClient{
|
||||
ID: uuid.New().String(),
|
||||
ClientID: clientID,
|
||||
ClientName: req.ClientName,
|
||||
RedirectURIs: string(redirectURIsJSON),
|
||||
GrantTypes: string(grantTypesJSON),
|
||||
}
|
||||
|
||||
if err := h.store.ConfigStore.CreatePerUserOAuthClient(ctx, client); err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to register client: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Return RFC 7591 response
|
||||
ctx.SetStatusCode(fasthttp.StatusCreated)
|
||||
SendJSON(ctx, map[string]interface{}{
|
||||
"client_id": clientID,
|
||||
"client_name": req.ClientName,
|
||||
"redirect_uris": req.RedirectURIs,
|
||||
"grant_types": grantTypes,
|
||||
"response_types": req.ResponseTypes,
|
||||
"token_endpoint_auth_method": "none",
|
||||
})
|
||||
}
|
||||
|
||||
// handleAuthorize handles the OAuth 2.1 authorization endpoint.
|
||||
// Instead of issuing a code immediately, it validates the request parameters,
|
||||
// creates a PendingFlow record, and redirects the user to the consent screen.
|
||||
// The code is only issued after the user completes the consent flow (VK + MCP auths).
|
||||
//
|
||||
// GET /api/oauth/per-user/authorize?response_type=code&client_id=xxx&redirect_uri=xxx&code_challenge=xxx&code_challenge_method=S256[&state=xxx]
|
||||
func (h *PerUserOAuthHandler) handleAuthorize(ctx *fasthttp.RequestCtx) {
|
||||
if h.store.ConfigStore == nil {
|
||||
SendError(ctx, fasthttp.StatusServiceUnavailable, "OAuth authorization unavailable: config store is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
if len(h.store.GetPerUserOAuthMCPClients()) == 0 {
|
||||
sendStringError(ctx, fasthttp.StatusNotFound, "Not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Extract parameters
|
||||
responseType := string(ctx.QueryArgs().Peek("response_type"))
|
||||
clientID := string(ctx.QueryArgs().Peek("client_id"))
|
||||
redirectURI := string(ctx.QueryArgs().Peek("redirect_uri"))
|
||||
state := string(ctx.QueryArgs().Peek("state"))
|
||||
codeChallenge := string(ctx.QueryArgs().Peek("code_challenge"))
|
||||
codeChallengeMethod := string(ctx.QueryArgs().Peek("code_challenge_method"))
|
||||
|
||||
// Validate required parameters
|
||||
if responseType != "code" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "response_type must be 'code'")
|
||||
return
|
||||
}
|
||||
if clientID == "" || redirectURI == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "client_id and redirect_uri are required")
|
||||
return
|
||||
}
|
||||
if codeChallenge == "" || codeChallengeMethod != "S256" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "PKCE is required: code_challenge and code_challenge_method=S256")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate client exists and redirect_uri is registered
|
||||
client, err := h.store.ConfigStore.GetPerUserOAuthClientByClientID(ctx, clientID)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to validate client: %v", err))
|
||||
return
|
||||
}
|
||||
if client == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Unknown client_id")
|
||||
return
|
||||
}
|
||||
var allowedURIs []string
|
||||
json.Unmarshal([]byte(client.RedirectURIs), &allowedURIs)
|
||||
uriAllowed := false
|
||||
for _, allowed := range allowedURIs {
|
||||
if allowed == redirectURI {
|
||||
uriAllowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !uriAllowed {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "redirect_uri not registered for this client")
|
||||
return
|
||||
}
|
||||
|
||||
// Generate a browser-binding secret so only the initiating browser can resume this flow.
|
||||
browserSecret, err := generateOpaqueToken(32)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate browser secret")
|
||||
return
|
||||
}
|
||||
browserSecretHash := fmt.Sprintf("%x", sha256.Sum256([]byte(browserSecret)))
|
||||
|
||||
// Create a PendingFlow to carry OAuth params through the consent screen.
|
||||
flow := &tables.TablePerUserOAuthPendingFlow{
|
||||
ID: uuid.New().String(),
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
CodeChallenge: codeChallenge,
|
||||
State: state,
|
||||
BrowserSecretHash: browserSecretHash,
|
||||
ExpiresAt: time.Now().Add(15 * time.Minute),
|
||||
}
|
||||
if err := h.store.ConfigStore.CreatePerUserOAuthPendingFlow(ctx, flow); err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to create pending flow: %v", err))
|
||||
return
|
||||
}
|
||||
logger.Debug("[oauth/authorize] PendingFlow created: flow_id=%s client_id=%s", flow.ID, clientID)
|
||||
|
||||
// Set HttpOnly cookie binding this flow to the current browser.
|
||||
var cookie fasthttp.Cookie
|
||||
cookie.SetKey("__bifrost_flow_secret")
|
||||
cookie.SetValue(browserSecret)
|
||||
cookie.SetPath("/")
|
||||
cookie.SetHTTPOnly(true)
|
||||
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
|
||||
isSecure := ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https"
|
||||
cookie.SetSecure(isSecure)
|
||||
cookie.SetMaxAge(15 * 60) // 15 minutes, matching flow TTL
|
||||
ctx.Response.Header.SetCookie(&cookie)
|
||||
|
||||
// Redirect to consent screen with flow_id (relative path — stays on current origin).
|
||||
consentURL := fmt.Sprintf("/oauth/consent?flow_id=%s", url.QueryEscape(flow.ID))
|
||||
ctx.Redirect(consentURL, fasthttp.StatusFound)
|
||||
}
|
||||
|
||||
// handleToken handles the OAuth 2.1 token endpoint.
|
||||
// It validates the authorization code + PKCE verifier and issues access/refresh tokens.
|
||||
//
|
||||
// POST /api/oauth/per-user/token
|
||||
func (h *PerUserOAuthHandler) handleToken(ctx *fasthttp.RequestCtx) {
|
||||
if h.store.ConfigStore == nil {
|
||||
SendError(ctx, fasthttp.StatusServiceUnavailable, "OAuth token endpoint unavailable: config store is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
if len(h.store.GetPerUserOAuthMCPClients()) == 0 {
|
||||
sendStringError(ctx, fasthttp.StatusNotFound, "Not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse form-encoded body
|
||||
grantType := string(ctx.FormValue("grant_type"))
|
||||
code := string(ctx.FormValue("code"))
|
||||
redirectURI := string(ctx.FormValue("redirect_uri"))
|
||||
clientID := string(ctx.FormValue("client_id"))
|
||||
codeVerifier := string(ctx.FormValue("code_verifier"))
|
||||
|
||||
if grantType != "authorization_code" {
|
||||
sendOAuthError(ctx, fasthttp.StatusBadRequest, "unsupported_grant_type", "Only authorization_code grant is supported")
|
||||
return
|
||||
}
|
||||
|
||||
if code == "" || codeVerifier == "" {
|
||||
sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_request", "code and code_verifier are required")
|
||||
return
|
||||
}
|
||||
|
||||
// Atomically claim authorization code (prevents concurrent redemption)
|
||||
codeRecord, err := h.store.ConfigStore.ClaimPerUserOAuthCode(ctx, code)
|
||||
if err != nil {
|
||||
sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Failed to validate code")
|
||||
return
|
||||
}
|
||||
if codeRecord == nil {
|
||||
sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "Invalid or already used authorization code")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate code is not expired
|
||||
if time.Now().After(codeRecord.ExpiresAt) {
|
||||
sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "Authorization code expired")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate client_id if provided — some public clients omit it (RFC 6749 §4.1.3 allows
|
||||
// omitting client_id when the client is not authenticating with the server).
|
||||
// The code record already binds the code to the correct client, so this is safe.
|
||||
if clientID != "" && codeRecord.ClientID != clientID {
|
||||
logger.Debug("[oauth/token] client_id mismatch: code_client=%s request_client=%s", codeRecord.ClientID, clientID)
|
||||
sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "client_id mismatch")
|
||||
return
|
||||
}
|
||||
// Use the client_id from the code record as the authoritative value.
|
||||
clientID = codeRecord.ClientID
|
||||
|
||||
// Validate redirect_uri matches
|
||||
if redirectURI != "" && codeRecord.RedirectURI != redirectURI {
|
||||
logger.Debug("[oauth/token] redirect_uri mismatch: code=%s request=%s", codeRecord.RedirectURI, redirectURI)
|
||||
sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "redirect_uri mismatch")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate PKCE: SHA256(code_verifier) must match code_challenge
|
||||
verifierHash := sha256.Sum256([]byte(codeVerifier))
|
||||
computedChallenge := base64.RawURLEncoding.EncodeToString(verifierHash[:])
|
||||
if computedChallenge != codeRecord.CodeChallenge {
|
||||
logger.Debug("[oauth/token] PKCE verification failed")
|
||||
sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "PKCE verification failed")
|
||||
return
|
||||
}
|
||||
|
||||
// If the code was issued by the consent flow (handleSubmit), the session already exists
|
||||
// with the upstream tokens transferred to it. Reuse that session's access token so the
|
||||
// client receives the token that the upstream (Notion, GitHub, etc.) tokens are linked to.
|
||||
var accessToken string
|
||||
var expiresAt time.Time
|
||||
|
||||
if codeRecord.SessionID != "" {
|
||||
existingSession, err := h.store.ConfigStore.GetPerUserOAuthSessionByID(ctx, codeRecord.SessionID)
|
||||
if err != nil {
|
||||
logger.Info("[oauth/token] Failed to load existing session: session_id=%s err=%v", codeRecord.SessionID, err)
|
||||
sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Failed to load session")
|
||||
return
|
||||
}
|
||||
if existingSession == nil {
|
||||
logger.Info("[oauth/token] Existing session not found: session_id=%s", codeRecord.SessionID)
|
||||
sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Session not found")
|
||||
return
|
||||
}
|
||||
if !existingSession.ExpiresAt.After(time.Now()) {
|
||||
sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "Session expired")
|
||||
return
|
||||
}
|
||||
accessToken = existingSession.AccessToken
|
||||
expiresAt = existingSession.ExpiresAt
|
||||
logger.Debug("[oauth/token] reusing consent session: session_id=%s", existingSession.ID)
|
||||
} else {
|
||||
// Fallback: no linked session (legacy path) — create a new one.
|
||||
var newAccessToken, newRefreshToken string
|
||||
newAccessToken, err = generateOpaqueToken(32)
|
||||
if err != nil {
|
||||
sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Failed to generate access token")
|
||||
return
|
||||
}
|
||||
newRefreshToken, err = generateOpaqueToken(32)
|
||||
if err != nil {
|
||||
sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Failed to generate refresh token")
|
||||
return
|
||||
}
|
||||
expiresAt = time.Now().Add(24 * time.Hour)
|
||||
newSession := &tables.TablePerUserOAuthSession{
|
||||
ID: uuid.New().String(),
|
||||
AccessToken: newAccessToken,
|
||||
RefreshToken: newRefreshToken,
|
||||
ClientID: clientID,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
if err := h.store.ConfigStore.CreatePerUserOAuthSession(ctx, newSession); err != nil {
|
||||
sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Failed to create session")
|
||||
return
|
||||
}
|
||||
accessToken = newAccessToken
|
||||
logger.Debug("[oauth/token] created new session (legacy path): session_id=%s", newSession.ID)
|
||||
}
|
||||
// Return OAuth token response
|
||||
ctx.SetContentType("application/json")
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
SendJSON(ctx, map[string]interface{}{
|
||||
"access_token": accessToken,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": int(time.Until(expiresAt).Seconds()),
|
||||
"scope": codeRecord.Scopes,
|
||||
})
|
||||
}
|
||||
|
||||
// sendOAuthError sends an OAuth 2.0 error response per RFC 6749 Section 5.2.
|
||||
func sendOAuthError(ctx *fasthttp.RequestCtx, statusCode int, errorCode, description string) {
|
||||
ctx.SetContentType("application/json")
|
||||
ctx.SetStatusCode(statusCode)
|
||||
resp, _ := json.Marshal(map[string]string{
|
||||
"error": errorCode,
|
||||
"error_description": description,
|
||||
})
|
||||
ctx.SetBody(resp)
|
||||
}
|
||||
|
||||
func sendStringError(ctx *fasthttp.RequestCtx, statusCode int, message string) {
|
||||
ctx.SetContentType("text/plain")
|
||||
ctx.SetStatusCode(statusCode)
|
||||
ctx.SetBodyString(message)
|
||||
}
|
||||
|
||||
// generateOpaqueToken generates a cryptographically secure random token.
|
||||
// validateFlowBrowserSecret checks that the request carries the __bifrost_flow_secret
|
||||
// cookie matching the hash stored on the pending flow. Returns true if valid.
|
||||
func validateFlowBrowserSecret(ctx *fasthttp.RequestCtx, flow *tables.TablePerUserOAuthPendingFlow) bool {
|
||||
if flow.BrowserSecretHash == "" {
|
||||
// Legacy flow without browser binding — allow for backwards compatibility.
|
||||
return true
|
||||
}
|
||||
secret := ctx.Request.Header.Cookie("__bifrost_flow_secret")
|
||||
if len(secret) == 0 {
|
||||
return false
|
||||
}
|
||||
hash := fmt.Sprintf("%x", sha256.Sum256(secret))
|
||||
return hash == flow.BrowserSecretHash
|
||||
}
|
||||
|
||||
func generateOpaqueToken(length int) (string, error) {
|
||||
bytes := make([]byte, length)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// handleUpstreamAuthorize handles the upstream OAuth proxy for per-user OAuth.
|
||||
// When a user needs to authenticate with an upstream MCP server (e.g., Notion),
|
||||
// this endpoint redirects them to the upstream provider's OAuth authorize URL.
|
||||
// After the user authenticates, the callback stores their upstream token linked
|
||||
// to either their Bifrost session (runtime flow) or a PendingFlow (consent flow).
|
||||
//
|
||||
// Runtime flow: GET /api/oauth/per-user/upstream/authorize?mcp_client_id=xxx&session=xxx
|
||||
// Consent flow: GET /api/oauth/per-user/upstream/authorize?mcp_client_id=xxx&flow_id=xxx
|
||||
func (h *PerUserOAuthHandler) handleUpstreamAuthorize(ctx *fasthttp.RequestCtx) {
|
||||
if h.store.ConfigStore == nil {
|
||||
SendError(ctx, fasthttp.StatusServiceUnavailable, "OAuth upstream authorization unavailable: config store is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
mcpClientID := string(ctx.QueryArgs().Peek("mcp_client_id"))
|
||||
sessionID := string(ctx.QueryArgs().Peek("session"))
|
||||
flowID := string(ctx.QueryArgs().Peek("flow_id"))
|
||||
|
||||
if mcpClientID == "" || (sessionID == "" && flowID == "") {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "mcp_client_id and either session or flow_id are required")
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve identity depending on whether this is a runtime session or a consent flow.
|
||||
var virtualKeyID, userID, proxySessionToken, gatewaySessionID string
|
||||
if flowID != "" {
|
||||
// Consent flow: use the pending flow for identity and proxy token.
|
||||
flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID)
|
||||
if err != nil || flow == nil || time.Now().After(flow.ExpiresAt) {
|
||||
SendError(ctx, fasthttp.StatusUnauthorized, "Invalid or expired consent flow")
|
||||
return
|
||||
}
|
||||
if !validateFlowBrowserSecret(ctx, flow) {
|
||||
SendError(ctx, fasthttp.StatusForbidden, "Flow does not belong to this browser session")
|
||||
return
|
||||
}
|
||||
if strVal(flow.VirtualKeyID) != "" {
|
||||
virtualKeyID = *flow.VirtualKeyID
|
||||
}
|
||||
if strVal(flow.UserID) != "" {
|
||||
userID = *flow.UserID
|
||||
}
|
||||
// Use a prefixed flow token so the callback can detect the consent path.
|
||||
// Include mcpClientID to avoid unique constraint violations when multiple
|
||||
// MCP services are connected in the same consent flow.
|
||||
proxySessionToken = "flow:" + flowID + ":" + mcpClientID
|
||||
gatewaySessionID = flowID
|
||||
} else {
|
||||
// Runtime flow: validate the existing Bifrost session.
|
||||
bifrostSession, err := h.store.ConfigStore.GetPerUserOAuthSessionByID(ctx, sessionID)
|
||||
if err != nil || bifrostSession == nil {
|
||||
SendError(ctx, fasthttp.StatusUnauthorized, "Invalid or expired session")
|
||||
return
|
||||
}
|
||||
if !bifrostSession.ExpiresAt.After(time.Now()) {
|
||||
SendError(ctx, fasthttp.StatusUnauthorized, "Invalid or expired session")
|
||||
return
|
||||
}
|
||||
virtualKeyID = strVal(bifrostSession.VirtualKeyID)
|
||||
userID = strVal(bifrostSession.UserID)
|
||||
proxySessionToken = "runtime:" + sessionID + ":" + mcpClientID
|
||||
gatewaySessionID = sessionID
|
||||
}
|
||||
|
||||
// Look up the MCP client config to get the template OAuth config.
|
||||
mcpClient, err := h.store.ConfigStore.GetMCPClientByID(ctx, mcpClientID)
|
||||
if err != nil || mcpClient == nil {
|
||||
SendError(ctx, fasthttp.StatusNotFound, "MCP client not found")
|
||||
return
|
||||
}
|
||||
if mcpClient.AuthType != string(schemas.MCPAuthTypePerUserOauth) {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "MCP client does not use per-user OAuth")
|
||||
return
|
||||
}
|
||||
if mcpClient.OauthConfigID == nil || *mcpClient.OauthConfigID == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "MCP client has no OAuth configuration")
|
||||
return
|
||||
}
|
||||
|
||||
// Load template OAuth config (has upstream authorize_url, client_id, etc.)
|
||||
templateConfig, err := h.store.ConfigStore.GetOauthConfigByID(ctx, *mcpClient.OauthConfigID)
|
||||
if err != nil || templateConfig == nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to load OAuth template config")
|
||||
return
|
||||
}
|
||||
|
||||
// Generate PKCE challenge for upstream.
|
||||
codeVerifier, err := generateOpaqueToken(32)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate PKCE verifier")
|
||||
return
|
||||
}
|
||||
verifierHash := sha256.Sum256([]byte(codeVerifier))
|
||||
codeChallenge := base64.RawURLEncoding.EncodeToString(verifierHash[:])
|
||||
|
||||
// Generate state for upstream.
|
||||
state, err := generateOpaqueToken(32)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate state token")
|
||||
return
|
||||
}
|
||||
|
||||
// Build redirect URI (Bifrost's callback endpoint).
|
||||
scheme := "http"
|
||||
if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" {
|
||||
scheme = "https"
|
||||
}
|
||||
host := string(ctx.Host())
|
||||
redirectURI := fmt.Sprintf("%s://%s/api/oauth/callback", scheme, host)
|
||||
var vkId *string
|
||||
if virtualKeyID != "" {
|
||||
vkId = &virtualKeyID
|
||||
}
|
||||
var uid *string
|
||||
if userID != "" {
|
||||
uid = &userID
|
||||
}
|
||||
// Store upstream OAuth session linking state → MCP client + identity.
|
||||
upstreamSession := &tables.TableOauthUserSession{
|
||||
ID: uuid.New().String(),
|
||||
MCPClientID: mcpClientID,
|
||||
OauthConfigID: *mcpClient.OauthConfigID,
|
||||
State: state,
|
||||
CodeVerifier: codeVerifier,
|
||||
SessionToken: proxySessionToken, // "runtime:xxx" for runtime flow; "flow:xxx" for consent flow
|
||||
GatewaySessionID: gatewaySessionID,
|
||||
VirtualKeyID: vkId,
|
||||
UserID: uid,
|
||||
Status: "pending",
|
||||
ExpiresAt: time.Now().Add(15 * time.Minute),
|
||||
}
|
||||
logger.Debug("[oauth/upstream-authorize] creating upstream session: mcp_client=%s flow=%s", mcpClientID, proxySessionToken)
|
||||
if err := h.store.ConfigStore.CreateOauthUserSession(ctx, upstreamSession); err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to create upstream OAuth session: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Parse scopes from template config.
|
||||
var scopes []string
|
||||
if templateConfig.Scopes != "" {
|
||||
json.Unmarshal([]byte(templateConfig.Scopes), &scopes)
|
||||
}
|
||||
|
||||
// Build upstream authorize URL with PKCE.
|
||||
params := url.Values{}
|
||||
params.Set("response_type", "code")
|
||||
params.Set("client_id", templateConfig.ClientID)
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("state", state)
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
if len(scopes) > 0 {
|
||||
params.Set("scope", strings.Join(scopes, " "))
|
||||
}
|
||||
|
||||
baseURL, err := url.Parse(templateConfig.AuthorizeURL)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Invalid upstream authorize URL")
|
||||
return
|
||||
}
|
||||
existing := baseURL.Query()
|
||||
for k, vals := range params {
|
||||
for _, v := range vals {
|
||||
existing.Set(k, v)
|
||||
}
|
||||
}
|
||||
baseURL.RawQuery = existing.Encode()
|
||||
ctx.Redirect(baseURL.String(), fasthttp.StatusFound)
|
||||
}
|
||||
|
||||
// Ensure unused imports are referenced.
|
||||
var _ = html.EscapeString
|
||||
var _ configstore.ConfigStore
|
||||
491
transports/bifrost-http/handlers/plugins.go
Normal file
491
transports/bifrost-http/handlers/plugins.go
Normal file
@@ -0,0 +1,491 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/maximhq/bifrost/framework/plugins"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
type PluginsLoader interface {
|
||||
ReloadPlugin(ctx context.Context, name string, path *string, pluginConfig any, placement *schemas.PluginPlacement, order *int) error
|
||||
RemovePlugin(ctx context.Context, name string) error
|
||||
GetPluginStatus(ctx context.Context) map[string]schemas.PluginStatus
|
||||
}
|
||||
|
||||
// PluginsHandler is the handler for the plugins API
|
||||
type PluginsHandler struct {
|
||||
configStore configstore.ConfigStore
|
||||
pluginsLoader PluginsLoader
|
||||
}
|
||||
|
||||
// NewPluginsHandler creates a new PluginsHandler
|
||||
func NewPluginsHandler(pluginsLoader PluginsLoader, configStore configstore.ConfigStore) *PluginsHandler {
|
||||
return &PluginsHandler{
|
||||
pluginsLoader: pluginsLoader,
|
||||
configStore: configStore,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
// CreatePluginRequest is the request body for creating a plugin
|
||||
type CreatePluginRequest struct {
|
||||
Name string `json:"name"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Config map[string]any `json:"config"`
|
||||
Path *string `json:"path"`
|
||||
Placement *schemas.PluginPlacement `json:"placement,omitempty"`
|
||||
Order *int `json:"order,omitempty"`
|
||||
}
|
||||
|
||||
// UpdatePluginRequest is the request body for updating a plugin
|
||||
type UpdatePluginRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Path *string `json:"path"`
|
||||
Config map[string]any `json:"config"`
|
||||
Placement *schemas.PluginPlacement `json:"placement,omitempty"`
|
||||
Order *int `json:"order,omitempty"`
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the routes for the PluginsHandler
|
||||
func (h *PluginsHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
r.GET("/api/plugins", lib.ChainMiddlewares(h.getPlugins, middlewares...))
|
||||
r.GET("/api/plugins/{name}", lib.ChainMiddlewares(h.getPlugin, middlewares...))
|
||||
r.POST("/api/plugins", lib.ChainMiddlewares(h.createPlugin, middlewares...))
|
||||
r.PUT("/api/plugins/{name}", lib.ChainMiddlewares(h.updatePlugin, middlewares...))
|
||||
r.DELETE("/api/plugins/{name}", lib.ChainMiddlewares(h.deletePlugin, middlewares...))
|
||||
}
|
||||
|
||||
type PluginResponse struct {
|
||||
Name string `json:"name"`
|
||||
ActualName string `json:"actualName"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Config any `json:"config"`
|
||||
IsCustom bool `json:"isCustom"`
|
||||
Path *string `json:"path"`
|
||||
Placement *schemas.PluginPlacement `json:"placement,omitempty"`
|
||||
Order *int `json:"order,omitempty"`
|
||||
Status schemas.PluginStatus `json:"status"`
|
||||
}
|
||||
|
||||
// buildPluginResponse constructs a PluginResponse with status for a given TablePlugin.
|
||||
func (h *PluginsHandler) buildPluginResponse(ctx context.Context, plugin *configstoreTables.TablePlugin) PluginResponse {
|
||||
pluginStatus := schemas.PluginStatus{
|
||||
Name: plugin.Name,
|
||||
Status: schemas.PluginStatusUninitialized,
|
||||
Logs: []string{},
|
||||
}
|
||||
if !plugin.Enabled {
|
||||
pluginStatus.Status = schemas.PluginStatusDisabled
|
||||
} else {
|
||||
for _, status := range h.pluginsLoader.GetPluginStatus(ctx) {
|
||||
if plugin.Name == status.Name {
|
||||
pluginStatus = status
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return PluginResponse{
|
||||
Name: plugin.Name,
|
||||
ActualName: pluginStatus.Name,
|
||||
Enabled: plugin.Enabled,
|
||||
Config: plugin.Config,
|
||||
IsCustom: plugin.IsCustom,
|
||||
Path: plugin.Path,
|
||||
Placement: plugin.Placement,
|
||||
Order: plugin.Order,
|
||||
Status: pluginStatus,
|
||||
}
|
||||
}
|
||||
|
||||
// getPlugins gets all plugins
|
||||
func (h *PluginsHandler) getPlugins(ctx *fasthttp.RequestCtx) {
|
||||
if h.configStore == nil {
|
||||
pluginStatus := h.pluginsLoader.GetPluginStatus(ctx)
|
||||
finalPlugins := []PluginResponse{}
|
||||
for name, pluginStatus := range pluginStatus {
|
||||
finalPlugins = append(finalPlugins, PluginResponse{
|
||||
Name: pluginStatus.Name,
|
||||
ActualName: name,
|
||||
Enabled: true,
|
||||
Config: map[string]any{},
|
||||
IsCustom: true,
|
||||
Path: nil,
|
||||
Status: pluginStatus,
|
||||
})
|
||||
}
|
||||
SendJSON(ctx, map[string]any{
|
||||
"plugins": finalPlugins,
|
||||
"count": len(finalPlugins),
|
||||
})
|
||||
return
|
||||
}
|
||||
plugins, err := h.configStore.GetPlugins(ctx)
|
||||
if err != nil {
|
||||
logger.Error("failed to get plugins: %v", err)
|
||||
SendError(ctx, 500, "Failed to retrieve plugins")
|
||||
return
|
||||
}
|
||||
// Fetching status
|
||||
pluginStatuses := h.pluginsLoader.GetPluginStatus(ctx)
|
||||
// Creating ephemeral struct for the plugins
|
||||
finalPlugins := []PluginResponse{}
|
||||
|
||||
// Iterating over plugin status to get the plugin info
|
||||
for _, plugin := range plugins {
|
||||
pluginStatus := schemas.PluginStatus{
|
||||
Name: plugin.Name,
|
||||
Status: schemas.PluginStatusUninitialized,
|
||||
Logs: []string{},
|
||||
}
|
||||
if !plugin.Enabled {
|
||||
pluginStatus.Status = schemas.PluginStatusDisabled
|
||||
}
|
||||
for _, status := range pluginStatuses {
|
||||
if plugin.Name == status.Name {
|
||||
pluginStatus = status
|
||||
break
|
||||
}
|
||||
}
|
||||
finalPlugins = append(finalPlugins, PluginResponse{
|
||||
Name: plugin.Name,
|
||||
ActualName: pluginStatus.Name,
|
||||
Enabled: plugin.Enabled,
|
||||
Config: plugin.Config,
|
||||
IsCustom: plugin.IsCustom,
|
||||
Path: plugin.Path,
|
||||
Placement: plugin.Placement,
|
||||
Order: plugin.Order,
|
||||
Status: pluginStatus,
|
||||
})
|
||||
}
|
||||
// Creating ephemeral struct
|
||||
SendJSON(ctx, map[string]any{
|
||||
"plugins": finalPlugins,
|
||||
"count": len(finalPlugins),
|
||||
})
|
||||
}
|
||||
|
||||
// getPlugin gets a plugin by name
|
||||
func (h *PluginsHandler) getPlugin(ctx *fasthttp.RequestCtx) {
|
||||
if h.configStore == nil {
|
||||
pluginStatus := h.pluginsLoader.GetPluginStatus(ctx)
|
||||
pluginInfo := PluginResponse{}
|
||||
for name, pluginStatus := range pluginStatus {
|
||||
if pluginStatus.Name == ctx.UserValue("name") {
|
||||
pluginInfo = PluginResponse{
|
||||
Name: pluginStatus.Name,
|
||||
ActualName: name,
|
||||
Enabled: true,
|
||||
Config: map[string]any{},
|
||||
IsCustom: true,
|
||||
Path: nil,
|
||||
Status: pluginStatus,
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
SendJSON(ctx, pluginInfo)
|
||||
return
|
||||
}
|
||||
// Safely validate the "name" parameter
|
||||
nameValue := ctx.UserValue("name")
|
||||
if nameValue == nil {
|
||||
logger.Warn("missing required 'name' parameter in request")
|
||||
SendError(ctx, 400, "Missing required 'name' parameter")
|
||||
return
|
||||
}
|
||||
|
||||
name, ok := nameValue.(string)
|
||||
if !ok {
|
||||
logger.Warn("invalid 'name' parameter type, expected string but got %T", nameValue)
|
||||
SendError(ctx, 400, "Invalid 'name' parameter type, expected string")
|
||||
return
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
logger.Warn("empty 'name' parameter provided")
|
||||
SendError(ctx, 400, "Empty 'name' parameter not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
plugin, err := h.configStore.GetPlugin(ctx, name)
|
||||
if err != nil {
|
||||
if errors.Is(err, configstore.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, "Plugin not found")
|
||||
return
|
||||
}
|
||||
logger.Error("failed to get plugin: %v", err)
|
||||
SendError(ctx, 500, "Failed to retrieve plugin")
|
||||
return
|
||||
}
|
||||
SendJSON(ctx, plugin)
|
||||
}
|
||||
|
||||
// createPlugin creates a new plugin
|
||||
func (h *PluginsHandler) createPlugin(ctx *fasthttp.RequestCtx) {
|
||||
if h.configStore == nil {
|
||||
SendError(ctx, 400, "Plugins creation is not supported when configstore is disabled")
|
||||
return
|
||||
}
|
||||
var request CreatePluginRequest
|
||||
if err := json.Unmarshal(ctx.PostBody(), &request); err != nil {
|
||||
logger.Error("failed to unmarshal create plugin request: %v", err)
|
||||
SendError(ctx, 400, "Invalid request body")
|
||||
return
|
||||
}
|
||||
// Validate required fields
|
||||
if request.Name == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Plugin name is required")
|
||||
return
|
||||
}
|
||||
// Validate placement value
|
||||
if request.Placement != nil && *request.Placement != "" &&
|
||||
*request.Placement != schemas.PluginPlacementPreBuiltin &&
|
||||
*request.Placement != schemas.PluginPlacementPostBuiltin {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Invalid placement value. Must be 'pre_builtin' or 'post_builtin'")
|
||||
return
|
||||
}
|
||||
if request.Placement != nil && *request.Placement == "" {
|
||||
request.Placement = nil
|
||||
}
|
||||
// Normalize empty path to nil (treat empty string as built-in plugin)
|
||||
if request.Path != nil && *request.Path == "" {
|
||||
request.Path = nil
|
||||
}
|
||||
// Check if plugin already exists
|
||||
existingPlugin, err := h.configStore.GetPlugin(ctx, request.Name)
|
||||
if err == nil && existingPlugin != nil {
|
||||
SendError(ctx, fasthttp.StatusConflict, "Plugin already exists")
|
||||
return
|
||||
}
|
||||
// Determine if this is a built-in or custom plugin
|
||||
isBuiltin := lib.IsBuiltinPlugin(request.Name)
|
||||
// Built-in plugins should not have a path
|
||||
if isBuiltin && request.Path != nil {
|
||||
request.Path = nil
|
||||
}
|
||||
// Create DB entry first to avoid orphaned in-memory state if DB write fails
|
||||
if err := h.configStore.CreatePlugin(ctx, &configstoreTables.TablePlugin{
|
||||
Name: request.Name,
|
||||
Enabled: request.Enabled,
|
||||
Config: request.Config,
|
||||
Path: request.Path,
|
||||
IsCustom: !isBuiltin,
|
||||
Placement: request.Placement,
|
||||
Order: request.Order,
|
||||
}); err != nil {
|
||||
logger.Error("failed to create plugin: %v", err)
|
||||
SendError(ctx, 500, "Failed to create plugin")
|
||||
return
|
||||
}
|
||||
|
||||
// Reload the plugin into memory if it's enabled
|
||||
if request.Enabled {
|
||||
if err := h.pluginsLoader.ReloadPlugin(ctx, request.Name, request.Path, request.Config, request.Placement, request.Order); err != nil {
|
||||
logger.Error("failed to load plugin: %v", err)
|
||||
if rbErr := h.configStore.DeletePlugin(ctx, request.Name); rbErr != nil {
|
||||
logger.Error("failed to rollback plugin creation: %v", rbErr)
|
||||
}
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Plugin created in database but failed to load: %v", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
plugin, err := h.configStore.GetPlugin(ctx, request.Name)
|
||||
if err != nil {
|
||||
logger.Error("failed to get plugin: %v", err)
|
||||
SendError(ctx, 500, "Failed to retrieve plugin")
|
||||
return
|
||||
}
|
||||
|
||||
ctx.SetStatusCode(fasthttp.StatusCreated)
|
||||
SendJSON(ctx, map[string]any{
|
||||
"message": "Plugin created successfully",
|
||||
"plugin": h.buildPluginResponse(ctx, plugin),
|
||||
})
|
||||
}
|
||||
|
||||
// updatePlugin updates an existing plugin
|
||||
func (h *PluginsHandler) updatePlugin(ctx *fasthttp.RequestCtx) {
|
||||
if h.configStore == nil {
|
||||
SendError(ctx, 400, "Plugins update is not supported when configstore is disabled")
|
||||
return
|
||||
}
|
||||
// Safely validate the "name" parameter
|
||||
nameValue := ctx.UserValue("name")
|
||||
if nameValue == nil {
|
||||
logger.Warn("missing required 'name' parameter in update plugin request")
|
||||
SendError(ctx, 400, "Missing required 'name' parameter")
|
||||
return
|
||||
}
|
||||
|
||||
name, ok := nameValue.(string)
|
||||
if !ok {
|
||||
logger.Warn("invalid 'name' parameter type in update plugin request, expected string but got %T", nameValue)
|
||||
SendError(ctx, 400, "Invalid 'name' parameter type, expected string")
|
||||
return
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
logger.Warn("empty 'name' parameter provided in update plugin request")
|
||||
SendError(ctx, 400, "Empty 'name' parameter not allowed")
|
||||
return
|
||||
}
|
||||
var plugin *configstoreTables.TablePlugin
|
||||
var err error
|
||||
// Check if plugin exists
|
||||
_, err = h.configStore.GetPlugin(ctx, name)
|
||||
if err != nil {
|
||||
// If doesn't exist, create it
|
||||
if errors.Is(err, configstore.ErrNotFound) {
|
||||
plugin = &configstoreTables.TablePlugin{
|
||||
Name: name,
|
||||
Enabled: false,
|
||||
Config: map[string]any{},
|
||||
Path: nil,
|
||||
IsCustom: false,
|
||||
}
|
||||
if err := h.configStore.CreatePlugin(ctx, plugin); err != nil {
|
||||
logger.Error("failed to create plugin: %v", err)
|
||||
SendError(ctx, 500, "Failed to create plugin")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
logger.Error("failed to get plugin: %v", err)
|
||||
SendError(ctx, 500, "Failed to update plugin")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Unmarshalling the request body
|
||||
var request UpdatePluginRequest
|
||||
if err := json.Unmarshal(ctx.PostBody(), &request); err != nil {
|
||||
logger.Error("failed to unmarshal update plugin request: %v", err)
|
||||
SendError(ctx, 400, "Invalid request body")
|
||||
return
|
||||
}
|
||||
// Validate placement value
|
||||
if request.Placement != nil && *request.Placement != "" &&
|
||||
*request.Placement != schemas.PluginPlacementPreBuiltin &&
|
||||
*request.Placement != schemas.PluginPlacementPostBuiltin {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Invalid placement value. Must be 'pre_builtin' or 'post_builtin'")
|
||||
return
|
||||
}
|
||||
if request.Placement != nil && *request.Placement == "" {
|
||||
request.Placement = nil
|
||||
}
|
||||
// Normalize empty path to nil (treat empty string as built-in plugin)
|
||||
if request.Path != nil && *request.Path == "" {
|
||||
request.Path = nil
|
||||
}
|
||||
// Determine if this is a built-in plugin
|
||||
isBuiltin := lib.IsBuiltinPlugin(name)
|
||||
// Built-in plugins should not have a path
|
||||
if isBuiltin && request.Path != nil {
|
||||
request.Path = nil
|
||||
}
|
||||
// Updating the plugin
|
||||
if err := h.configStore.UpdatePlugin(ctx, &configstoreTables.TablePlugin{
|
||||
Name: name,
|
||||
Enabled: request.Enabled,
|
||||
Config: request.Config,
|
||||
Path: request.Path,
|
||||
IsCustom: !isBuiltin,
|
||||
Placement: request.Placement,
|
||||
Order: request.Order,
|
||||
}); err != nil {
|
||||
logger.Error("failed to update plugin: %v", err)
|
||||
SendError(ctx, 500, "Failed to update plugin")
|
||||
return
|
||||
}
|
||||
plugin, err = h.configStore.GetPlugin(ctx, name)
|
||||
if err != nil {
|
||||
if errors.Is(err, configstore.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, "Plugin not found")
|
||||
return
|
||||
}
|
||||
logger.Error("failed to get plugin: %v", err)
|
||||
SendError(ctx, 500, "Failed to retrieve plugin")
|
||||
return
|
||||
}
|
||||
// We reload the plugin if its enabled, otherwise we stop it
|
||||
if request.Enabled {
|
||||
if err := h.pluginsLoader.ReloadPlugin(ctx, name, request.Path, request.Config, request.Placement, request.Order); err != nil {
|
||||
logger.Error("failed to load plugin: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Plugin updated in database but failed to load: %v", err))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
ctx.SetUserValue(PluginDisabledKey, true)
|
||||
if err := h.pluginsLoader.RemovePlugin(ctx, name); err != nil {
|
||||
if !errors.Is(err, plugins.ErrPluginNotFound) {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Plugin updated in database but failed to stop: %v", err))
|
||||
return
|
||||
}
|
||||
// If not found then we don't need to do anything
|
||||
}
|
||||
}
|
||||
|
||||
SendJSON(ctx, map[string]interface{}{
|
||||
"message": "Plugin updated successfully",
|
||||
"plugin": h.buildPluginResponse(ctx, plugin),
|
||||
})
|
||||
}
|
||||
|
||||
// deletePlugin deletes an existing plugin
|
||||
func (h *PluginsHandler) deletePlugin(ctx *fasthttp.RequestCtx) {
|
||||
if h.configStore == nil {
|
||||
SendError(ctx, 400, "Plugins deletion is not supported when configstore is disabled")
|
||||
return
|
||||
}
|
||||
// Safely validate the "name" parameter
|
||||
nameValue := ctx.UserValue("name")
|
||||
if nameValue == nil {
|
||||
logger.Warn("missing required 'name' parameter in delete plugin request")
|
||||
SendError(ctx, 400, "Missing required 'name' parameter")
|
||||
return
|
||||
}
|
||||
|
||||
name, ok := nameValue.(string)
|
||||
if !ok {
|
||||
logger.Warn("invalid 'name' parameter type in delete plugin request, expected string but got %T", nameValue)
|
||||
SendError(ctx, 400, "Invalid 'name' parameter type, expected string")
|
||||
return
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
logger.Warn("empty 'name' parameter provided in delete plugin request")
|
||||
SendError(ctx, 400, "Empty 'name' parameter not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.configStore.DeletePlugin(ctx, name); err != nil {
|
||||
if errors.Is(err, configstore.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, "Plugin not found")
|
||||
return
|
||||
}
|
||||
logger.Error("failed to delete plugin: %v", err)
|
||||
SendError(ctx, 500, "Failed to delete plugin")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.pluginsLoader.RemovePlugin(ctx, name); err != nil {
|
||||
if !errors.Is(err, plugins.ErrPluginNotFound) {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Plugin deleted in database but failed to stop: %v", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
SendJSON(ctx, map[string]interface{}{
|
||||
"message": "Plugin deleted successfully",
|
||||
})
|
||||
}
|
||||
149
transports/bifrost-http/handlers/pricing_override_test.go
Normal file
149
transports/bifrost-http/handlers/pricing_override_test.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/maximhq/bifrost/framework/modelcatalog"
|
||||
"github.com/maximhq/bifrost/plugins/governance"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
type pricingOverrideTestGovernanceManager struct{}
|
||||
|
||||
func (pricingOverrideTestGovernanceManager) GetGovernanceData(ctx context.Context) *governance.GovernanceData {
|
||||
return nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) ReloadVirtualKey(context.Context, string) (*configstoreTables.TableVirtualKey, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) RemoveVirtualKey(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) ReloadTeam(context.Context, string) (*configstoreTables.TableTeam, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) RemoveTeam(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) ReloadCustomer(context.Context, string) (*configstoreTables.TableCustomer, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) RemoveCustomer(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) ReloadModelConfig(context.Context, string) (*configstoreTables.TableModelConfig, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) RemoveModelConfig(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) ReloadProvider(context.Context, schemas.ModelProvider) (*configstoreTables.TableProvider, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) RemoveProvider(context.Context, schemas.ModelProvider) error {
|
||||
return nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) ReloadRoutingRule(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) RemoveRoutingRule(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) UpsertPricingOverride(context.Context, *configstoreTables.TablePricingOverride) error {
|
||||
return nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) DeletePricingOverride(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupPricingOverrideHandlerStore(t *testing.T) configstore.ConfigStore {
|
||||
t.Helper()
|
||||
|
||||
dbPath := t.TempDir() + "/config.db"
|
||||
store, err := configstore.NewConfigStore(context.Background(), &configstore.Config{
|
||||
Enabled: true,
|
||||
Type: configstore.ConfigStoreTypeSQLite,
|
||||
Config: &configstore.SQLiteConfig{
|
||||
Path: dbPath,
|
||||
},
|
||||
}, &mockLogger{})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = os.Remove(dbPath)
|
||||
})
|
||||
return store
|
||||
}
|
||||
|
||||
func newTestRequestCtx(body string) *fasthttp.RequestCtx {
|
||||
var req fasthttp.Request
|
||||
req.SetBodyString(body)
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&req, &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 12345}, nil)
|
||||
return ctx
|
||||
}
|
||||
|
||||
func TestUpdatePricingOverride_ReplacesFullBody(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
store := setupPricingOverrideHandlerStore(t)
|
||||
handler := &GovernanceHandler{
|
||||
configStore: store,
|
||||
governanceManager: pricingOverrideTestGovernanceManager{},
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
override := configstoreTables.TablePricingOverride{
|
||||
ID: "override-1",
|
||||
Name: "Original",
|
||||
ScopeKind: string(modelcatalog.ScopeKindGlobal),
|
||||
MatchType: string(modelcatalog.MatchTypeExact),
|
||||
Pattern: "gpt-4.1",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
PricingPatchJSON: `{"input_cost_per_token":1,"output_cost_per_token":2}`,
|
||||
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
|
||||
}
|
||||
require.NoError(t, store.CreatePricingOverride(context.Background(), &override))
|
||||
|
||||
// Patch replaces in full: send only input_cost_per_token.
|
||||
// output_cost_per_token must be absent from the stored patch afterwards,
|
||||
// confirming full-replace (not merge) semantics.
|
||||
body := `{
|
||||
"name":"Updated",
|
||||
"scope_kind":"global",
|
||||
"match_type":"exact",
|
||||
"pattern":"gpt-4.1",
|
||||
"request_types":["chat_completion"],
|
||||
"patch":{"input_cost_per_token":1.5}
|
||||
}`
|
||||
ctx := newTestRequestCtx(body)
|
||||
ctx.SetUserValue("id", override.ID)
|
||||
|
||||
handler.updatePricingOverride(ctx)
|
||||
|
||||
require.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
|
||||
stored, err := store.GetPricingOverrideByID(context.Background(), override.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated", stored.Name)
|
||||
|
||||
var patch modelcatalog.PricingOptions
|
||||
require.NoError(t, json.Unmarshal([]byte(stored.PricingPatchJSON), &patch))
|
||||
// Sent field must reflect the new value.
|
||||
require.NotNil(t, patch.InputCostPerToken)
|
||||
assert.Equal(t, 1.5, *patch.InputCostPerToken)
|
||||
// Omitted field must be cleared — patch is always fully replaced, not merged.
|
||||
assert.Nil(t, patch.OutputCostPerToken)
|
||||
assert.Empty(t, stored.ConfigHash)
|
||||
}
|
||||
1121
transports/bifrost-http/handlers/prompts.go
Normal file
1121
transports/bifrost-http/handlers/prompts.go
Normal file
File diff suppressed because it is too large
Load Diff
495
transports/bifrost-http/handlers/provider_keys.go
Normal file
495
transports/bifrost-http/handlers/provider_keys.go
Normal file
@@ -0,0 +1,495 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/google/uuid"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// ListProviderKeysResponse represents the response for listing keys for a provider.
|
||||
type ListProviderKeysResponse struct {
|
||||
Keys []schemas.Key `json:"keys"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
|
||||
func (h *ProviderHandler) listProviderKeys(ctx *fasthttp.RequestCtx) {
|
||||
provider, err := getProviderFromCtx(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
keys, err := h.inMemoryStore.GetProviderKeysRedacted(provider)
|
||||
if err != nil {
|
||||
if errors.Is(err, lib.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err))
|
||||
return
|
||||
}
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider keys: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
SendJSON(ctx, ListProviderKeysResponse{Keys: keys, Total: len(keys)})
|
||||
}
|
||||
|
||||
func (h *ProviderHandler) getProviderKey(ctx *fasthttp.RequestCtx) {
|
||||
provider, err := getProviderFromCtx(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
keyID, err := getKeyIDFromCtx(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
key, err := h.inMemoryStore.GetProviderKeyRedacted(provider, keyID)
|
||||
if err != nil {
|
||||
if errors.Is(err, lib.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err))
|
||||
return
|
||||
}
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider key: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
SendJSON(ctx, key)
|
||||
}
|
||||
|
||||
func (h *ProviderHandler) createProviderKey(ctx *fasthttp.RequestCtx) {
|
||||
provider, err := getProviderFromCtx(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
var key schemas.Key
|
||||
if err := sonic.Unmarshal(ctx.PostBody(), &key); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid JSON: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
providerConfig, err := h.inMemoryStore.GetProviderConfigRaw(provider)
|
||||
if err != nil {
|
||||
if errors.Is(err, lib.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err))
|
||||
return
|
||||
}
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider config: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.IsKeyLess {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Cannot add keys to a keyless provider")
|
||||
return
|
||||
}
|
||||
|
||||
baseProvider := provider
|
||||
if providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.BaseProviderType != "" {
|
||||
baseProvider = providerConfig.CustomProviderConfig.BaseProviderType
|
||||
}
|
||||
|
||||
if !bifrost.CanProviderKeyValueBeEmpty(baseProvider) && key.Value.GetValue() == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Key value must not be empty")
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateProviderKeyURL(provider, key); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := key.BlacklistedModels.Validate(); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid blacklisted_models: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := key.Aliases.Validate(); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid aliases: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if key.ID == "" {
|
||||
key.ID = uuid.NewString()
|
||||
}
|
||||
if key.Enabled == nil {
|
||||
key.Enabled = bifrost.Ptr(true)
|
||||
}
|
||||
|
||||
if err := h.inMemoryStore.AddProviderKey(ctx, provider, key); err != nil {
|
||||
logger.Warn("Failed to create key for provider %s: %v", provider, err)
|
||||
if errors.Is(err, lib.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err))
|
||||
return
|
||||
}
|
||||
if errors.Is(err, lib.ErrAlreadyExists) {
|
||||
SendError(ctx, fasthttp.StatusConflict, err.Error())
|
||||
return
|
||||
}
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to create provider key: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.attemptModelDiscovery(ctx, provider, providerConfig.CustomProviderConfig); err != nil {
|
||||
logger.Warn("Model discovery failed for provider %s after key create: %v", provider, err)
|
||||
}
|
||||
|
||||
redactedKey, err := h.inMemoryStore.GetProviderKeyRedacted(provider, key.ID)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get created provider key: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
SendJSON(ctx, redactedKey)
|
||||
}
|
||||
|
||||
func (h *ProviderHandler) updateProviderKey(ctx *fasthttp.RequestCtx) {
|
||||
provider, err := getProviderFromCtx(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
keyID, err := getKeyIDFromCtx(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var updateKey schemas.Key
|
||||
if err := sonic.Unmarshal(ctx.PostBody(), &updateKey); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid JSON: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
providerConfig, err := h.inMemoryStore.GetProviderConfigRaw(provider)
|
||||
if err != nil {
|
||||
if errors.Is(err, lib.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err))
|
||||
return
|
||||
}
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider config: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.IsKeyLess {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Cannot update keys on a keyless provider")
|
||||
return
|
||||
}
|
||||
|
||||
oldRawKey, err := h.inMemoryStore.GetProviderKeyRaw(provider, keyID)
|
||||
if err != nil {
|
||||
if errors.Is(err, lib.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err))
|
||||
return
|
||||
}
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider key: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
oldRedactedKey, err := h.inMemoryStore.GetProviderKeyRedacted(provider, keyID)
|
||||
if err != nil {
|
||||
if errors.Is(err, lib.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err))
|
||||
return
|
||||
}
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider key: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
updateKey.ID = keyID
|
||||
mergedKey := h.mergeUpdatedKey(*oldRawKey, *oldRedactedKey, updateKey)
|
||||
|
||||
baseProvider := provider
|
||||
if providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.BaseProviderType != "" {
|
||||
baseProvider = providerConfig.CustomProviderConfig.BaseProviderType
|
||||
}
|
||||
|
||||
if !bifrost.CanProviderKeyValueBeEmpty(baseProvider) && mergedKey.Value.GetValue() == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Key value must not be empty")
|
||||
return
|
||||
}
|
||||
|
||||
if err := mergedKey.BlacklistedModels.Validate(); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid blacklisted_models: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := mergedKey.Aliases.Validate(); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid aliases: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateProviderKeyURL(provider, mergedKey); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.inMemoryStore.UpdateProviderKey(ctx, provider, keyID, mergedKey); err != nil {
|
||||
logger.Warn("Failed to update key %s for provider %s: %v", keyID, provider, err)
|
||||
if errors.Is(err, lib.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err))
|
||||
return
|
||||
}
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to update provider key: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.attemptModelDiscovery(ctx, provider, providerConfig.CustomProviderConfig); err != nil {
|
||||
logger.Warn("Model discovery failed for provider %s after key update: %v", provider, err)
|
||||
}
|
||||
|
||||
redactedKey, err := h.inMemoryStore.GetProviderKeyRedacted(provider, keyID)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get updated provider key: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
SendJSON(ctx, redactedKey)
|
||||
}
|
||||
|
||||
func (h *ProviderHandler) deleteProviderKey(ctx *fasthttp.RequestCtx) {
|
||||
provider, err := getProviderFromCtx(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
keyID, err := getKeyIDFromCtx(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
providerConfig, err := h.inMemoryStore.GetProviderConfigRaw(provider)
|
||||
if err != nil {
|
||||
if errors.Is(err, lib.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err))
|
||||
return
|
||||
}
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider config: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.IsKeyLess {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Cannot delete keys on a keyless provider")
|
||||
return
|
||||
}
|
||||
|
||||
redactedKey, err := h.inMemoryStore.GetProviderKeyRedacted(provider, keyID)
|
||||
if err != nil {
|
||||
if errors.Is(err, lib.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err))
|
||||
return
|
||||
}
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider key: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.inMemoryStore.RemoveProviderKey(ctx, provider, keyID); err != nil {
|
||||
logger.Warn("Failed to delete key %s for provider %s: %v", keyID, provider, err)
|
||||
if errors.Is(err, lib.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err))
|
||||
return
|
||||
}
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to delete provider key: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.attemptModelDiscovery(ctx, provider, providerConfig.CustomProviderConfig); err != nil {
|
||||
logger.Warn("Model discovery failed for provider %s after key delete: %v", provider, err)
|
||||
}
|
||||
|
||||
SendJSON(ctx, redactedKey)
|
||||
}
|
||||
|
||||
// mergeUpdatedKey merges an updated key with the old raw/redacted versions,
|
||||
// preserving real values for fields that were sent back in redacted form.
|
||||
func (h *ProviderHandler) mergeUpdatedKey(oldRawKey, oldRedactedKey, updateKey schemas.Key) schemas.Key {
|
||||
mergedKey := updateKey
|
||||
|
||||
if updateKey.Value.IsRedacted() && updateKey.Value.Equals(&oldRedactedKey.Value) {
|
||||
mergedKey.Value = oldRawKey.Value
|
||||
}
|
||||
|
||||
if updateKey.AzureKeyConfig != nil && oldRedactedKey.AzureKeyConfig != nil && oldRawKey.AzureKeyConfig != nil {
|
||||
if updateKey.AzureKeyConfig.Endpoint.IsRedacted() &&
|
||||
updateKey.AzureKeyConfig.Endpoint.Equals(&oldRedactedKey.AzureKeyConfig.Endpoint) {
|
||||
mergedKey.AzureKeyConfig.Endpoint = oldRawKey.AzureKeyConfig.Endpoint
|
||||
}
|
||||
if updateKey.AzureKeyConfig.APIVersion != nil &&
|
||||
oldRedactedKey.AzureKeyConfig.APIVersion != nil &&
|
||||
oldRawKey.AzureKeyConfig != nil &&
|
||||
updateKey.AzureKeyConfig.APIVersion.IsRedacted() &&
|
||||
updateKey.AzureKeyConfig.APIVersion.Equals(oldRedactedKey.AzureKeyConfig.APIVersion) {
|
||||
mergedKey.AzureKeyConfig.APIVersion = oldRawKey.AzureKeyConfig.APIVersion
|
||||
}
|
||||
if updateKey.AzureKeyConfig.ClientID != nil &&
|
||||
oldRedactedKey.AzureKeyConfig.ClientID != nil &&
|
||||
oldRawKey.AzureKeyConfig != nil &&
|
||||
updateKey.AzureKeyConfig.ClientID.IsRedacted() &&
|
||||
updateKey.AzureKeyConfig.ClientID.Equals(oldRedactedKey.AzureKeyConfig.ClientID) {
|
||||
mergedKey.AzureKeyConfig.ClientID = oldRawKey.AzureKeyConfig.ClientID
|
||||
}
|
||||
if updateKey.AzureKeyConfig.ClientSecret != nil &&
|
||||
oldRedactedKey.AzureKeyConfig.ClientSecret != nil &&
|
||||
oldRawKey.AzureKeyConfig != nil &&
|
||||
updateKey.AzureKeyConfig.ClientSecret.IsRedacted() &&
|
||||
updateKey.AzureKeyConfig.ClientSecret.Equals(oldRedactedKey.AzureKeyConfig.ClientSecret) {
|
||||
mergedKey.AzureKeyConfig.ClientSecret = oldRawKey.AzureKeyConfig.ClientSecret
|
||||
}
|
||||
if updateKey.AzureKeyConfig.TenantID != nil &&
|
||||
oldRedactedKey.AzureKeyConfig.TenantID != nil &&
|
||||
oldRawKey.AzureKeyConfig != nil &&
|
||||
updateKey.AzureKeyConfig.TenantID.IsRedacted() &&
|
||||
updateKey.AzureKeyConfig.TenantID.Equals(oldRedactedKey.AzureKeyConfig.TenantID) {
|
||||
mergedKey.AzureKeyConfig.TenantID = oldRawKey.AzureKeyConfig.TenantID
|
||||
}
|
||||
}
|
||||
|
||||
if updateKey.VertexKeyConfig != nil && oldRedactedKey.VertexKeyConfig != nil && oldRawKey.VertexKeyConfig != nil {
|
||||
if updateKey.VertexKeyConfig.ProjectID.IsRedacted() &&
|
||||
updateKey.VertexKeyConfig.ProjectID.Equals(&oldRedactedKey.VertexKeyConfig.ProjectID) {
|
||||
mergedKey.VertexKeyConfig.ProjectID = oldRawKey.VertexKeyConfig.ProjectID
|
||||
}
|
||||
if updateKey.VertexKeyConfig.ProjectNumber.IsRedacted() &&
|
||||
updateKey.VertexKeyConfig.ProjectNumber.Equals(&oldRedactedKey.VertexKeyConfig.ProjectNumber) {
|
||||
mergedKey.VertexKeyConfig.ProjectNumber = oldRawKey.VertexKeyConfig.ProjectNumber
|
||||
}
|
||||
if updateKey.VertexKeyConfig.Region.IsRedacted() &&
|
||||
updateKey.VertexKeyConfig.Region.Equals(&oldRedactedKey.VertexKeyConfig.Region) {
|
||||
mergedKey.VertexKeyConfig.Region = oldRawKey.VertexKeyConfig.Region
|
||||
}
|
||||
if updateKey.VertexKeyConfig.AuthCredentials.IsRedacted() &&
|
||||
updateKey.VertexKeyConfig.AuthCredentials.Equals(&oldRedactedKey.VertexKeyConfig.AuthCredentials) {
|
||||
mergedKey.VertexKeyConfig.AuthCredentials = oldRawKey.VertexKeyConfig.AuthCredentials
|
||||
}
|
||||
}
|
||||
|
||||
if updateKey.BedrockKeyConfig != nil && oldRedactedKey.BedrockKeyConfig != nil && oldRawKey.BedrockKeyConfig != nil {
|
||||
if updateKey.BedrockKeyConfig.AccessKey.IsRedacted() &&
|
||||
updateKey.BedrockKeyConfig.AccessKey.Equals(&oldRedactedKey.BedrockKeyConfig.AccessKey) {
|
||||
mergedKey.BedrockKeyConfig.AccessKey = oldRawKey.BedrockKeyConfig.AccessKey
|
||||
}
|
||||
if updateKey.BedrockKeyConfig.SecretKey.IsRedacted() &&
|
||||
updateKey.BedrockKeyConfig.SecretKey.Equals(&oldRedactedKey.BedrockKeyConfig.SecretKey) {
|
||||
mergedKey.BedrockKeyConfig.SecretKey = oldRawKey.BedrockKeyConfig.SecretKey
|
||||
}
|
||||
if updateKey.BedrockKeyConfig.SessionToken != nil &&
|
||||
oldRedactedKey.BedrockKeyConfig.SessionToken != nil &&
|
||||
oldRawKey.BedrockKeyConfig != nil &&
|
||||
updateKey.BedrockKeyConfig.SessionToken.IsRedacted() &&
|
||||
updateKey.BedrockKeyConfig.SessionToken.Equals(oldRedactedKey.BedrockKeyConfig.SessionToken) {
|
||||
mergedKey.BedrockKeyConfig.SessionToken = oldRawKey.BedrockKeyConfig.SessionToken
|
||||
}
|
||||
if updateKey.BedrockKeyConfig.Region != nil &&
|
||||
oldRedactedKey.BedrockKeyConfig.Region != nil &&
|
||||
oldRawKey.BedrockKeyConfig != nil &&
|
||||
updateKey.BedrockKeyConfig.Region.IsRedacted() &&
|
||||
updateKey.BedrockKeyConfig.Region.Equals(oldRedactedKey.BedrockKeyConfig.Region) {
|
||||
mergedKey.BedrockKeyConfig.Region = oldRawKey.BedrockKeyConfig.Region
|
||||
}
|
||||
if updateKey.BedrockKeyConfig.ARN != nil &&
|
||||
oldRedactedKey.BedrockKeyConfig.ARN != nil &&
|
||||
oldRawKey.BedrockKeyConfig != nil &&
|
||||
updateKey.BedrockKeyConfig.ARN.IsRedacted() &&
|
||||
updateKey.BedrockKeyConfig.ARN.Equals(oldRedactedKey.BedrockKeyConfig.ARN) {
|
||||
mergedKey.BedrockKeyConfig.ARN = oldRawKey.BedrockKeyConfig.ARN
|
||||
}
|
||||
if updateKey.BedrockKeyConfig.RoleARN != nil &&
|
||||
oldRedactedKey.BedrockKeyConfig.RoleARN != nil &&
|
||||
oldRawKey.BedrockKeyConfig != nil &&
|
||||
updateKey.BedrockKeyConfig.RoleARN.IsRedacted() &&
|
||||
updateKey.BedrockKeyConfig.RoleARN.Equals(oldRedactedKey.BedrockKeyConfig.RoleARN) {
|
||||
mergedKey.BedrockKeyConfig.RoleARN = oldRawKey.BedrockKeyConfig.RoleARN
|
||||
}
|
||||
if updateKey.BedrockKeyConfig.ExternalID != nil &&
|
||||
oldRedactedKey.BedrockKeyConfig.ExternalID != nil &&
|
||||
oldRawKey.BedrockKeyConfig != nil &&
|
||||
updateKey.BedrockKeyConfig.ExternalID.IsRedacted() &&
|
||||
updateKey.BedrockKeyConfig.ExternalID.Equals(oldRedactedKey.BedrockKeyConfig.ExternalID) {
|
||||
mergedKey.BedrockKeyConfig.ExternalID = oldRawKey.BedrockKeyConfig.ExternalID
|
||||
}
|
||||
if updateKey.BedrockKeyConfig.RoleSessionName != nil &&
|
||||
oldRedactedKey.BedrockKeyConfig.RoleSessionName != nil &&
|
||||
oldRawKey.BedrockKeyConfig != nil &&
|
||||
updateKey.BedrockKeyConfig.RoleSessionName.IsRedacted() &&
|
||||
updateKey.BedrockKeyConfig.RoleSessionName.Equals(oldRedactedKey.BedrockKeyConfig.RoleSessionName) {
|
||||
mergedKey.BedrockKeyConfig.RoleSessionName = oldRawKey.BedrockKeyConfig.RoleSessionName
|
||||
}
|
||||
}
|
||||
|
||||
if updateKey.VLLMKeyConfig != nil && oldRedactedKey.VLLMKeyConfig != nil && oldRawKey.VLLMKeyConfig != nil {
|
||||
if updateKey.VLLMKeyConfig.URL.IsRedacted() &&
|
||||
updateKey.VLLMKeyConfig.URL.Equals(&oldRedactedKey.VLLMKeyConfig.URL) {
|
||||
mergedKey.VLLMKeyConfig.URL = oldRawKey.VLLMKeyConfig.URL
|
||||
}
|
||||
}
|
||||
|
||||
// ReplicateKeyConfig has no sensitive fields — pass through as-is
|
||||
if updateKey.ReplicateKeyConfig == nil && oldRawKey.ReplicateKeyConfig != nil {
|
||||
mergedKey.ReplicateKeyConfig = oldRawKey.ReplicateKeyConfig
|
||||
}
|
||||
|
||||
if updateKey.OllamaKeyConfig != nil && oldRedactedKey.OllamaKeyConfig != nil && oldRawKey.OllamaKeyConfig != nil {
|
||||
if updateKey.OllamaKeyConfig.URL.IsRedacted() &&
|
||||
updateKey.OllamaKeyConfig.URL.Equals(&oldRedactedKey.OllamaKeyConfig.URL) {
|
||||
mergedKey.OllamaKeyConfig.URL = oldRawKey.OllamaKeyConfig.URL
|
||||
}
|
||||
}
|
||||
|
||||
if updateKey.SGLKeyConfig != nil && oldRedactedKey.SGLKeyConfig != nil && oldRawKey.SGLKeyConfig != nil {
|
||||
if updateKey.SGLKeyConfig.URL.IsRedacted() &&
|
||||
updateKey.SGLKeyConfig.URL.Equals(&oldRedactedKey.SGLKeyConfig.URL) {
|
||||
mergedKey.SGLKeyConfig.URL = oldRawKey.SGLKeyConfig.URL
|
||||
}
|
||||
}
|
||||
|
||||
mergedKey.ConfigHash = oldRawKey.ConfigHash
|
||||
mergedKey.Status = oldRawKey.Status
|
||||
|
||||
return mergedKey
|
||||
}
|
||||
|
||||
func getKeyIDFromCtx(ctx *fasthttp.RequestCtx) (string, error) {
|
||||
keyValue := ctx.UserValue("key_id")
|
||||
if keyValue == nil {
|
||||
return "", fmt.Errorf("missing key_id parameter")
|
||||
}
|
||||
|
||||
keyID, ok := keyValue.(string)
|
||||
if !ok || keyID == "" {
|
||||
return "", fmt.Errorf("invalid key_id parameter")
|
||||
}
|
||||
|
||||
decoded, err := url.PathUnescape(keyID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid key_id parameter encoding: %v", err)
|
||||
}
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// validateProviderKeyURL checks that Ollama/SGL keys have a server URL configured.
|
||||
func validateProviderKeyURL(provider schemas.ModelProvider, key schemas.Key) error {
|
||||
switch provider {
|
||||
case schemas.Ollama:
|
||||
if key.OllamaKeyConfig == nil || !key.OllamaKeyConfig.URL.IsDefined() {
|
||||
return fmt.Errorf("ollama_key_config.url is required for Ollama keys")
|
||||
}
|
||||
case schemas.SGL:
|
||||
if key.SGLKeyConfig == nil || !key.SGLKeyConfig.URL.IsDefined() {
|
||||
return fmt.Errorf("sgl_key_config.url is required for SGL keys")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
1097
transports/bifrost-http/handlers/providers.go
Normal file
1097
transports/bifrost-http/handlers/providers.go
Normal file
File diff suppressed because it is too large
Load Diff
550
transports/bifrost-http/handlers/providers_test.go
Normal file
550
transports/bifrost-http/handlers/providers_test.go
Normal file
@@ -0,0 +1,550 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/maximhq/bifrost/framework/modelcatalog"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// mockModelsManager returns stable filtered and unfiltered model lists for handler tests.
|
||||
type mockModelsManager struct {
|
||||
filtered map[schemas.ModelProvider][]string
|
||||
unfiltered map[schemas.ModelProvider][]string
|
||||
reloadCalls []schemas.ModelProvider
|
||||
reloadErr error
|
||||
}
|
||||
|
||||
func (m *mockModelsManager) ReloadProvider(_ context.Context, provider schemas.ModelProvider) (*configstoreTables.TableProvider, error) {
|
||||
m.reloadCalls = append(m.reloadCalls, provider)
|
||||
if m.reloadErr != nil {
|
||||
return nil, m.reloadErr
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockModelsManager) RemoveProvider(_ context.Context, _ schemas.ModelProvider) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockModelsManager) GetModelsForProvider(provider schemas.ModelProvider) []string {
|
||||
models := m.filtered[provider]
|
||||
result := make([]string, len(models))
|
||||
copy(result, models)
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *mockModelsManager) GetUnfilteredModelsForProvider(provider schemas.ModelProvider) []string {
|
||||
models := m.unfiltered[provider]
|
||||
result := make([]string, len(models))
|
||||
copy(result, models)
|
||||
return result
|
||||
}
|
||||
|
||||
// providerHandlerForTest builds a handler with fixed provider config and model sets.
|
||||
func providerHandlerForTest(provider schemas.ModelProvider, keys []schemas.Key, filtered, unfiltered []string) *ProviderHandler {
|
||||
return &ProviderHandler{
|
||||
inMemoryStore: &lib.Config{
|
||||
Providers: map[schemas.ModelProvider]configstore.ProviderConfig{
|
||||
provider: {
|
||||
Keys: keys,
|
||||
},
|
||||
},
|
||||
},
|
||||
modelsManager: &mockModelsManager{
|
||||
filtered: map[schemas.ModelProvider][]string{
|
||||
provider: filtered,
|
||||
},
|
||||
unfiltered: map[schemas.ModelProvider][]string{
|
||||
provider: unfiltered,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddProvider_ReloadsRuntimeEvenWhenModelDiscoveryIsSkipped(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
lib.SetLogger(&mockLogger{})
|
||||
|
||||
modelsManager := &mockModelsManager{}
|
||||
h := &ProviderHandler{
|
||||
inMemoryStore: &lib.Config{Providers: map[schemas.ModelProvider]configstore.ProviderConfig{}},
|
||||
modelsManager: modelsManager,
|
||||
}
|
||||
|
||||
body, err := sonic.Marshal(providerCreatePayload{
|
||||
Provider: "mock-openai",
|
||||
CustomProviderConfig: &schemas.CustomProviderConfig{
|
||||
BaseProviderType: schemas.OpenAI,
|
||||
IsKeyLess: true,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal request body: %v", err)
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod(fasthttp.MethodPost)
|
||||
ctx.Request.SetRequestURI("/api/providers")
|
||||
ctx.Request.SetBody(body)
|
||||
|
||||
h.addProvider(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
if len(modelsManager.reloadCalls) != 1 || modelsManager.reloadCalls[0] != "mock-openai" {
|
||||
t.Fatalf("expected provider reload for mock-openai, got %#v", modelsManager.reloadCalls)
|
||||
}
|
||||
if _, exists := h.inMemoryStore.Providers["mock-openai"]; !exists {
|
||||
t.Fatalf("expected provider to be added to in-memory store")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddProvider_ReturnsErrorWhenRuntimeReloadFails(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
lib.SetLogger(&mockLogger{})
|
||||
|
||||
modelsManager := &mockModelsManager{reloadErr: context.DeadlineExceeded}
|
||||
h := &ProviderHandler{
|
||||
inMemoryStore: &lib.Config{Providers: map[schemas.ModelProvider]configstore.ProviderConfig{}},
|
||||
modelsManager: modelsManager,
|
||||
}
|
||||
|
||||
body, err := sonic.Marshal(providerCreatePayload{
|
||||
Provider: "mock-openai",
|
||||
CustomProviderConfig: &schemas.CustomProviderConfig{
|
||||
BaseProviderType: schemas.OpenAI,
|
||||
IsKeyLess: true,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal request body: %v", err)
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod(fasthttp.MethodPost)
|
||||
ctx.Request.SetRequestURI("/api/providers")
|
||||
ctx.Request.SetBody(body)
|
||||
|
||||
h.addProvider(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusInternalServerError {
|
||||
t.Fatalf("expected 500, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
if len(modelsManager.reloadCalls) != 1 || modelsManager.reloadCalls[0] != "mock-openai" {
|
||||
t.Fatalf("expected single provider reload for mock-openai, got %#v", modelsManager.reloadCalls)
|
||||
}
|
||||
var bifrostErr schemas.BifrostError
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &bifrostErr); err != nil {
|
||||
t.Fatalf("failed to unmarshal error response: %v", err)
|
||||
}
|
||||
if bifrostErr.Error == nil || bifrostErr.Error.Message == "" {
|
||||
t.Fatalf("expected error message in response, got %#v", bifrostErr)
|
||||
}
|
||||
if bifrostErr.Error.Message != "Failed to initialize provider after add: context deadline exceeded" {
|
||||
t.Fatalf("unexpected error message: %q", bifrostErr.Error.Message)
|
||||
}
|
||||
if _, exists := h.inMemoryStore.Providers["mock-openai"]; exists {
|
||||
t.Fatalf("expected provider rollback after reload failure")
|
||||
}
|
||||
}
|
||||
|
||||
// boolPtr keeps pointer-valued key fixtures inline without pulling in pointer helpers.
|
||||
func boolPtr(v bool) *bool {
|
||||
return &v
|
||||
}
|
||||
|
||||
func TestListModels_UnknownKeysDoNotFilter(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := providerHandlerForTest(
|
||||
schemas.OpenAI,
|
||||
[]schemas.Key{{ID: "key-a"}},
|
||||
[]string{"gpt-4o", "gpt-4o-mini"},
|
||||
[]string{"gpt-4o", "gpt-4o-mini"},
|
||||
)
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/api/models?provider=openai&keys=missing")
|
||||
|
||||
h.listModels(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
var resp ListModelsResponse
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Total != 2 {
|
||||
t.Fatalf("expected total=2, got %d", resp.Total)
|
||||
}
|
||||
if len(resp.Models) != 2 {
|
||||
t.Fatalf("expected all models to be returned, got %#v", resp.Models)
|
||||
}
|
||||
for _, model := range resp.Models {
|
||||
if len(model.AccessibleByKeys) != 0 {
|
||||
t.Fatalf("expected no accessible_by_keys annotations, got %#v", resp.Models)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestListModels_ReturnsExactAccessibleByKeysAndSkipsDisabledKeys(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := providerHandlerForTest(
|
||||
schemas.OpenAI,
|
||||
[]schemas.Key{
|
||||
{ID: "key-a", Models: []string{"gpt-4o"}},
|
||||
{ID: "key-b", Models: []string{"gpt-4o", "gpt-4o-mini"}},
|
||||
{ID: "key-disabled", Enabled: boolPtr(false)},
|
||||
},
|
||||
[]string{"gpt-4o", "gpt-4o-mini"},
|
||||
[]string{"gpt-4o", "gpt-4o-mini"},
|
||||
)
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/api/models?provider=openai&keys=key-a,key-b,key-disabled")
|
||||
|
||||
h.listModels(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
var resp ListModelsResponse
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Total != 2 {
|
||||
t.Fatalf("expected total=2, got %d", resp.Total)
|
||||
}
|
||||
|
||||
got := map[string][]string{}
|
||||
for _, model := range resp.Models {
|
||||
got[model.Name] = model.AccessibleByKeys
|
||||
}
|
||||
|
||||
if len(got["gpt-4o"]) != 2 || got["gpt-4o"][0] != "key-a" || got["gpt-4o"][1] != "key-b" {
|
||||
t.Fatalf("expected gpt-4o to be accessible by [key-a key-b], got %#v", got["gpt-4o"])
|
||||
}
|
||||
if len(got["gpt-4o-mini"]) != 1 || got["gpt-4o-mini"][0] != "key-b" {
|
||||
t.Fatalf("expected gpt-4o-mini to be accessible by [key-b], got %#v", got["gpt-4o-mini"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestListModels_AppliesQueryAndLimitAfterFiltering(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := providerHandlerForTest(
|
||||
schemas.OpenAI,
|
||||
[]schemas.Key{{ID: "key-a"}},
|
||||
[]string{"gpt-4o", "gpt-4o-mini", "claude-3-5-sonnet"},
|
||||
[]string{"gpt-4o", "gpt-4o-mini", "claude-3-5-sonnet"},
|
||||
)
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/api/models?provider=openai&query=gpt&limit=1")
|
||||
|
||||
h.listModels(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
var resp ListModelsResponse
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Total != 2 {
|
||||
t.Fatalf("expected total=2 after query filtering, got %d", resp.Total)
|
||||
}
|
||||
if len(resp.Models) != 1 {
|
||||
t.Fatalf("expected limit to truncate response to 1 model, got %#v", resp.Models)
|
||||
}
|
||||
if resp.Models[0].Name != "gpt-4o" {
|
||||
t.Fatalf("expected first filtered model to be gpt-4o, got %#v", resp.Models[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestListModels_UnfilteredIgnoresKeys(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := providerHandlerForTest(
|
||||
schemas.OpenAI,
|
||||
[]schemas.Key{
|
||||
{ID: "key-b", Models: []string{"gpt-4o-mini"}},
|
||||
},
|
||||
[]string{"gpt-4o"},
|
||||
[]string{"gpt-4o", "gpt-4o-mini"},
|
||||
)
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/api/models?provider=openai&keys=key-b&unfiltered=true")
|
||||
|
||||
h.listModels(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
var resp ListModelsResponse
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Total != 2 || len(resp.Models) != 2 {
|
||||
t.Fatalf("expected both unfiltered models, got %#v", resp.Models)
|
||||
}
|
||||
|
||||
for _, model := range resp.Models {
|
||||
if len(model.AccessibleByKeys) != 0 {
|
||||
t.Fatalf("expected no accessible_by_keys when unfiltered bypasses key filtering, got %#v", resp.Models)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestListModels_UnfilteredWithoutKeysReturnsAllUnfilteredModels(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := providerHandlerForTest(
|
||||
schemas.OpenAI,
|
||||
[]schemas.Key{
|
||||
{ID: "key-b", Models: []string{"gpt-4o-mini"}},
|
||||
},
|
||||
[]string{"gpt-4o"},
|
||||
[]string{"gpt-4o", "gpt-4o-mini"},
|
||||
)
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/api/models?provider=openai&unfiltered=true")
|
||||
|
||||
h.listModels(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
var resp ListModelsResponse
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Total != 2 || len(resp.Models) != 2 {
|
||||
t.Fatalf("expected both unfiltered models, got %#v", resp.Models)
|
||||
}
|
||||
|
||||
for _, model := range resp.Models {
|
||||
if len(model.AccessibleByKeys) != 0 {
|
||||
t.Fatalf("expected no accessible_by_keys when no key filter is requested, got %#v", resp.Models)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestListModelDetails_ErrorsWhenModelCatalogUnavailable(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := providerHandlerForTest(
|
||||
schemas.OpenAI,
|
||||
[]schemas.Key{{ID: "key-a"}},
|
||||
[]string{"gpt-4o"},
|
||||
[]string{"gpt-4o"},
|
||||
)
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/api/models/details?provider=openai")
|
||||
|
||||
h.listModelDetails(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusInternalServerError {
|
||||
t.Fatalf("expected 500, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListModelDetails_UnknownKeysDoNotFilter(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := providerHandlerForTest(
|
||||
schemas.OpenAI,
|
||||
[]schemas.Key{{ID: "key-a"}},
|
||||
[]string{"gpt-4o", "gpt-4o-mini"},
|
||||
[]string{"gpt-4o", "gpt-4o-mini"},
|
||||
)
|
||||
h.inMemoryStore.ModelCatalog = &modelcatalog.ModelCatalog{}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/api/models/details?provider=openai&keys=missing")
|
||||
|
||||
h.listModelDetails(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
var resp ListModelDetailsResponse
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Total != 2 || len(resp.Models) != 2 {
|
||||
t.Fatalf("expected all models when keys are unknown, got %#v", resp.Models)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListModelDetails_SkipsUnknownKeysAndFiltersWithValid(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := providerHandlerForTest(
|
||||
schemas.OpenAI,
|
||||
[]schemas.Key{{ID: "key-a", Models: []string{"gpt-4o"}}},
|
||||
[]string{"gpt-4o", "gpt-4o-mini"},
|
||||
[]string{"gpt-4o", "gpt-4o-mini"},
|
||||
)
|
||||
h.inMemoryStore.ModelCatalog = &modelcatalog.ModelCatalog{}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/api/models/details?provider=openai&keys=key-a,missing")
|
||||
|
||||
h.listModelDetails(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
var resp ListModelDetailsResponse
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Total != 1 || len(resp.Models) != 1 {
|
||||
t.Fatalf("expected 1 model filtered by valid key, got %#v", resp.Models)
|
||||
}
|
||||
if resp.Models[0].Name != "gpt-4o" {
|
||||
t.Fatalf("expected gpt-4o, got %s", resp.Models[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListModelDetails_SkipsDisabledKeysAndFiltersWithValid(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := providerHandlerForTest(
|
||||
schemas.OpenAI,
|
||||
[]schemas.Key{
|
||||
{ID: "key-a", Models: []string{"gpt-4o"}},
|
||||
{ID: "key-disabled", Enabled: boolPtr(false)},
|
||||
},
|
||||
[]string{"gpt-4o", "gpt-4o-mini"},
|
||||
[]string{"gpt-4o", "gpt-4o-mini"},
|
||||
)
|
||||
h.inMemoryStore.ModelCatalog = &modelcatalog.ModelCatalog{}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/api/models/details?provider=openai&keys=key-a,key-disabled")
|
||||
|
||||
h.listModelDetails(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
var resp ListModelDetailsResponse
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Total != 1 || len(resp.Models) != 1 {
|
||||
t.Fatalf("expected 1 model filtered by valid key, got %#v", resp.Models)
|
||||
}
|
||||
if resp.Models[0].Name != "gpt-4o" {
|
||||
t.Fatalf("expected gpt-4o, got %s", resp.Models[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListModelDetails_UnfilteredIgnoresKeys(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := providerHandlerForTest(
|
||||
schemas.OpenAI,
|
||||
[]schemas.Key{
|
||||
{ID: "key-b", Models: []string{"gpt-4o-mini"}},
|
||||
},
|
||||
[]string{"gpt-4o"},
|
||||
[]string{"gpt-4o", "gpt-4o-mini"},
|
||||
)
|
||||
h.inMemoryStore.ModelCatalog = &modelcatalog.ModelCatalog{}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/api/models/details?provider=openai&keys=key-b&unfiltered=true")
|
||||
|
||||
h.listModelDetails(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
var resp ListModelDetailsResponse
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Total != 2 || len(resp.Models) != 2 {
|
||||
t.Fatalf("expected all unfiltered models when unfiltered=true, got %#v", resp.Models)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListModels_UsesCatalogAwareAliasMatchingForKeyAllowlist(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := providerHandlerForTest(
|
||||
schemas.OpenAI,
|
||||
[]schemas.Key{
|
||||
{ID: "key-a", Models: []string{"gpt-4o-2024-08-06"}},
|
||||
},
|
||||
[]string{"gpt-4o"},
|
||||
[]string{"gpt-4o"},
|
||||
)
|
||||
h.inMemoryStore.ModelCatalog = modelcatalog.NewTestCatalog(map[string]string{
|
||||
"gpt-4o-2024-08-06": "gpt-4o",
|
||||
})
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/api/models?provider=openai&keys=key-a")
|
||||
|
||||
h.listModels(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
var resp ListModelsResponse
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Total != 1 || len(resp.Models) != 1 || resp.Models[0].Name != "gpt-4o" {
|
||||
t.Fatalf("expected gpt-4o to be matched through alias allowlist, got %#v", resp.Models)
|
||||
}
|
||||
}
|
||||
419
transports/bifrost-http/handlers/realtime_client_secrets.go
Normal file
419
transports/bifrost-http/handlers/realtime_client_secrets.go
Normal file
@@ -0,0 +1,419 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"mime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/plugins/governance"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/integrations"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// RealtimeClientSecretsHandler exposes OpenAI-compatible HTTP routes for
|
||||
// minting short-lived Realtime client secrets.
|
||||
type RealtimeClientSecretsHandler struct {
|
||||
client *bifrost.Bifrost
|
||||
config *lib.Config
|
||||
handlerStore lib.HandlerStore
|
||||
routeSpecs map[string]schemas.RealtimeSessionRoute
|
||||
}
|
||||
|
||||
func NewRealtimeClientSecretsHandler(client *bifrost.Bifrost, config *lib.Config) *RealtimeClientSecretsHandler {
|
||||
return &RealtimeClientSecretsHandler{
|
||||
client: client,
|
||||
config: config,
|
||||
handlerStore: config,
|
||||
routeSpecs: make(map[string]schemas.RealtimeSessionRoute),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *RealtimeClientSecretsHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
handler := lib.ChainMiddlewares(h.handleRequest, middlewares...)
|
||||
for _, route := range h.realtimeSessionRoutes() {
|
||||
h.routeSpecs[route.Path] = route
|
||||
r.POST(route.Path, handler)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *RealtimeClientSecretsHandler) findGovernancePlugin() governance.BaseGovernancePlugin {
|
||||
basePlugins := h.config.BasePlugins.Load()
|
||||
if basePlugins == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, plugin := range *basePlugins {
|
||||
if governancePlugin, ok := plugin.(governance.BaseGovernancePlugin); ok {
|
||||
return governancePlugin
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *RealtimeClientSecretsHandler) handleRequest(ctx *fasthttp.RequestCtx) {
|
||||
if !isJSONContentType(string(ctx.Request.Header.ContentType())) {
|
||||
SendBifrostError(ctx, newRealtimeClientSecretHandlerError(
|
||||
fasthttp.StatusBadRequest,
|
||||
"invalid_request_error",
|
||||
"Content-Type must be application/json",
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
body := append([]byte(nil), ctx.Request.Body()...)
|
||||
route, ok := h.routeSpecs[string(ctx.Path())]
|
||||
if !ok {
|
||||
SendBifrostError(ctx, newRealtimeClientSecretHandlerError(
|
||||
fasthttp.StatusNotFound,
|
||||
"invalid_request_error",
|
||||
"unsupported realtime client secret route",
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
providerKey, model, normalizedBody, err := resolveRealtimeClientSecretTarget(route, body)
|
||||
if err != nil {
|
||||
SendBifrostError(ctx, err)
|
||||
return
|
||||
}
|
||||
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(
|
||||
ctx,
|
||||
h.handlerStore.ShouldAllowDirectKeys(),
|
||||
h.config.GetHeaderMatcher(),
|
||||
h.config.GetMCPHeaderCombinedAllowlist(),
|
||||
)
|
||||
defer cancel()
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyHTTPRequestType, schemas.RealtimeRequest)
|
||||
if route.DefaultProvider == schemas.OpenAI {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyIntegrationType, "openai")
|
||||
}
|
||||
if governanceUserID, ok := ctx.UserValue(schemas.BifrostContextKeyUserID).(string); ok && governanceUserID != "" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyUserID, governanceUserID)
|
||||
}
|
||||
if userName, ok := ctx.UserValue(schemas.BifrostContextKeyUserName).(string); ok && userName != "" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyUserName, userName)
|
||||
}
|
||||
if bifrostErr := h.evaluateMintingGovernance(bifrostCtx, providerKey, model); bifrostErr != nil {
|
||||
SendBifrostError(ctx, bifrostErr)
|
||||
return
|
||||
}
|
||||
|
||||
provider := h.client.GetProviderByKey(providerKey)
|
||||
if provider == nil {
|
||||
SendBifrostError(ctx, newRealtimeClientSecretHandlerError(
|
||||
fasthttp.StatusBadRequest,
|
||||
"invalid_request_error",
|
||||
"provider not found: "+string(providerKey),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
key, keyErr := h.client.SelectKeyForProviderRequestType(bifrostCtx, schemas.RealtimeRequest, providerKey, model)
|
||||
if keyErr != nil {
|
||||
SendBifrostError(ctx, newRealtimeClientSecretHandlerError(
|
||||
fasthttp.StatusBadRequest,
|
||||
"invalid_request_error",
|
||||
keyErr.Error(),
|
||||
keyErr,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve model aliases now that the key is selected so the forwarded body
|
||||
// carries the provider's canonical model, matching wsrealtime/webrtc flows.
|
||||
if resolved := key.Aliases.Resolve(model); resolved != "" && resolved != model {
|
||||
model = resolved
|
||||
reparsed, parseErr := schemas.ParseRealtimeClientSecretBody(normalizedBody)
|
||||
if parseErr != nil {
|
||||
SendBifrostError(ctx, parseErr)
|
||||
return
|
||||
}
|
||||
rewritten, normalizeErr := normalizeRealtimeClientSecretBody(reparsed, model)
|
||||
if normalizeErr != nil {
|
||||
SendBifrostError(ctx, normalizeErr)
|
||||
return
|
||||
}
|
||||
normalizedBody = rewritten
|
||||
}
|
||||
|
||||
sessionProvider, ok := provider.(schemas.RealtimeSessionProvider)
|
||||
if !ok {
|
||||
SendBifrostError(ctx, realtimeSessionNotSupportedError(providerKey, provider))
|
||||
return
|
||||
}
|
||||
|
||||
resp, bifrostErr := sessionProvider.CreateRealtimeClientSecret(bifrostCtx, key, route.EndpointType, normalizedBody)
|
||||
if bifrostErr != nil {
|
||||
SendBifrostError(ctx, bifrostErr)
|
||||
return
|
||||
}
|
||||
cacheRealtimeEphemeralKeyMapping(
|
||||
h.handlerStore.GetKVStore(),
|
||||
resp.Body,
|
||||
key.ID,
|
||||
bifrost.GetStringFromContext(bifrostCtx, schemas.BifrostContextKeyVirtualKey),
|
||||
)
|
||||
|
||||
writeRealtimeClientSecretResponse(ctx, resp)
|
||||
}
|
||||
|
||||
func (h *RealtimeClientSecretsHandler) evaluateMintingGovernance(
|
||||
bifrostCtx *schemas.BifrostContext,
|
||||
providerKey schemas.ModelProvider,
|
||||
model string,
|
||||
) *schemas.BifrostError {
|
||||
governancePlugin := h.findGovernancePlugin()
|
||||
if governancePlugin == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, bifrostErr := governancePlugin.EvaluateGovernanceRequest(bifrostCtx, &governance.EvaluationRequest{
|
||||
VirtualKey: bifrost.GetStringFromContext(bifrostCtx, schemas.BifrostContextKeyVirtualKey),
|
||||
Provider: providerKey,
|
||||
Model: model,
|
||||
UserID: bifrost.GetStringFromContext(bifrostCtx, schemas.BifrostContextKeyUserID),
|
||||
}, schemas.RealtimeRequest)
|
||||
return bifrostErr
|
||||
}
|
||||
|
||||
func (h *RealtimeClientSecretsHandler) realtimeSessionRoutes() []schemas.RealtimeSessionRoute {
|
||||
routes := []schemas.RealtimeSessionRoute{
|
||||
{
|
||||
Path: "/v1/realtime/client_secrets",
|
||||
EndpointType: schemas.RealtimeSessionEndpointClientSecrets,
|
||||
},
|
||||
{
|
||||
Path: "/v1/realtime/sessions",
|
||||
EndpointType: schemas.RealtimeSessionEndpointSessions,
|
||||
},
|
||||
}
|
||||
|
||||
for _, path := range integrations.OpenAIRealtimeClientSecretPaths("/openai") {
|
||||
endpointType := schemas.RealtimeSessionEndpointClientSecrets
|
||||
if strings.HasSuffix(path, "/realtime/sessions") {
|
||||
endpointType = schemas.RealtimeSessionEndpointSessions
|
||||
}
|
||||
routes = append(routes, schemas.RealtimeSessionRoute{
|
||||
Path: path,
|
||||
EndpointType: endpointType,
|
||||
DefaultProvider: schemas.OpenAI,
|
||||
})
|
||||
}
|
||||
return routes
|
||||
}
|
||||
|
||||
func resolveRealtimeClientSecretTarget(route schemas.RealtimeSessionRoute, body []byte) (schemas.ModelProvider, string, []byte, *schemas.BifrostError) {
|
||||
root, err := schemas.ParseRealtimeClientSecretBody(body)
|
||||
if err != nil {
|
||||
return "", "", nil, err
|
||||
}
|
||||
|
||||
rawModel, err := schemas.ExtractRealtimeClientSecretModel(root)
|
||||
if err != nil {
|
||||
return "", "", nil, err
|
||||
}
|
||||
|
||||
defaultProvider := route.DefaultProvider
|
||||
providerKey, model := schemas.ParseModelString(rawModel, defaultProvider)
|
||||
if defaultProvider == "" && providerKey == "" {
|
||||
return "", "", nil, newRealtimeClientSecretHandlerError(
|
||||
fasthttp.StatusBadRequest,
|
||||
"invalid_request_error",
|
||||
"session.model must use provider/model on /v1 realtime client secret routes",
|
||||
nil,
|
||||
)
|
||||
}
|
||||
if providerKey == "" || model == "" {
|
||||
return "", "", nil, newRealtimeClientSecretHandlerError(
|
||||
fasthttp.StatusBadRequest,
|
||||
"invalid_request_error",
|
||||
"session.model is required",
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
// Normalize the forwarded body so the upstream provider sees the bare model
|
||||
// (strip provider prefix). Mirrors resolveRealtimeSDPTarget normalization.
|
||||
normalizedBody, normalizeErr := normalizeRealtimeClientSecretBody(root, model)
|
||||
if normalizeErr != nil {
|
||||
return "", "", nil, normalizeErr
|
||||
}
|
||||
|
||||
return providerKey, model, normalizedBody, nil
|
||||
}
|
||||
|
||||
func normalizeRealtimeClientSecretBody(root map[string]json.RawMessage, bareModel string) ([]byte, *schemas.BifrostError) {
|
||||
normalizedModel, marshalErr := json.Marshal(bareModel)
|
||||
if marshalErr != nil {
|
||||
return nil, newRealtimeClientSecretHandlerError(fasthttp.StatusInternalServerError, "server_error", "failed to encode normalized model", marshalErr)
|
||||
}
|
||||
|
||||
// Normalize session.model if present
|
||||
if sessionJSON, ok := root["session"]; ok && len(sessionJSON) > 0 {
|
||||
var session map[string]json.RawMessage
|
||||
if err := json.Unmarshal(sessionJSON, &session); err == nil {
|
||||
if _, hasModel := session["model"]; hasModel {
|
||||
session["model"] = normalizedModel
|
||||
rewritten, err := json.Marshal(session)
|
||||
if err != nil {
|
||||
return nil, newRealtimeClientSecretHandlerError(fasthttp.StatusInternalServerError, "server_error", "failed to re-encode session", err)
|
||||
}
|
||||
root["session"] = rewritten
|
||||
}
|
||||
}
|
||||
}
|
||||
// Normalize top-level model if present
|
||||
if _, ok := root["model"]; ok {
|
||||
root["model"] = normalizedModel
|
||||
}
|
||||
|
||||
normalized, marshalErr := json.Marshal(root)
|
||||
if marshalErr != nil {
|
||||
return nil, newRealtimeClientSecretHandlerError(fasthttp.StatusInternalServerError, "server_error", "failed to re-encode body", marshalErr)
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
const realtimeEphemeralKeyMappingPrefix = "realtime:ephemeral-key:"
|
||||
|
||||
type realtimeEphemeralKeyMapping struct {
|
||||
KeyID string `json:"key_id,omitempty"`
|
||||
VirtualKey string `json:"virtual_key,omitempty"`
|
||||
}
|
||||
|
||||
func cacheRealtimeEphemeralKeyMapping(kv schemas.KVStore, body []byte, keyID string, virtualKey string) {
|
||||
if kv == nil || len(body) == 0 || strings.TrimSpace(keyID) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
token, ttl, ok := parseRealtimeEphemeralKeyMapping(body)
|
||||
if !ok || strings.TrimSpace(token) == "" || ttl <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(realtimeEphemeralKeyMapping{
|
||||
KeyID: strings.TrimSpace(keyID),
|
||||
VirtualKey: strings.TrimSpace(virtualKey),
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn("failed to encode realtime ephemeral key mapping for key_id=%s: %v", keyID, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := kv.SetWithTTL(buildRealtimeEphemeralKeyMappingKey(token), payload, ttl); err != nil {
|
||||
logger.Warn("failed to cache realtime ephemeral key mapping for key_id=%s: %v", keyID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func parseRealtimeEphemeralKeyMapping(body []byte) (string, time.Duration, bool) {
|
||||
var root map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &root); err != nil {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
var clientSecret struct {
|
||||
Value string `json:"value"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
}
|
||||
|
||||
// OpenAI client_secrets responses expose the ephemeral token at the top level.
|
||||
// Keep accepting the nested shape too so the mapping logic stays compatible
|
||||
// with any provider/session endpoint variants that wrap the secret object.
|
||||
if err := json.Unmarshal(body, &clientSecret); err != nil || strings.TrimSpace(clientSecret.Value) == "" || clientSecret.ExpiresAt <= 0 {
|
||||
clientSecretRaw, ok := root["client_secret"]
|
||||
if !ok || len(clientSecretRaw) == 0 || string(clientSecretRaw) == "null" {
|
||||
return "", 0, false
|
||||
}
|
||||
if err := json.Unmarshal(clientSecretRaw, &clientSecret); err != nil {
|
||||
return "", 0, false
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(clientSecret.Value) == "" || clientSecret.ExpiresAt <= 0 {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
ttl := time.Until(time.Unix(clientSecret.ExpiresAt, 0))
|
||||
if ttl <= 0 {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
return clientSecret.Value, ttl, true
|
||||
}
|
||||
|
||||
func buildRealtimeEphemeralKeyMappingKey(token string) string {
|
||||
return realtimeEphemeralKeyMappingPrefix + strings.TrimSpace(token)
|
||||
}
|
||||
|
||||
func realtimeSessionNotSupportedError(providerKey schemas.ModelProvider, provider schemas.Provider) *schemas.BifrostError {
|
||||
if rtProvider, ok := provider.(schemas.RealtimeProvider); ok && rtProvider.SupportsRealtimeAPI() {
|
||||
return newRealtimeClientSecretHandlerError(
|
||||
fasthttp.StatusBadRequest,
|
||||
"invalid_request_error",
|
||||
fmt.Sprintf("provider %s supports realtime websocket connections but not realtime client secret creation", providerKey),
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
return newRealtimeClientSecretHandlerError(
|
||||
fasthttp.StatusBadRequest,
|
||||
"invalid_request_error",
|
||||
fmt.Sprintf("provider %s does not support realtime client secret creation", providerKey),
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
func newRealtimeClientSecretHandlerError(status int, errorType, message string, err error) *schemas.BifrostError {
|
||||
return &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
StatusCode: schemas.Ptr(status),
|
||||
Error: &schemas.ErrorField{
|
||||
Type: schemas.Ptr(errorType),
|
||||
Message: message,
|
||||
Error: err,
|
||||
},
|
||||
ExtraFields: schemas.BifrostErrorExtraFields{
|
||||
RequestType: schemas.RealtimeRequest,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func writeRealtimeClientSecretResponse(ctx *fasthttp.RequestCtx, resp *schemas.BifrostPassthroughResponse) {
|
||||
if resp == nil {
|
||||
SendBifrostError(ctx, newRealtimeClientSecretHandlerError(
|
||||
fasthttp.StatusInternalServerError,
|
||||
"server_error",
|
||||
"provider returned an empty realtime client secret response",
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
for key, value := range resp.Headers {
|
||||
ctx.Response.Header.Set(key, value)
|
||||
}
|
||||
if len(ctx.Response.Header.ContentType()) == 0 {
|
||||
ctx.SetContentType("application/json")
|
||||
}
|
||||
ctx.SetStatusCode(resp.StatusCode)
|
||||
ctx.SetBody(resp.Body)
|
||||
}
|
||||
|
||||
func isJSONContentType(contentType string) bool {
|
||||
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
mediaType = strings.ToLower(mediaType)
|
||||
return mediaType == "application/json" || strings.HasSuffix(mediaType, "+json")
|
||||
}
|
||||
414
transports/bifrost-http/handlers/realtime_client_secrets_test.go
Normal file
414
transports/bifrost-http/handlers/realtime_client_secrets_test.go
Normal file
@@ -0,0 +1,414 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/kvstore"
|
||||
"github.com/maximhq/bifrost/plugins/governance"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func TestResolveRealtimeClientSecretTarget(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
route schemas.RealtimeSessionRoute
|
||||
body []byte
|
||||
wantProvider schemas.ModelProvider
|
||||
wantModel string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "base route with session model",
|
||||
route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets},
|
||||
body: []byte(`{"session":{"model":"openai/gpt-4o-realtime-preview"}}`),
|
||||
wantProvider: schemas.OpenAI,
|
||||
wantModel: "gpt-4o-realtime-preview",
|
||||
},
|
||||
{
|
||||
name: "base route with top level model",
|
||||
route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/sessions", EndpointType: schemas.RealtimeSessionEndpointSessions},
|
||||
body: []byte(`{"model":"openai/gpt-4o-realtime-preview"}`),
|
||||
wantProvider: schemas.OpenAI,
|
||||
wantModel: "gpt-4o-realtime-preview",
|
||||
},
|
||||
{
|
||||
name: "openai alias uses bare model",
|
||||
route: schemas.RealtimeSessionRoute{Path: "/openai/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets, DefaultProvider: schemas.OpenAI},
|
||||
body: []byte(`{"session":{"model":"gpt-4o-realtime-preview"}}`),
|
||||
wantProvider: schemas.OpenAI,
|
||||
wantModel: "gpt-4o-realtime-preview",
|
||||
},
|
||||
{
|
||||
name: "base route rejects bare model",
|
||||
route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets},
|
||||
body: []byte(`{"session":{"model":"gpt-4o-realtime-preview"}}`),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing model",
|
||||
route: schemas.RealtimeSessionRoute{Path: "/openai/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets, DefaultProvider: schemas.OpenAI},
|
||||
body: []byte(`{"session":{}}`),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gotProvider, gotModel, _, err := resolveRealtimeClientSecretTarget(tt.route, tt.body)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("resolveRealtimeClientSecretTarget() error = %v", err)
|
||||
}
|
||||
if gotProvider != tt.wantProvider {
|
||||
t.Fatalf("provider = %q, want %q", gotProvider, tt.wantProvider)
|
||||
}
|
||||
if gotModel != tt.wantModel {
|
||||
t.Fatalf("model = %q, want %q", gotModel, tt.wantModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRealtimeClientSecretTarget_NormalizesModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
route schemas.RealtimeSessionRoute
|
||||
body string
|
||||
wantModel string // bare model expected in normalized body
|
||||
}{
|
||||
{
|
||||
name: "session.model provider prefix stripped",
|
||||
route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets},
|
||||
body: `{"session":{"model":"openai/gpt-4o-realtime-preview","voice":"alloy"}}`,
|
||||
wantModel: "gpt-4o-realtime-preview",
|
||||
},
|
||||
{
|
||||
name: "top-level model provider prefix stripped",
|
||||
route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/sessions", EndpointType: schemas.RealtimeSessionEndpointSessions},
|
||||
body: `{"model":"openai/gpt-4o-realtime-preview"}`,
|
||||
wantModel: "gpt-4o-realtime-preview",
|
||||
},
|
||||
{
|
||||
name: "bare model unchanged on alias route",
|
||||
route: schemas.RealtimeSessionRoute{Path: "/openai/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets, DefaultProvider: schemas.OpenAI},
|
||||
body: `{"session":{"model":"gpt-4o-realtime-preview"}}`,
|
||||
wantModel: "gpt-4o-realtime-preview",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, _, normalizedBody, err := resolveRealtimeClientSecretTarget(tt.route, []byte(tt.body))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
var root map[string]json.RawMessage
|
||||
if unmarshalErr := json.Unmarshal(normalizedBody, &root); unmarshalErr != nil {
|
||||
t.Fatalf("failed to unmarshal normalized body: %v", unmarshalErr)
|
||||
}
|
||||
|
||||
// Check session.model if present
|
||||
if sessionJSON, ok := root["session"]; ok {
|
||||
var session map[string]json.RawMessage
|
||||
if unmarshalErr := json.Unmarshal(sessionJSON, &session); unmarshalErr != nil {
|
||||
t.Fatalf("failed to unmarshal session: %v", unmarshalErr)
|
||||
}
|
||||
if modelJSON, ok := session["model"]; ok {
|
||||
var model string
|
||||
if unmarshalErr := json.Unmarshal(modelJSON, &model); unmarshalErr != nil {
|
||||
t.Fatalf("failed to unmarshal session.model: %v", unmarshalErr)
|
||||
}
|
||||
if model != tt.wantModel {
|
||||
t.Fatalf("session.model = %q, want %q", model, tt.wantModel)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check top-level model if present
|
||||
if modelJSON, ok := root["model"]; ok {
|
||||
var model string
|
||||
if unmarshalErr := json.Unmarshal(modelJSON, &model); unmarshalErr != nil {
|
||||
t.Fatalf("failed to unmarshal model: %v", unmarshalErr)
|
||||
}
|
||||
if model != tt.wantModel {
|
||||
t.Fatalf("model = %q, want %q", model, tt.wantModel)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRealtimeEphemeralKeyMapping(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
token, ttl, ok := parseRealtimeEphemeralKeyMapping([]byte(`{
|
||||
"value": "ek_test_123",
|
||||
"expires_at": 4102444800
|
||||
}`))
|
||||
if !ok {
|
||||
t.Fatal("expected ephemeral mapping to be parsed")
|
||||
}
|
||||
if token != "ek_test_123" {
|
||||
t.Fatalf("token = %q, want %q", token, "ek_test_123")
|
||||
}
|
||||
if ttl <= 0 {
|
||||
t.Fatalf("ttl = %v, want > 0", ttl)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRealtimeEphemeralKeyMapping_NestedFallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
token, ttl, ok := parseRealtimeEphemeralKeyMapping([]byte(`{
|
||||
"client_secret": {
|
||||
"value": "ek_test_nested",
|
||||
"expires_at": 4102444800
|
||||
}
|
||||
}`))
|
||||
if !ok {
|
||||
t.Fatal("expected nested ephemeral mapping to be parsed")
|
||||
}
|
||||
if token != "ek_test_nested" {
|
||||
t.Fatalf("token = %q, want %q", token, "ek_test_nested")
|
||||
}
|
||||
if ttl <= 0 {
|
||||
t.Fatalf("ttl = %v, want > 0", ttl)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheRealtimeEphemeralKeyMappingStoresKeyID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, err := kvstore.New(kvstore.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("kvstore.New() error = %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
body := []byte(`{
|
||||
"value": "ek_test_456",
|
||||
"expires_at": ` + "4102444800" + `
|
||||
}`)
|
||||
cacheRealtimeEphemeralKeyMapping(store, body, "key_123", "sk-bf-test")
|
||||
|
||||
raw, err := store.Get(buildRealtimeEphemeralKeyMappingKey("ek_test_456"))
|
||||
if err != nil {
|
||||
t.Fatalf("store.Get() error = %v", err)
|
||||
}
|
||||
value, ok := raw.([]byte)
|
||||
if !ok {
|
||||
t.Fatalf("cached value type = %T, want []byte", raw)
|
||||
}
|
||||
var mapping realtimeEphemeralKeyMapping
|
||||
if err := json.Unmarshal(value, &mapping); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
if mapping.KeyID != "key_123" {
|
||||
t.Fatalf("mapping.KeyID = %q, want %q", mapping.KeyID, "key_123")
|
||||
}
|
||||
if mapping.VirtualKey != "sk-bf-test" {
|
||||
t.Fatalf("mapping.VirtualKey = %q, want %q", mapping.VirtualKey, "sk-bf-test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheRealtimeEphemeralKeyMappingSkipsExpiredSecrets(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, err := kvstore.New(kvstore.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("kvstore.New() error = %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
expired := time.Now().Add(-time.Minute).Unix()
|
||||
body := fmt.Appendf(nil, `{
|
||||
"value": "ek_expired",
|
||||
"expires_at": %d
|
||||
}`, expired)
|
||||
cacheRealtimeEphemeralKeyMapping(store, body, "key_123", "")
|
||||
|
||||
if _, err := store.Get(buildRealtimeEphemeralKeyMappingKey("ek_expired")); err == nil {
|
||||
t.Fatal("expected no cached mapping for expired token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsJSONContentType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if !isJSONContentType("application/json; charset=utf-8") {
|
||||
t.Fatal("expected application/json content type to pass")
|
||||
}
|
||||
if !isJSONContentType("application/vnd.openai+json") {
|
||||
t.Fatal("expected +json content type to pass")
|
||||
}
|
||||
if isJSONContentType("text/plain") {
|
||||
t.Fatal("expected text/plain content type to fail")
|
||||
}
|
||||
}
|
||||
|
||||
type mockRealtimeMintingGovernancePlugin struct {
|
||||
err *schemas.BifrostError
|
||||
seenUserID string
|
||||
seenVirtualKey string
|
||||
seenProvider schemas.ModelProvider
|
||||
seenModel string
|
||||
evaluateCalls int
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) GetName() string {
|
||||
return governance.PluginName
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) EvaluateGovernanceRequest(ctx *schemas.BifrostContext, evaluationRequest *governance.EvaluationRequest, _ schemas.RequestType) (*governance.EvaluationResult, *schemas.BifrostError) {
|
||||
m.evaluateCalls++
|
||||
m.seenUserID = ""
|
||||
m.seenVirtualKey = ""
|
||||
m.seenProvider = ""
|
||||
m.seenModel = ""
|
||||
if evaluationRequest != nil {
|
||||
m.seenUserID = evaluationRequest.UserID
|
||||
m.seenVirtualKey = evaluationRequest.VirtualKey
|
||||
m.seenProvider = evaluationRequest.Provider
|
||||
m.seenModel = evaluationRequest.Model
|
||||
}
|
||||
if ctx != nil && m.seenVirtualKey == "" {
|
||||
m.seenVirtualKey = bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey)
|
||||
}
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return &governance.EvaluationResult{Decision: governance.DecisionAllow}, nil
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) HTTPTransportPreHook(_ *schemas.BifrostContext, _ *schemas.HTTPRequest) (*schemas.HTTPResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) HTTPTransportPostHook(_ *schemas.BifrostContext, _ *schemas.HTTPRequest, _ *schemas.HTTPResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) PreLLMHook(_ *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) {
|
||||
return req, nil, nil
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) PostLLMHook(_ *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) {
|
||||
return result, bifrostErr, nil
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) PreMCPHook(_ *schemas.BifrostContext, req *schemas.BifrostMCPRequest) (*schemas.BifrostMCPRequest, *schemas.MCPPluginShortCircuit, error) {
|
||||
return req, nil, nil
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) PostMCPHook(_ *schemas.BifrostContext, resp *schemas.BifrostMCPResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostMCPResponse, *schemas.BifrostError, error) {
|
||||
return resp, bifrostErr, nil
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) Cleanup() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) GetGovernanceStore() governance.GovernanceStore {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRealtimeClientSecretsEvaluateMintingGovernance_RequiresAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &lib.Config{}
|
||||
plugin := &mockRealtimeMintingGovernancePlugin{
|
||||
err: &schemas.BifrostError{
|
||||
Type: schemas.Ptr("virtual_key_required"),
|
||||
StatusCode: schemas.Ptr(401),
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "virtual key is required. Provide a virtual key via the x-bf-vk header.",
|
||||
},
|
||||
},
|
||||
}
|
||||
plugins := []schemas.BasePlugin{plugin}
|
||||
config.BasePlugins.Store(&plugins)
|
||||
|
||||
handler := NewRealtimeClientSecretsHandler(nil, config)
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
defer bifrostCtx.Done()
|
||||
|
||||
err := handler.evaluateMintingGovernance(bifrostCtx, schemas.OpenAI, "gpt-realtime")
|
||||
if err == nil {
|
||||
t.Fatal("expected governance error")
|
||||
}
|
||||
if err.StatusCode == nil {
|
||||
t.Fatal("expected status code")
|
||||
}
|
||||
if got, want := *err.StatusCode, fasthttp.StatusUnauthorized; got != want {
|
||||
t.Fatalf("status = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRealtimeClientSecretsEvaluateMintingGovernance_PassesContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &lib.Config{}
|
||||
plugin := &mockRealtimeMintingGovernancePlugin{}
|
||||
plugins := []schemas.BasePlugin{
|
||||
plugin,
|
||||
}
|
||||
config.BasePlugins.Store(&plugins)
|
||||
|
||||
handler := NewRealtimeClientSecretsHandler(nil, config)
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
defer bifrostCtx.Done()
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyUserID, "user_123")
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, "sk-bf-123")
|
||||
|
||||
if err := handler.evaluateMintingGovernance(bifrostCtx, schemas.OpenAI, "gpt-realtime"); err != nil {
|
||||
t.Fatalf("unexpected governance error: %v", err)
|
||||
}
|
||||
if plugin.evaluateCalls != 1 {
|
||||
t.Fatalf("evaluate calls = %d, want 1", plugin.evaluateCalls)
|
||||
}
|
||||
if plugin.seenUserID != "user_123" {
|
||||
t.Fatalf("governance user id = %q, want %q", plugin.seenUserID, "user_123")
|
||||
}
|
||||
if plugin.seenVirtualKey != "sk-bf-123" {
|
||||
t.Fatalf("virtual key = %q, want %q", plugin.seenVirtualKey, "sk-bf-123")
|
||||
}
|
||||
if plugin.seenProvider != schemas.OpenAI {
|
||||
t.Fatalf("provider = %q, want %q", plugin.seenProvider, schemas.OpenAI)
|
||||
}
|
||||
if plugin.seenModel != "gpt-realtime" {
|
||||
t.Fatalf("model = %q, want %q", plugin.seenModel, "gpt-realtime")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRealtimeClientSecretsEvaluateMintingGovernance_ContinuesWithoutGovernance(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := NewRealtimeClientSecretsHandler(nil, &lib.Config{})
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
defer bifrostCtx.Done()
|
||||
|
||||
if err := handler.evaluateMintingGovernance(bifrostCtx, schemas.OpenAI, "gpt-realtime"); err != nil {
|
||||
t.Fatalf("unexpected governance error without plugin: %v", err)
|
||||
}
|
||||
}
|
||||
441
transports/bifrost-http/handlers/realtime_logging.go
Normal file
441
transports/bifrost-http/handlers/realtime_logging.go
Normal file
@@ -0,0 +1,441 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket"
|
||||
)
|
||||
|
||||
type realtimeTurnSource string
|
||||
|
||||
const (
|
||||
realtimeTurnSourceEI realtimeTurnSource = "ei"
|
||||
realtimeTurnSourceLM realtimeTurnSource = "lm"
|
||||
)
|
||||
|
||||
const (
|
||||
realtimeMissingTranscriptText = "[Audio transcription unavailable]"
|
||||
)
|
||||
|
||||
func extractRealtimeTurnSummary(event *schemas.BifrostRealtimeEvent, contentOverride string) string {
|
||||
if strings.TrimSpace(contentOverride) != "" {
|
||||
return strings.TrimSpace(contentOverride)
|
||||
}
|
||||
if event == nil {
|
||||
return ""
|
||||
}
|
||||
if event.Error != nil && strings.TrimSpace(event.Error.Message) != "" {
|
||||
return strings.TrimSpace(event.Error.Message)
|
||||
}
|
||||
if event.Delta != nil {
|
||||
if text := strings.TrimSpace(event.Delta.Text); text != "" {
|
||||
return text
|
||||
}
|
||||
if transcript := strings.TrimSpace(event.Delta.Transcript); transcript != "" {
|
||||
return transcript
|
||||
}
|
||||
}
|
||||
if event.Item != nil {
|
||||
if summary := extractRealtimeItemSummary(event.Item); summary != "" {
|
||||
return summary
|
||||
}
|
||||
}
|
||||
if event.Session != nil && strings.TrimSpace(event.Session.Instructions) != "" {
|
||||
return strings.TrimSpace(event.Session.Instructions)
|
||||
}
|
||||
if len(event.RawData) > 0 {
|
||||
return strings.TrimSpace(string(event.RawData))
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractRealtimeItemSummary(item *schemas.RealtimeItem) string {
|
||||
if item == nil {
|
||||
return ""
|
||||
}
|
||||
if summary := extractRealtimeContentSummary(item.Content); summary != "" {
|
||||
return summary
|
||||
}
|
||||
switch {
|
||||
case strings.TrimSpace(item.Output) != "":
|
||||
return strings.TrimSpace(item.Output)
|
||||
case strings.TrimSpace(item.Arguments) != "":
|
||||
return strings.TrimSpace(item.Arguments)
|
||||
case strings.TrimSpace(item.Name) != "":
|
||||
return strings.TrimSpace(item.Name)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func extractRealtimeContentSummary(raw []byte) string {
|
||||
if len(raw) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var decoded any
|
||||
if err := sonic.Unmarshal(raw, &decoded); err != nil {
|
||||
return strings.TrimSpace(string(raw))
|
||||
}
|
||||
|
||||
var parts []string
|
||||
collectRealtimeTextFragments(decoded, &parts)
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
func collectRealtimeTextFragments(value any, parts *[]string) {
|
||||
switch v := value.(type) {
|
||||
case map[string]any:
|
||||
for key, field := range v {
|
||||
switch key {
|
||||
case "text", "transcript", "input_text", "output_text", "output", "arguments":
|
||||
if text, ok := field.(string); ok {
|
||||
text = strings.TrimSpace(text)
|
||||
if text != "" {
|
||||
*parts = append(*parts, text)
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
collectRealtimeTextFragments(field, parts)
|
||||
}
|
||||
case []any:
|
||||
for _, item := range v {
|
||||
collectRealtimeTextFragments(item, parts)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func finalizedRealtimeInputSummary(event *schemas.BifrostRealtimeEvent) string {
|
||||
if event == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case schemas.RTEventInputAudioTransCompleted:
|
||||
if transcript := extractRealtimeExtraParamString(event, "transcript"); transcript != "" {
|
||||
return transcript
|
||||
}
|
||||
return realtimeMissingTranscriptText
|
||||
default:
|
||||
if event != nil && event.Type == schemas.RTEventConversationItemDone && schemas.IsRealtimeUserInputEvent(event) {
|
||||
if summary := extractRealtimeItemSummary(event.Item); summary != "" {
|
||||
return summary
|
||||
}
|
||||
if realtimeItemHasMissingAudioTranscript(event.Item) {
|
||||
return realtimeMissingTranscriptText
|
||||
}
|
||||
}
|
||||
if schemas.IsRealtimeUserInputEvent(event) {
|
||||
return extractRealtimeItemSummary(event.Item)
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func pendingRealtimeInputUpdate(event *schemas.BifrostRealtimeEvent) (string, string) {
|
||||
if event == nil {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case schemas.RTEventConversationItemRetrieved:
|
||||
return "", ""
|
||||
case schemas.RTEventInputAudioTransCompleted:
|
||||
return realtimeEventItemID(event), finalizedRealtimeInputSummary(event)
|
||||
default:
|
||||
if schemas.IsRealtimeUserInputEvent(event) {
|
||||
return realtimeEventItemID(event), finalizedRealtimeInputSummary(event)
|
||||
}
|
||||
}
|
||||
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func realtimeItemHasMissingAudioTranscript(item *schemas.RealtimeItem) bool {
|
||||
if item == nil || len(item.Content) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
var decoded []map[string]any
|
||||
if err := sonic.Unmarshal(item.Content, &decoded); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, part := range decoded {
|
||||
partType, _ := part["type"].(string)
|
||||
if partType != "input_audio" {
|
||||
continue
|
||||
}
|
||||
transcript, exists := part["transcript"]
|
||||
if !exists || transcript == nil {
|
||||
return true
|
||||
}
|
||||
if text, ok := transcript.(string); ok && strings.TrimSpace(text) == "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func finalizedRealtimeToolOutputSummary(event *schemas.BifrostRealtimeEvent) string {
|
||||
if !schemas.IsRealtimeToolOutputEvent(event) {
|
||||
return ""
|
||||
}
|
||||
return extractRealtimeItemSummary(event.Item)
|
||||
}
|
||||
|
||||
func pendingRealtimeToolOutputUpdate(event *schemas.BifrostRealtimeEvent) (string, string) {
|
||||
if event == nil || event.Type == schemas.RTEventConversationItemRetrieved || !schemas.IsRealtimeToolOutputEvent(event) {
|
||||
return "", ""
|
||||
}
|
||||
return realtimeEventItemID(event), finalizedRealtimeToolOutputSummary(event)
|
||||
}
|
||||
|
||||
func extractRealtimeExtraParamString(event *schemas.BifrostRealtimeEvent, key string) string {
|
||||
if event == nil || event.ExtraParams == nil {
|
||||
return ""
|
||||
}
|
||||
raw, ok := event.ExtraParams[key]
|
||||
if !ok || len(raw) == 0 {
|
||||
return ""
|
||||
}
|
||||
var value string
|
||||
if err := json.Unmarshal(raw, &value); err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
func realtimeEventItemID(event *schemas.BifrostRealtimeEvent) string {
|
||||
if event == nil {
|
||||
return ""
|
||||
}
|
||||
if event.Item != nil && strings.TrimSpace(event.Item.ID) != "" {
|
||||
return strings.TrimSpace(event.Item.ID)
|
||||
}
|
||||
if event.Delta != nil && strings.TrimSpace(event.Delta.ItemID) != "" {
|
||||
return strings.TrimSpace(event.Delta.ItemID)
|
||||
}
|
||||
return extractRealtimeExtraParamString(event, "item_id")
|
||||
}
|
||||
|
||||
func combineRealtimeInputRaw(turnInputs []bfws.RealtimeTurnInput) string {
|
||||
var parts []string
|
||||
for _, turnInput := range turnInputs {
|
||||
if trimmed := strings.TrimSpace(turnInput.Raw); trimmed != "" {
|
||||
parts = append(parts, trimmed)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
type realtimeResponseDoneEnvelope struct {
|
||||
Response struct {
|
||||
Output []realtimeResponseDoneOutput `json:"output"`
|
||||
Usage *realtimeResponseDoneUsage `json:"usage"`
|
||||
} `json:"response"`
|
||||
}
|
||||
|
||||
type realtimeResponseDoneOutput struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
CallID string `json:"call_id"`
|
||||
Arguments string `json:"arguments"`
|
||||
Content []realtimeResponseDoneContent `json:"content"`
|
||||
}
|
||||
|
||||
type realtimeResponseDoneContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
Transcript string `json:"transcript"`
|
||||
Refusal string `json:"refusal"`
|
||||
}
|
||||
|
||||
type realtimeResponseDoneUsage struct {
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokenDetails *realtimeResponseDoneInputTokenUsage `json:"input_token_details"`
|
||||
OutputTokenDetails *realtimeResponseDoneOutputTokenUsage `json:"output_token_details"`
|
||||
}
|
||||
|
||||
type realtimeResponseDoneInputTokenUsage struct {
|
||||
TextTokens int `json:"text_tokens"`
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
ImageTokens int `json:"image_tokens"`
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
}
|
||||
|
||||
type realtimeResponseDoneOutputTokenUsage struct {
|
||||
TextTokens int `json:"text_tokens"`
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
ReasoningTokens int `json:"reasoning_tokens"`
|
||||
ImageTokens *int `json:"image_tokens"`
|
||||
CitationTokens *int `json:"citation_tokens"`
|
||||
NumSearchQueries *int `json:"num_search_queries"`
|
||||
AcceptedPredictionTokens int `json:"accepted_prediction_tokens"`
|
||||
RejectedPredictionTokens int `json:"rejected_prediction_tokens"`
|
||||
}
|
||||
|
||||
func extractRealtimeTurnUsage(provider schemas.RealtimeProvider, rawMessage []byte) *schemas.BifrostLLMUsage {
|
||||
if extractor, ok := provider.(schemas.RealtimeUsageExtractor); ok {
|
||||
if usage := extractor.ExtractRealtimeTurnUsage(rawMessage); usage != nil {
|
||||
return usage
|
||||
}
|
||||
}
|
||||
return extractRealtimeResponseDoneUsage(rawMessage)
|
||||
}
|
||||
|
||||
func extractRealtimeTurnOutputMessage(provider schemas.RealtimeProvider, rawMessage []byte, contentSummary string) *schemas.ChatMessage {
|
||||
if extractor, ok := provider.(schemas.RealtimeUsageExtractor); ok {
|
||||
if message := extractor.ExtractRealtimeTurnOutput(rawMessage); message != nil {
|
||||
if strings.TrimSpace(contentSummary) != "" && (message.Content == nil || message.Content.ContentStr == nil || strings.TrimSpace(*message.Content.ContentStr) == "") {
|
||||
message.Content = &schemas.ChatMessageContent{ContentStr: schemas.Ptr(strings.TrimSpace(contentSummary))}
|
||||
}
|
||||
return message
|
||||
}
|
||||
}
|
||||
return buildRealtimeAssistantLogMessage(rawMessage, contentSummary)
|
||||
}
|
||||
|
||||
func buildRealtimeAssistantLogMessage(rawMessage []byte, contentSummary string) *schemas.ChatMessage {
|
||||
contentSummary = strings.TrimSpace(contentSummary)
|
||||
var parsed realtimeResponseDoneEnvelope
|
||||
if len(rawMessage) > 0 && sonic.Unmarshal(rawMessage, &parsed) == nil {
|
||||
message := &schemas.ChatMessage{Role: schemas.ChatMessageRoleAssistant}
|
||||
if contentSummary == "" {
|
||||
contentSummary = extractRealtimeResponseDoneAssistantText(parsed.Response.Output)
|
||||
}
|
||||
if contentSummary != "" {
|
||||
message.Content = &schemas.ChatMessageContent{ContentStr: schemas.Ptr(contentSummary)}
|
||||
}
|
||||
|
||||
toolCalls := extractRealtimeResponseDoneToolCalls(parsed.Response.Output)
|
||||
if len(toolCalls) > 0 {
|
||||
message.ChatAssistantMessage = &schemas.ChatAssistantMessage{
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
if message.Content != nil || message.ChatAssistantMessage != nil {
|
||||
return message
|
||||
}
|
||||
}
|
||||
|
||||
if contentSummary == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr(contentSummary)},
|
||||
}
|
||||
}
|
||||
|
||||
func extractRealtimeResponseDoneAssistantText(outputs []realtimeResponseDoneOutput) string {
|
||||
var parts []string
|
||||
for _, output := range outputs {
|
||||
if output.Type != "message" {
|
||||
continue
|
||||
}
|
||||
for _, block := range output.Content {
|
||||
switch {
|
||||
case strings.TrimSpace(block.Text) != "":
|
||||
parts = append(parts, strings.TrimSpace(block.Text))
|
||||
case strings.TrimSpace(block.Transcript) != "":
|
||||
parts = append(parts, strings.TrimSpace(block.Transcript))
|
||||
case strings.TrimSpace(block.Refusal) != "":
|
||||
parts = append(parts, strings.TrimSpace(block.Refusal))
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
func extractRealtimeResponseDoneToolCalls(outputs []realtimeResponseDoneOutput) []schemas.ChatAssistantMessageToolCall {
|
||||
toolCalls := make([]schemas.ChatAssistantMessageToolCall, 0)
|
||||
for _, output := range outputs {
|
||||
if output.Type != "function_call" {
|
||||
continue
|
||||
}
|
||||
|
||||
name := strings.TrimSpace(output.Name)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
toolType := "function"
|
||||
id := strings.TrimSpace(output.CallID)
|
||||
if id == "" {
|
||||
id = strings.TrimSpace(output.ID)
|
||||
}
|
||||
|
||||
toolCall := schemas.ChatAssistantMessageToolCall{
|
||||
Index: uint16(len(toolCalls)),
|
||||
Type: &toolType,
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: schemas.Ptr(name),
|
||||
Arguments: output.Arguments,
|
||||
},
|
||||
}
|
||||
if id != "" {
|
||||
toolCall.ID = schemas.Ptr(id)
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
}
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func extractRealtimeResponseDoneUsage(rawMessage []byte) *schemas.BifrostLLMUsage {
|
||||
if len(rawMessage) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var parsed realtimeResponseDoneEnvelope
|
||||
if err := sonic.Unmarshal(rawMessage, &parsed); err != nil || parsed.Response.Usage == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
totalTokens := parsed.Response.Usage.TotalTokens
|
||||
if totalTokens == 0 && (parsed.Response.Usage.InputTokens > 0 || parsed.Response.Usage.OutputTokens > 0) {
|
||||
totalTokens = parsed.Response.Usage.InputTokens + parsed.Response.Usage.OutputTokens
|
||||
}
|
||||
|
||||
usage := &schemas.BifrostLLMUsage{
|
||||
PromptTokens: parsed.Response.Usage.InputTokens,
|
||||
CompletionTokens: parsed.Response.Usage.OutputTokens,
|
||||
TotalTokens: totalTokens,
|
||||
}
|
||||
|
||||
if parsed.Response.Usage.InputTokenDetails != nil {
|
||||
usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{
|
||||
TextTokens: parsed.Response.Usage.InputTokenDetails.TextTokens,
|
||||
AudioTokens: parsed.Response.Usage.InputTokenDetails.AudioTokens,
|
||||
ImageTokens: parsed.Response.Usage.InputTokenDetails.ImageTokens,
|
||||
CachedReadTokens: parsed.Response.Usage.InputTokenDetails.CachedTokens,
|
||||
}
|
||||
}
|
||||
|
||||
if parsed.Response.Usage.OutputTokenDetails != nil {
|
||||
usage.CompletionTokensDetails = &schemas.ChatCompletionTokensDetails{
|
||||
TextTokens: parsed.Response.Usage.OutputTokenDetails.TextTokens,
|
||||
AudioTokens: parsed.Response.Usage.OutputTokenDetails.AudioTokens,
|
||||
ReasoningTokens: parsed.Response.Usage.OutputTokenDetails.ReasoningTokens,
|
||||
ImageTokens: parsed.Response.Usage.OutputTokenDetails.ImageTokens,
|
||||
CitationTokens: parsed.Response.Usage.OutputTokenDetails.CitationTokens,
|
||||
NumSearchQueries: parsed.Response.Usage.OutputTokenDetails.NumSearchQueries,
|
||||
AcceptedPredictionTokens: parsed.Response.Usage.OutputTokenDetails.AcceptedPredictionTokens,
|
||||
RejectedPredictionTokens: parsed.Response.Usage.OutputTokenDetails.RejectedPredictionTokens,
|
||||
}
|
||||
}
|
||||
|
||||
return usage
|
||||
}
|
||||
435
transports/bifrost-http/handlers/realtime_logging_test.go
Normal file
435
transports/bifrost-http/handlers/realtime_logging_test.go
Normal file
@@ -0,0 +1,435 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/openai"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket"
|
||||
)
|
||||
|
||||
func TestShouldAccumulateRealtimeOutput(t *testing.T) {
|
||||
provider := &openai.OpenAIProvider{}
|
||||
if !provider.ShouldAccumulateRealtimeOutput(schemas.RTEventResponseTextDelta) {
|
||||
t.Fatal("expected response.text.delta to accumulate output text")
|
||||
}
|
||||
if !provider.ShouldAccumulateRealtimeOutput(schemas.RTEventResponseAudioTransDelta) {
|
||||
t.Fatal("expected response.audio_transcript.delta to accumulate output transcript")
|
||||
}
|
||||
if provider.ShouldAccumulateRealtimeOutput(schemas.RTEventInputAudioTransDelta) {
|
||||
t.Fatal("did not expect input audio transcription delta to accumulate assistant output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractRealtimeTurnSummary(t *testing.T) {
|
||||
event := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemCreate,
|
||||
Item: &schemas.RealtimeItem{
|
||||
Content: []byte(`[{"type":"input_text","text":"hello from realtime"}]`),
|
||||
},
|
||||
}
|
||||
|
||||
got := extractRealtimeTurnSummary(event, "")
|
||||
if got != "hello from realtime" {
|
||||
t.Fatalf("extractRealtimeTurnSummary() = %q, want %q", got, "hello from realtime")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalizedRealtimeInputSummary(t *testing.T) {
|
||||
userCreate := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemCreate,
|
||||
Item: &schemas.RealtimeItem{
|
||||
Role: "user",
|
||||
Content: []byte(`[{"type":"input_text","text":"hello from browser"}]`),
|
||||
},
|
||||
}
|
||||
if got := finalizedRealtimeInputSummary(userCreate); got != "hello from browser" {
|
||||
t.Fatalf("finalizedRealtimeInputSummary(user create) = %q, want %q", got, "hello from browser")
|
||||
}
|
||||
|
||||
userRetrieved := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemRetrieved,
|
||||
Item: &schemas.RealtimeItem{
|
||||
Role: "user",
|
||||
Content: []byte(`[{"type":"input_text","text":"hello from retrieved item"}]`),
|
||||
},
|
||||
}
|
||||
if got := finalizedRealtimeInputSummary(userRetrieved); got != "hello from retrieved item" {
|
||||
t.Fatalf("finalizedRealtimeInputSummary(user retrieved) = %q, want %q", got, "hello from retrieved item")
|
||||
}
|
||||
|
||||
userCreated := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemCreated,
|
||||
Item: &schemas.RealtimeItem{
|
||||
Role: "user",
|
||||
Content: []byte(`[{"type":"input_text","text":"hello from provider created item"}]`),
|
||||
},
|
||||
}
|
||||
if got := finalizedRealtimeInputSummary(userCreated); got != "hello from provider created item" {
|
||||
t.Fatalf("finalizedRealtimeInputSummary(user created) = %q, want %q", got, "hello from provider created item")
|
||||
}
|
||||
|
||||
userAdded := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemAdded,
|
||||
Item: &schemas.RealtimeItem{
|
||||
Role: "user",
|
||||
Content: []byte(`[{"type":"input_text","text":"hello from provider added item"}]`),
|
||||
},
|
||||
}
|
||||
if got := finalizedRealtimeInputSummary(userAdded); got != "hello from provider added item" {
|
||||
t.Fatalf("finalizedRealtimeInputSummary(user added) = %q, want %q", got, "hello from provider added item")
|
||||
}
|
||||
|
||||
userCreatedWithoutTranscript := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemCreated,
|
||||
Item: &schemas.RealtimeItem{
|
||||
Role: "user",
|
||||
Type: "message",
|
||||
Content: []byte(`[{"type":"input_audio","audio":null,"transcript":null}]`),
|
||||
},
|
||||
RawData: []byte(`{"type":"conversation.item.created","item":{"type":"message","role":"user","content":[{"type":"input_audio","audio":null,"transcript":null}]}}`),
|
||||
}
|
||||
if got := finalizedRealtimeInputSummary(userCreatedWithoutTranscript); got != "" {
|
||||
t.Fatalf("finalizedRealtimeInputSummary(user created without transcript) = %q, want empty", got)
|
||||
}
|
||||
|
||||
userDoneWithoutTranscript := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemDone,
|
||||
Item: &schemas.RealtimeItem{
|
||||
Role: "user",
|
||||
Type: "message",
|
||||
Status: "completed",
|
||||
Content: []byte(`[{"type":"input_audio","audio":null,"transcript":null}]`),
|
||||
},
|
||||
RawData: []byte(`{"type":"conversation.item.done","item":{"type":"message","role":"user","status":"completed","content":[{"type":"input_audio","audio":null,"transcript":null}]}}`),
|
||||
}
|
||||
if got := finalizedRealtimeInputSummary(userDoneWithoutTranscript); got != realtimeMissingTranscriptText {
|
||||
t.Fatalf("finalizedRealtimeInputSummary(user done without transcript) = %q, want %q", got, realtimeMissingTranscriptText)
|
||||
}
|
||||
|
||||
inputTranscript := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventInputAudioTransCompleted,
|
||||
ExtraParams: map[string]json.RawMessage{
|
||||
"transcript": json.RawMessage(`"spoken user turn"`),
|
||||
},
|
||||
}
|
||||
if got := finalizedRealtimeInputSummary(inputTranscript); got != "spoken user turn" {
|
||||
t.Fatalf("finalizedRealtimeInputSummary(input transcript) = %q, want %q", got, "spoken user turn")
|
||||
}
|
||||
|
||||
emptyInputTranscript := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventInputAudioTransCompleted,
|
||||
ExtraParams: map[string]json.RawMessage{
|
||||
"transcript": json.RawMessage(`""`),
|
||||
},
|
||||
RawData: []byte(`{"type":"conversation.item.input_audio_transcription.completed","transcript":"","usage":{"total_tokens":11}}`),
|
||||
}
|
||||
if got := finalizedRealtimeInputSummary(emptyInputTranscript); got != realtimeMissingTranscriptText {
|
||||
t.Fatalf("finalizedRealtimeInputSummary(empty input transcript) = %q, want %q", got, realtimeMissingTranscriptText)
|
||||
}
|
||||
|
||||
missingInputTranscript := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventInputAudioTransCompleted,
|
||||
RawData: []byte(`{"type":"conversation.item.input_audio_transcription.completed","usage":{"total_tokens":11}}`),
|
||||
}
|
||||
if got := finalizedRealtimeInputSummary(missingInputTranscript); got != realtimeMissingTranscriptText {
|
||||
t.Fatalf("finalizedRealtimeInputSummary(missing input transcript) = %q, want %q", got, realtimeMissingTranscriptText)
|
||||
}
|
||||
|
||||
assistantCreate := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemCreate,
|
||||
Item: &schemas.RealtimeItem{
|
||||
Role: "assistant",
|
||||
Content: []byte(`[{"type":"text","text":"assistant text"}]`),
|
||||
},
|
||||
}
|
||||
if got := finalizedRealtimeInputSummary(assistantCreate); got != "" {
|
||||
t.Fatalf("finalizedRealtimeInputSummary(assistant create) = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalizedRealtimeToolOutputSummary(t *testing.T) {
|
||||
event := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemCreate,
|
||||
Item: &schemas.RealtimeItem{
|
||||
Type: "function_call_output",
|
||||
Output: `{"nextResponse":"tool result"}`,
|
||||
},
|
||||
}
|
||||
if got := finalizedRealtimeToolOutputSummary(event); got != `{"nextResponse":"tool result"}` {
|
||||
t.Fatalf("finalizedRealtimeToolOutputSummary() = %q, want %q", got, `{"nextResponse":"tool result"}`)
|
||||
}
|
||||
|
||||
retrieved := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemRetrieved,
|
||||
Item: &schemas.RealtimeItem{
|
||||
Type: "function_call_output",
|
||||
Output: `{"nextResponse":"tool result from retrieved"}`,
|
||||
},
|
||||
}
|
||||
if got := finalizedRealtimeToolOutputSummary(retrieved); got != `{"nextResponse":"tool result from retrieved"}` {
|
||||
t.Fatalf("finalizedRealtimeToolOutputSummary(retrieved) = %q, want %q", got, `{"nextResponse":"tool result from retrieved"}`)
|
||||
}
|
||||
|
||||
created := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemCreated,
|
||||
Item: &schemas.RealtimeItem{
|
||||
Type: "function_call_output",
|
||||
Output: `{"nextResponse":"tool result from created"}`,
|
||||
},
|
||||
}
|
||||
if got := finalizedRealtimeToolOutputSummary(created); got != `{"nextResponse":"tool result from created"}` {
|
||||
t.Fatalf("finalizedRealtimeToolOutputSummary(created) = %q, want %q", got, `{"nextResponse":"tool result from created"}`)
|
||||
}
|
||||
|
||||
added := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemAdded,
|
||||
Item: &schemas.RealtimeItem{
|
||||
Type: "function_call_output",
|
||||
Output: `{"nextResponse":"tool result from added"}`,
|
||||
},
|
||||
}
|
||||
if got := finalizedRealtimeToolOutputSummary(added); got != `{"nextResponse":"tool result from added"}` {
|
||||
t.Fatalf("finalizedRealtimeToolOutputSummary(added) = %q, want %q", got, `{"nextResponse":"tool result from added"}`)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPendingRealtimeInputUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
transcriptEvent := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventInputAudioTransCompleted,
|
||||
ExtraParams: map[string]json.RawMessage{
|
||||
"item_id": json.RawMessage(`"item_123"`),
|
||||
"transcript": json.RawMessage(`"Hello."`),
|
||||
},
|
||||
}
|
||||
itemID, summary := pendingRealtimeInputUpdate(transcriptEvent)
|
||||
if itemID != "item_123" || summary != "Hello." {
|
||||
t.Fatalf("pendingRealtimeInputUpdate(transcript) = (%q, %q), want (%q, %q)", itemID, summary, "item_123", "Hello.")
|
||||
}
|
||||
|
||||
retrievedEvent := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemRetrieved,
|
||||
Item: &schemas.RealtimeItem{
|
||||
ID: "item_123",
|
||||
Role: "user",
|
||||
Content: []byte(`[{"type":"input_text","text":"historical hello"}]`),
|
||||
},
|
||||
}
|
||||
itemID, summary = pendingRealtimeInputUpdate(retrievedEvent)
|
||||
if itemID != "" || summary != "" {
|
||||
t.Fatalf("pendingRealtimeInputUpdate(retrieved) = (%q, %q), want empty", itemID, summary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPendingRealtimeToolOutputUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
toolOutputEvent := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemDone,
|
||||
Item: &schemas.RealtimeItem{
|
||||
ID: "item_tool_123",
|
||||
Type: "function_call_output",
|
||||
Output: `{"nextResponse":"tool result"}`,
|
||||
},
|
||||
}
|
||||
itemID, summary := pendingRealtimeToolOutputUpdate(toolOutputEvent)
|
||||
if itemID != "item_tool_123" || summary != `{"nextResponse":"tool result"}` {
|
||||
t.Fatalf("pendingRealtimeToolOutputUpdate(done) = (%q, %q), want (%q, %q)", itemID, summary, "item_tool_123", `{"nextResponse":"tool result"}`)
|
||||
}
|
||||
|
||||
retrievedToolOutputEvent := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemRetrieved,
|
||||
Item: &schemas.RealtimeItem{
|
||||
ID: "item_tool_123",
|
||||
Type: "function_call_output",
|
||||
Output: `{"nextResponse":"historical tool result"}`,
|
||||
},
|
||||
}
|
||||
itemID, summary = pendingRealtimeToolOutputUpdate(retrievedToolOutputEvent)
|
||||
if itemID != "" || summary != "" {
|
||||
t.Fatalf("pendingRealtimeToolOutputUpdate(retrieved) = (%q, %q), want empty", itemID, summary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRealtimeTurnPostResponseUsesFullResponseDonePayload(t *testing.T) {
|
||||
rawRequest := `{"type":"conversation.item.input_audio_transcription.completed","transcript":""}`
|
||||
rawResponse := []byte(`{
|
||||
"type":"response.done",
|
||||
"response":{
|
||||
"output":[
|
||||
{
|
||||
"id":"item_message_123",
|
||||
"type":"message",
|
||||
"content":[
|
||||
{
|
||||
"type":"audio",
|
||||
"transcript":"assistant turn text"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"usage":{
|
||||
"total_tokens":26,
|
||||
"input_tokens":17,
|
||||
"output_tokens":9,
|
||||
"input_token_details":{
|
||||
"text_tokens":12,
|
||||
"audio_tokens":5,
|
||||
"image_tokens":0,
|
||||
"cached_tokens":4
|
||||
},
|
||||
"output_token_details":{
|
||||
"text_tokens":7,
|
||||
"audio_tokens":2
|
||||
}
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
resp := buildRealtimeTurnPostResponse(&openai.OpenAIProvider{}, schemas.OpenAI, "gpt-4o-realtime-preview-2025-06-03", rawRequest, rawResponse, "", 4321)
|
||||
if resp == nil || resp.ResponsesResponse == nil {
|
||||
t.Fatal("expected realtime post response to be built")
|
||||
}
|
||||
if resp.ResponsesResponse.ExtraFields.Latency != 4321 {
|
||||
t.Fatalf("Latency = %d, want %d", resp.ResponsesResponse.ExtraFields.Latency, 4321)
|
||||
}
|
||||
if resp.ResponsesResponse.Usage == nil || resp.ResponsesResponse.Usage.InputTokens != 17 || resp.ResponsesResponse.Usage.OutputTokens != 9 || resp.ResponsesResponse.Usage.TotalTokens != 26 {
|
||||
t.Fatalf("Usage = %+v, want input=17 output=9 total=26", resp.ResponsesResponse.Usage)
|
||||
}
|
||||
if len(resp.ResponsesResponse.Output) != 1 {
|
||||
t.Fatalf("len(Output) = %d, want 1", len(resp.ResponsesResponse.Output))
|
||||
}
|
||||
if resp.ResponsesResponse.Output[0].Content == nil || resp.ResponsesResponse.Output[0].Content.ContentStr == nil || *resp.ResponsesResponse.Output[0].Content.ContentStr != "assistant turn text" {
|
||||
t.Fatalf("Output[0].Content = %+v, want assistant turn text", resp.ResponsesResponse.Output[0].Content)
|
||||
}
|
||||
if got, ok := resp.ResponsesResponse.ExtraFields.RawRequest.(string); !ok || got != rawRequest {
|
||||
t.Fatalf("RawRequest = %#v, want %q", resp.ResponsesResponse.ExtraFields.RawRequest, rawRequest)
|
||||
}
|
||||
if got, ok := resp.ResponsesResponse.ExtraFields.RawResponse.(string); !ok || got == "" {
|
||||
t.Fatalf("RawResponse = %#v, want raw response string", resp.ResponsesResponse.ExtraFields.RawResponse)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalizeRealtimeTurnHooksWithErrorCompletesActiveHooks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
session := bfws.NewSession(nil)
|
||||
session.SetProviderSessionID("sess_provider_123")
|
||||
session.AddRealtimeInput("hello from user", `{"type":"conversation.item.added"}`)
|
||||
session.AppendRealtimeOutputText("partial assistant output")
|
||||
|
||||
var (
|
||||
capturedResp *schemas.BifrostResponse
|
||||
capturedErr *schemas.BifrostError
|
||||
cleanedUp bool
|
||||
)
|
||||
session.SetRealtimeTurnHooks(&bfws.RealtimeTurnPluginState{
|
||||
RequestID: "req_realtime_123",
|
||||
StartedAt: time.Now().Add(-time.Second),
|
||||
PreHookValues: map[any]any{
|
||||
schemas.BifrostContextKeyGovernanceVirtualKeyID: "vk_123",
|
||||
},
|
||||
PostHookRunner: func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
|
||||
capturedResp = result
|
||||
capturedErr = err
|
||||
return result, nil
|
||||
},
|
||||
Cleanup: func() {
|
||||
cleanedUp = true
|
||||
},
|
||||
})
|
||||
|
||||
rawResponse := []byte(`{"type":"error","error":{"type":"server_error","message":"Virtual key is required."}}`)
|
||||
postErr := finalizeRealtimeTurnHooksWithError(
|
||||
nil,
|
||||
nil,
|
||||
session,
|
||||
schemas.OpenAI,
|
||||
"gpt-realtime",
|
||||
nil,
|
||||
schemas.RTEventError,
|
||||
rawResponse,
|
||||
newRealtimeWireBifrostError(401, "server_error", "Virtual key is required."),
|
||||
)
|
||||
if postErr != nil {
|
||||
t.Fatalf("finalizeRealtimeTurnHooksWithError() post error = %v, want nil", postErr)
|
||||
}
|
||||
if capturedResp != nil {
|
||||
t.Fatalf("captured response = %#v, want nil", capturedResp)
|
||||
}
|
||||
if capturedErr == nil {
|
||||
t.Fatal("expected captured error")
|
||||
}
|
||||
if capturedErr.ExtraFields.RequestType != schemas.RealtimeRequest {
|
||||
t.Fatalf("request type = %q, want %q", capturedErr.ExtraFields.RequestType, schemas.RealtimeRequest)
|
||||
}
|
||||
if capturedErr.ExtraFields.Provider != schemas.OpenAI {
|
||||
t.Fatalf("provider = %q, want %q", capturedErr.ExtraFields.Provider, schemas.OpenAI)
|
||||
}
|
||||
if capturedErr.ExtraFields.OriginalModelRequested != "gpt-realtime" {
|
||||
t.Fatalf("model requested = %q, want %q", capturedErr.ExtraFields.OriginalModelRequested, "gpt-realtime")
|
||||
}
|
||||
rawRequest, ok := capturedErr.ExtraFields.RawRequest.(string)
|
||||
if !ok || rawRequest == "" {
|
||||
t.Fatalf("raw request = %#v, want non-empty string", capturedErr.ExtraFields.RawRequest)
|
||||
}
|
||||
rawResp, ok := capturedErr.ExtraFields.RawResponse.(json.RawMessage)
|
||||
if !ok || string(rawResp) != string(rawResponse) {
|
||||
t.Fatalf("raw response = %#v, want %s", capturedErr.ExtraFields.RawResponse, string(rawResponse))
|
||||
}
|
||||
if session.PeekRealtimeTurnHooks() != nil {
|
||||
t.Fatal("expected active hooks to be cleared")
|
||||
}
|
||||
if got := session.ConsumeRealtimeTurnInputs(); len(got) != 0 {
|
||||
t.Fatalf("remaining turn inputs = %d, want 0", len(got))
|
||||
}
|
||||
if got := session.ConsumeRealtimeOutputText(); got != "" {
|
||||
t.Fatalf("remaining output text = %q, want empty", got)
|
||||
}
|
||||
if !cleanedUp {
|
||||
t.Fatal("expected realtime hook cleanup to run")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBifrostErrorFromRealtimeErrorCarriesRealtimeMetadata(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rawResponse := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request_error","message":"bad request","param":"session.type"}}`)
|
||||
bifrostErr := newBifrostErrorFromRealtimeError(
|
||||
schemas.OpenAI,
|
||||
"gpt-realtime",
|
||||
rawResponse,
|
||||
&schemas.RealtimeError{
|
||||
Type: "invalid_request_error",
|
||||
Code: "invalid_request_error",
|
||||
Message: "bad request",
|
||||
Param: "session.type",
|
||||
},
|
||||
)
|
||||
if bifrostErr == nil {
|
||||
t.Fatal("expected bifrost error")
|
||||
}
|
||||
if bifrostErr.StatusCode == nil || *bifrostErr.StatusCode != 400 {
|
||||
t.Fatalf("status code = %#v, want 400", bifrostErr.StatusCode)
|
||||
}
|
||||
if bifrostErr.ExtraFields.RequestType != schemas.RealtimeRequest {
|
||||
t.Fatalf("request type = %q, want %q", bifrostErr.ExtraFields.RequestType, schemas.RealtimeRequest)
|
||||
}
|
||||
if bifrostErr.ExtraFields.Provider != schemas.OpenAI {
|
||||
t.Fatalf("provider = %q, want %q", bifrostErr.ExtraFields.Provider, schemas.OpenAI)
|
||||
}
|
||||
if bifrostErr.ExtraFields.OriginalModelRequested != "gpt-realtime" {
|
||||
t.Fatalf("model requested = %q, want %q", bifrostErr.ExtraFields.OriginalModelRequested, "gpt-realtime")
|
||||
}
|
||||
rawResp, ok := bifrostErr.ExtraFields.RawResponse.(json.RawMessage)
|
||||
if !ok || string(rawResp) != string(rawResponse) {
|
||||
t.Fatalf("raw response = %#v, want %s", bifrostErr.ExtraFields.RawResponse, string(rawResponse))
|
||||
}
|
||||
if bifrostErr.Error == nil || bifrostErr.Error.Param != "session.type" {
|
||||
t.Fatalf("error param = %#v, want session.type", bifrostErr.Error)
|
||||
}
|
||||
}
|
||||
798
transports/bifrost-http/handlers/realtime_turn_pipeline.go
Normal file
798
transports/bifrost-http/handlers/realtime_turn_pipeline.go
Normal file
@@ -0,0 +1,798 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket"
|
||||
)
|
||||
|
||||
func newRealtimeTurnContext(
|
||||
baseCtx *schemas.BifrostContext,
|
||||
requestID string,
|
||||
sessionID string,
|
||||
providerSessionID string,
|
||||
source realtimeTurnSource,
|
||||
eventType schemas.RealtimeEventType,
|
||||
key *schemas.Key,
|
||||
) *schemas.BifrostContext {
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
if baseCtx != nil {
|
||||
// Realtime post-hook contexts must preserve plugin-private values written in
|
||||
// pre-hooks (for example telemetry start timestamps), not just public keys.
|
||||
for ctxKey, value := range baseCtx.GetUserValues() {
|
||||
if value != nil {
|
||||
ctx.SetValue(ctxKey, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ctx.SetValue(schemas.BifrostContextKeyHTTPRequestType, schemas.RealtimeRequest)
|
||||
if requestID == "" {
|
||||
requestID = uuid.NewString()
|
||||
}
|
||||
ctx.SetValue(schemas.BifrostContextKeyRequestID, requestID)
|
||||
resolvedSessionID := strings.TrimSpace(providerSessionID)
|
||||
if resolvedSessionID == "" {
|
||||
resolvedSessionID = strings.TrimSpace(sessionID)
|
||||
}
|
||||
if baseCtx != nil {
|
||||
if externalSessionID, ok := baseCtx.Value(schemas.BifrostContextKeyParentRequestID).(string); ok && strings.TrimSpace(externalSessionID) != "" {
|
||||
resolvedSessionID = strings.TrimSpace(externalSessionID)
|
||||
}
|
||||
}
|
||||
if resolvedSessionID != "" {
|
||||
ctx.SetValue(schemas.BifrostContextKeyParentRequestID, resolvedSessionID)
|
||||
}
|
||||
if strings.TrimSpace(providerSessionID) != "" {
|
||||
ctx.SetValue(schemas.BifrostContextKeyRealtimeSessionID, providerSessionID)
|
||||
ctx.SetValue(schemas.BifrostContextKeyRealtimeProviderSessionID, providerSessionID)
|
||||
}
|
||||
if source != "" {
|
||||
ctx.SetValue(schemas.BifrostContextKeyRealtimeSource, string(source))
|
||||
}
|
||||
if eventType != "" {
|
||||
ctx.SetValue(schemas.BifrostContextKeyRealtimeEventType, string(eventType))
|
||||
}
|
||||
if key != nil {
|
||||
if strings.TrimSpace(key.ID) != "" {
|
||||
ctx.SetValue(schemas.BifrostContextKeySelectedKeyID, key.ID)
|
||||
}
|
||||
if strings.TrimSpace(key.Name) != "" {
|
||||
ctx.SetValue(schemas.BifrostContextKeySelectedKeyName, key.Name)
|
||||
}
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func applyRealtimeTurnContextValues(ctx *schemas.BifrostContext, values map[any]any) {
|
||||
if ctx == nil || len(values) == 0 {
|
||||
return
|
||||
}
|
||||
for ctxKey, value := range values {
|
||||
switch ctxKey {
|
||||
case schemas.BifrostContextKeyRequestID,
|
||||
schemas.BifrostContextKeyParentRequestID,
|
||||
schemas.BifrostContextKeyRealtimeSessionID,
|
||||
schemas.BifrostContextKeyRealtimeProviderSessionID,
|
||||
schemas.BifrostContextKeyRealtimeSource,
|
||||
schemas.BifrostContextKeyRealtimeEventType,
|
||||
schemas.BifrostContextKeyStreamStartTime,
|
||||
schemas.BifrostContextKeyStreamEndIndicator:
|
||||
continue
|
||||
}
|
||||
if value != nil {
|
||||
ctx.SetValue(ctxKey, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func setRealtimeTurnStreamContext(ctx *schemas.BifrostContext, startedAt time.Time, isFinal bool) {
|
||||
if ctx == nil {
|
||||
return
|
||||
}
|
||||
if startedAt.IsZero() {
|
||||
startedAt = time.Now()
|
||||
}
|
||||
ctx.SetValue(schemas.BifrostContextKeyStreamStartTime, startedAt)
|
||||
if isFinal {
|
||||
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
||||
}
|
||||
}
|
||||
|
||||
func buildRealtimeTurnPreRequest(provider schemas.ModelProvider, model string, turnInputs []bfws.RealtimeTurnInput) *schemas.BifrostRequest {
|
||||
input := make([]schemas.ResponsesMessage, 0, len(turnInputs))
|
||||
for _, turnInput := range turnInputs {
|
||||
summary := strings.TrimSpace(turnInput.Summary)
|
||||
if summary == "" {
|
||||
continue
|
||||
}
|
||||
switch turnInput.Role {
|
||||
case string(schemas.ChatMessageRoleTool):
|
||||
itemType := schemas.ResponsesMessageTypeFunctionCallOutput
|
||||
output := &schemas.ResponsesToolMessageOutputStruct{
|
||||
ResponsesToolCallOutputStr: schemas.Ptr(summary),
|
||||
}
|
||||
input = append(input, schemas.ResponsesMessage{
|
||||
Type: &itemType,
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{Output: output},
|
||||
})
|
||||
case string(schemas.ChatMessageRoleUser):
|
||||
itemType := schemas.ResponsesMessageTypeMessage
|
||||
role := schemas.ResponsesInputMessageRoleUser
|
||||
input = append(input, schemas.ResponsesMessage{
|
||||
Type: &itemType,
|
||||
Role: &role,
|
||||
Content: &schemas.ResponsesMessageContent{ContentStr: schemas.Ptr(summary)},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return &schemas.BifrostRequest{
|
||||
RequestType: schemas.RealtimeRequest,
|
||||
ResponsesRequest: &schemas.BifrostResponsesRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Input: input,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func buildRealtimeTurnPostResponse(
|
||||
rtProvider schemas.RealtimeProvider,
|
||||
provider schemas.ModelProvider,
|
||||
model string,
|
||||
rawRequest string,
|
||||
rawResponse []byte,
|
||||
contentOverride string,
|
||||
latency int64,
|
||||
) *schemas.BifrostResponse {
|
||||
output := buildRealtimeTurnOutputMessages(rtProvider, rawResponse, contentOverride)
|
||||
resp := &schemas.BifrostResponsesResponse{
|
||||
Object: "response",
|
||||
Model: model,
|
||||
Output: output,
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
RequestType: schemas.RealtimeRequest,
|
||||
Provider: provider,
|
||||
OriginalModelRequested: model,
|
||||
Latency: latency,
|
||||
},
|
||||
}
|
||||
if usage := extractRealtimeTurnUsage(rtProvider, rawResponse); usage != nil {
|
||||
resp.Usage = buildRealtimeResponsesUsage(usage)
|
||||
}
|
||||
if strings.TrimSpace(rawRequest) != "" {
|
||||
resp.ExtraFields.RawRequest = rawRequest
|
||||
}
|
||||
if len(rawResponse) > 0 {
|
||||
resp.ExtraFields.RawResponse = string(rawResponse)
|
||||
}
|
||||
|
||||
return &schemas.BifrostResponse{ResponsesResponse: resp}
|
||||
}
|
||||
|
||||
func buildRealtimeTurnOutputMessages(rtProvider schemas.RealtimeProvider, rawResponse []byte, contentOverride string) []schemas.ResponsesMessage {
|
||||
outputs := make([]schemas.ResponsesMessage, 0)
|
||||
if outputMessage := extractRealtimeTurnOutputMessage(rtProvider, rawResponse, contentOverride); outputMessage != nil {
|
||||
outputs = append(outputs, buildRealtimeResponsesMessagesFromChat(outputMessage, contentOverride)...)
|
||||
}
|
||||
|
||||
if len(outputs) > 0 {
|
||||
return outputs
|
||||
}
|
||||
|
||||
var parsed realtimeResponseDoneEnvelope
|
||||
if len(rawResponse) > 0 && schemas.Unmarshal(rawResponse, &parsed) == nil {
|
||||
for _, item := range parsed.Response.Output {
|
||||
switch item.Type {
|
||||
case "message":
|
||||
content := strings.TrimSpace(contentOverride)
|
||||
if content == "" {
|
||||
content = extractRealtimeResponseDoneContentText(item.Content)
|
||||
}
|
||||
itemType := schemas.ResponsesMessageTypeMessage
|
||||
role := schemas.ResponsesInputMessageRoleAssistant
|
||||
msg := schemas.ResponsesMessage{
|
||||
Type: &itemType,
|
||||
Role: &role,
|
||||
Status: schemas.Ptr("completed"),
|
||||
}
|
||||
if strings.TrimSpace(item.ID) != "" {
|
||||
msg.ID = schemas.Ptr(strings.TrimSpace(item.ID))
|
||||
}
|
||||
if content != "" {
|
||||
msg.Content = &schemas.ResponsesMessageContent{ContentStr: schemas.Ptr(content)}
|
||||
}
|
||||
outputs = append(outputs, msg)
|
||||
case "function_call":
|
||||
itemType := schemas.ResponsesMessageTypeFunctionCall
|
||||
msg := schemas.ResponsesMessage{
|
||||
Type: &itemType,
|
||||
Status: schemas.Ptr("completed"),
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
Name: schemas.Ptr(strings.TrimSpace(item.Name)),
|
||||
Arguments: schemas.Ptr(item.Arguments),
|
||||
},
|
||||
}
|
||||
if strings.TrimSpace(item.ID) != "" {
|
||||
msg.ID = schemas.Ptr(strings.TrimSpace(item.ID))
|
||||
}
|
||||
if strings.TrimSpace(item.CallID) != "" {
|
||||
msg.CallID = schemas.Ptr(strings.TrimSpace(item.CallID))
|
||||
}
|
||||
outputs = append(outputs, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(outputs) == 0 && strings.TrimSpace(contentOverride) != "" {
|
||||
itemType := schemas.ResponsesMessageTypeMessage
|
||||
role := schemas.ResponsesInputMessageRoleAssistant
|
||||
outputs = append(outputs, schemas.ResponsesMessage{
|
||||
Type: &itemType,
|
||||
Role: &role,
|
||||
Status: schemas.Ptr("completed"),
|
||||
Content: &schemas.ResponsesMessageContent{ContentStr: schemas.Ptr(strings.TrimSpace(contentOverride))},
|
||||
})
|
||||
}
|
||||
|
||||
return outputs
|
||||
}
|
||||
|
||||
func buildRealtimeResponsesMessagesFromChat(message *schemas.ChatMessage, contentOverride string) []schemas.ResponsesMessage {
|
||||
if message == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
outputs := make([]schemas.ResponsesMessage, 0, 1)
|
||||
content := strings.TrimSpace(contentOverride)
|
||||
if content == "" && message.Content != nil && message.Content.ContentStr != nil {
|
||||
content = strings.TrimSpace(*message.Content.ContentStr)
|
||||
}
|
||||
if content != "" {
|
||||
itemType := schemas.ResponsesMessageTypeMessage
|
||||
role := schemas.ResponsesInputMessageRoleAssistant
|
||||
outputs = append(outputs, schemas.ResponsesMessage{
|
||||
Type: &itemType,
|
||||
Role: &role,
|
||||
Status: schemas.Ptr("completed"),
|
||||
Content: &schemas.ResponsesMessageContent{ContentStr: schemas.Ptr(content)},
|
||||
})
|
||||
}
|
||||
|
||||
if message.ChatAssistantMessage == nil {
|
||||
return outputs
|
||||
}
|
||||
|
||||
for _, toolCall := range message.ChatAssistantMessage.ToolCalls {
|
||||
itemType := schemas.ResponsesMessageTypeFunctionCall
|
||||
msg := schemas.ResponsesMessage{
|
||||
Type: &itemType,
|
||||
Status: schemas.Ptr("completed"),
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
Arguments: schemas.Ptr(toolCall.Function.Arguments),
|
||||
},
|
||||
}
|
||||
if toolCall.Function.Name != nil {
|
||||
msg.ResponsesToolMessage.Name = schemas.Ptr(strings.TrimSpace(*toolCall.Function.Name))
|
||||
}
|
||||
if toolCall.ID != nil {
|
||||
msg.CallID = schemas.Ptr(strings.TrimSpace(*toolCall.ID))
|
||||
msg.ID = schemas.Ptr(strings.TrimSpace(*toolCall.ID))
|
||||
}
|
||||
outputs = append(outputs, msg)
|
||||
}
|
||||
|
||||
return outputs
|
||||
}
|
||||
|
||||
func extractRealtimeResponseDoneContentText(content []realtimeResponseDoneContent) string {
|
||||
for _, block := range content {
|
||||
switch {
|
||||
case strings.TrimSpace(block.Text) != "":
|
||||
return strings.TrimSpace(block.Text)
|
||||
case strings.TrimSpace(block.Transcript) != "":
|
||||
return strings.TrimSpace(block.Transcript)
|
||||
case strings.TrimSpace(block.Refusal) != "":
|
||||
return strings.TrimSpace(block.Refusal)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func buildRealtimeResponsesUsage(usage *schemas.BifrostLLMUsage) *schemas.ResponsesResponseUsage {
|
||||
if usage == nil {
|
||||
return nil
|
||||
}
|
||||
result := &schemas.ResponsesResponseUsage{
|
||||
InputTokens: usage.PromptTokens,
|
||||
OutputTokens: usage.CompletionTokens,
|
||||
TotalTokens: usage.TotalTokens,
|
||||
}
|
||||
if usage.PromptTokensDetails != nil {
|
||||
result.InputTokensDetails = &schemas.ResponsesResponseInputTokens{
|
||||
TextTokens: usage.PromptTokensDetails.TextTokens,
|
||||
AudioTokens: usage.PromptTokensDetails.AudioTokens,
|
||||
ImageTokens: usage.PromptTokensDetails.ImageTokens,
|
||||
CachedReadTokens: usage.PromptTokensDetails.CachedReadTokens,
|
||||
CachedWriteTokens: usage.PromptTokensDetails.CachedWriteTokens,
|
||||
}
|
||||
}
|
||||
if usage.CompletionTokensDetails != nil {
|
||||
result.OutputTokensDetails = &schemas.ResponsesResponseOutputTokens{
|
||||
TextTokens: usage.CompletionTokensDetails.TextTokens,
|
||||
AcceptedPredictionTokens: usage.CompletionTokensDetails.AcceptedPredictionTokens,
|
||||
AudioTokens: usage.CompletionTokensDetails.AudioTokens,
|
||||
ImageTokens: usage.CompletionTokensDetails.ImageTokens,
|
||||
ReasoningTokens: usage.CompletionTokensDetails.ReasoningTokens,
|
||||
RejectedPredictionTokens: usage.CompletionTokensDetails.RejectedPredictionTokens,
|
||||
CitationTokens: usage.CompletionTokensDetails.CitationTokens,
|
||||
NumSearchQueries: usage.CompletionTokensDetails.NumSearchQueries,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func newRealtimeTurnErrorEventPayload(bifrostErr *schemas.BifrostError) []byte {
|
||||
if bifrostErr == nil {
|
||||
return []byte(`{"type":"error","error":{"type":"server_error","message":"internal server error"}}`)
|
||||
}
|
||||
|
||||
errorType, errorCode, errorMessage, errorParam := mapRealtimeWireErrorFields(bifrostErr)
|
||||
payload := schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventError,
|
||||
Error: &schemas.RealtimeError{
|
||||
Type: errorType,
|
||||
Code: errorCode,
|
||||
Message: errorMessage,
|
||||
Param: errorParam,
|
||||
},
|
||||
}
|
||||
if data, err := schemas.Marshal(payload); err == nil {
|
||||
return data
|
||||
}
|
||||
return []byte(`{"type":"error","error":{"type":"server_error","message":"internal server error"}}`)
|
||||
}
|
||||
|
||||
// isBudgetOrBillingError returns true if the lowercased value indicates a budget or billing exhaustion error.
|
||||
// Quota/rate-limit patterns (quota_exceeded, quota exceeded, etc.) are already covered by bifrost.IsRateLimitErrorMessage.
|
||||
func isBudgetOrBillingError(lower string) bool {
|
||||
return strings.Contains(lower, "budget_exceeded") ||
|
||||
strings.Contains(lower, "budget exceeded") ||
|
||||
strings.Contains(lower, "insufficient_quota") ||
|
||||
strings.Contains(lower, "hard limit reached") ||
|
||||
strings.Contains(lower, "billing hard limit")
|
||||
}
|
||||
|
||||
func mapRealtimeWireErrorFields(bifrostErr *schemas.BifrostError) (string, string, string, string) {
|
||||
errorType := "server_error"
|
||||
errorCode := "server_error"
|
||||
errorMessage := "internal server error"
|
||||
errorParam := ""
|
||||
|
||||
if bifrostErr == nil {
|
||||
return errorType, errorCode, errorMessage, errorParam
|
||||
}
|
||||
|
||||
var values []string
|
||||
if bifrostErr.Type != nil {
|
||||
values = append(values, strings.TrimSpace(*bifrostErr.Type))
|
||||
}
|
||||
if bifrostErr.Error != nil {
|
||||
if bifrostErr.Error.Type != nil {
|
||||
values = append(values, strings.TrimSpace(*bifrostErr.Error.Type))
|
||||
}
|
||||
if bifrostErr.Error.Code != nil {
|
||||
values = append(values, strings.TrimSpace(*bifrostErr.Error.Code))
|
||||
}
|
||||
if strings.TrimSpace(bifrostErr.Error.Message) != "" {
|
||||
errorMessage = strings.TrimSpace(bifrostErr.Error.Message)
|
||||
values = append(values, errorMessage)
|
||||
}
|
||||
if bifrostErr.Error.Param != nil {
|
||||
errorParam = strings.TrimSpace(fmt.Sprint(bifrostErr.Error.Param))
|
||||
}
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
lower := strings.ToLower(value)
|
||||
switch {
|
||||
case lower == "":
|
||||
continue
|
||||
case strings.Contains(lower, "invalid_request_error"):
|
||||
return "invalid_request_error", "invalid_request_error", errorMessage, errorParam
|
||||
case isBudgetOrBillingError(lower):
|
||||
return "insufficient_quota", "insufficient_quota", errorMessage, errorParam
|
||||
case bifrost.IsRateLimitErrorMessage(lower):
|
||||
return "rate_limit_exceeded", "rate_limit_exceeded", errorMessage, errorParam
|
||||
}
|
||||
}
|
||||
|
||||
return errorType, errorCode, errorMessage, errorParam
|
||||
}
|
||||
|
||||
func shouldGracefullyDisconnectRealtime(bifrostErr *schemas.BifrostError) bool {
|
||||
if bifrostErr == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var values []string
|
||||
if bifrostErr.Type != nil {
|
||||
values = append(values, strings.TrimSpace(*bifrostErr.Type))
|
||||
}
|
||||
if bifrostErr.Error != nil {
|
||||
if bifrostErr.Error.Type != nil {
|
||||
values = append(values, strings.TrimSpace(*bifrostErr.Error.Type))
|
||||
}
|
||||
if bifrostErr.Error.Code != nil {
|
||||
values = append(values, strings.TrimSpace(*bifrostErr.Error.Code))
|
||||
}
|
||||
values = append(values, strings.TrimSpace(bifrostErr.Error.Message))
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
lower := strings.ToLower(value)
|
||||
if lower == "" {
|
||||
continue
|
||||
}
|
||||
if isBudgetOrBillingError(lower) || bifrost.IsRateLimitErrorMessage(lower) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func startRealtimeTurnHooks(
|
||||
client *bifrost.Bifrost,
|
||||
baseCtx *schemas.BifrostContext,
|
||||
session *bfws.Session,
|
||||
rtProvider schemas.RealtimeProvider,
|
||||
provider schemas.ModelProvider,
|
||||
model string,
|
||||
key *schemas.Key,
|
||||
startEventType schemas.RealtimeEventType,
|
||||
) *schemas.BifrostError {
|
||||
if client == nil || session == nil {
|
||||
return &schemas.BifrostError{
|
||||
Type: schemas.Ptr("server_error"),
|
||||
StatusCode: schemas.Ptr(500),
|
||||
Error: &schemas.ErrorField{
|
||||
Type: schemas.Ptr("server_error"),
|
||||
Message: "realtime turn pipeline is unavailable",
|
||||
},
|
||||
}
|
||||
}
|
||||
if !session.TryBeginRealtimeTurnHooks() {
|
||||
return &schemas.BifrostError{
|
||||
Type: schemas.Ptr("invalid_request_error"),
|
||||
StatusCode: schemas.Ptr(400),
|
||||
Error: &schemas.ErrorField{
|
||||
Type: schemas.Ptr("invalid_request_error"),
|
||||
Message: "Conversation already has an active response in progress.",
|
||||
},
|
||||
}
|
||||
}
|
||||
committed := false
|
||||
defer func() {
|
||||
if !committed {
|
||||
session.AbortRealtimeTurnHooks()
|
||||
}
|
||||
}()
|
||||
|
||||
startedAt := time.Now()
|
||||
turnCtx := newRealtimeTurnContext(baseCtx, "", session.ID(), session.ProviderSessionID(), realtimeTurnSourceEI, startEventType, key)
|
||||
setRealtimeTurnStreamContext(turnCtx, startedAt, false)
|
||||
req := buildRealtimeTurnPreRequest(provider, model, session.PeekRealtimeTurnInputs())
|
||||
hooks, bifrostErr := client.RunRealtimeTurnPreHooks(turnCtx, req)
|
||||
if bifrostErr != nil {
|
||||
// RunRealtimeTurnPreHooks already executed post-hooks and flushed the trace
|
||||
// for this turn-start failure. Clear buffered turn state so transport-close
|
||||
// fallback finalization does not emit the same error a second time.
|
||||
session.ConsumeRealtimeTurnInputs()
|
||||
session.ConsumeRealtimeOutputText()
|
||||
return bifrostErr
|
||||
}
|
||||
|
||||
requestID, _ := turnCtx.Value(schemas.BifrostContextKeyRequestID).(string)
|
||||
session.SetRealtimeTurnHooks(&bfws.RealtimeTurnPluginState{
|
||||
PostHookRunner: hooks.PostHookRunner,
|
||||
Cleanup: hooks.Cleanup,
|
||||
RequestID: requestID,
|
||||
StartedAt: startedAt,
|
||||
PreHookValues: turnCtx.GetUserValues(),
|
||||
})
|
||||
committed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func finalizeRealtimeTurnHooks(
|
||||
client *bifrost.Bifrost,
|
||||
baseCtx *schemas.BifrostContext,
|
||||
session *bfws.Session,
|
||||
rtProvider schemas.RealtimeProvider,
|
||||
provider schemas.ModelProvider,
|
||||
model string,
|
||||
key *schemas.Key,
|
||||
rawResponse []byte,
|
||||
contentOverride string,
|
||||
) *schemas.BifrostError {
|
||||
if client == nil || session == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
turnInputs := session.ConsumeRealtimeTurnInputs()
|
||||
rawRequest := combineRealtimeInputRaw(turnInputs)
|
||||
|
||||
if activeHooks := session.ConsumeRealtimeTurnHooks(); activeHooks != nil {
|
||||
defer func() {
|
||||
if activeHooks.Cleanup != nil {
|
||||
activeHooks.Cleanup()
|
||||
}
|
||||
}()
|
||||
postResponse := buildRealtimeTurnPostResponse(
|
||||
rtProvider,
|
||||
provider,
|
||||
model,
|
||||
rawRequest,
|
||||
rawResponse,
|
||||
contentOverride,
|
||||
time.Since(activeHooks.StartedAt).Milliseconds(),
|
||||
)
|
||||
postCtx := newRealtimeTurnContext(baseCtx, activeHooks.RequestID, session.ID(), session.ProviderSessionID(), realtimeTurnSourceLM, rtProvider.RealtimeTurnFinalEvent(), key)
|
||||
applyRealtimeTurnContextValues(postCtx, activeHooks.PreHookValues)
|
||||
setRealtimeTurnStreamContext(postCtx, activeHooks.StartedAt, true)
|
||||
_, bifrostErr := activeHooks.PostHookRunner(postCtx, postResponse, nil)
|
||||
completeRealtimeTurnTrace(postCtx)
|
||||
return bifrostErr
|
||||
}
|
||||
|
||||
startedAt := time.Now()
|
||||
preCtx := newRealtimeTurnContext(baseCtx, "", session.ID(), session.ProviderSessionID(), realtimeTurnSourceEI, "", key)
|
||||
setRealtimeTurnStreamContext(preCtx, startedAt, false)
|
||||
preReq := buildRealtimeTurnPreRequest(provider, model, turnInputs)
|
||||
hooks, bifrostErr := client.RunRealtimeTurnPreHooks(preCtx, preReq)
|
||||
if bifrostErr != nil {
|
||||
return bifrostErr
|
||||
}
|
||||
if hooks.Cleanup != nil {
|
||||
defer hooks.Cleanup()
|
||||
}
|
||||
|
||||
requestID, _ := preCtx.Value(schemas.BifrostContextKeyRequestID).(string)
|
||||
postResponse := buildRealtimeTurnPostResponse(
|
||||
rtProvider,
|
||||
provider,
|
||||
model,
|
||||
rawRequest,
|
||||
rawResponse,
|
||||
contentOverride,
|
||||
time.Since(startedAt).Milliseconds(),
|
||||
)
|
||||
postCtx := newRealtimeTurnContext(baseCtx, requestID, session.ID(), session.ProviderSessionID(), realtimeTurnSourceLM, rtProvider.RealtimeTurnFinalEvent(), key)
|
||||
applyRealtimeTurnContextValues(postCtx, preCtx.GetUserValues())
|
||||
setRealtimeTurnStreamContext(postCtx, startedAt, true)
|
||||
_, bifrostErr = hooks.PostHookRunner(postCtx, postResponse, nil)
|
||||
completeRealtimeTurnTrace(postCtx)
|
||||
return bifrostErr
|
||||
}
|
||||
|
||||
func finalizeRealtimeTurnHooksWithError(
|
||||
client *bifrost.Bifrost,
|
||||
baseCtx *schemas.BifrostContext,
|
||||
session *bfws.Session,
|
||||
provider schemas.ModelProvider,
|
||||
model string,
|
||||
key *schemas.Key,
|
||||
eventType schemas.RealtimeEventType,
|
||||
rawResponse []byte,
|
||||
bifrostErr *schemas.BifrostError,
|
||||
) *schemas.BifrostError {
|
||||
if session == nil || bifrostErr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
turnInputs := session.ConsumeRealtimeTurnInputs()
|
||||
rawRequest := combineRealtimeInputRaw(turnInputs)
|
||||
session.ConsumeRealtimeOutputText()
|
||||
|
||||
if activeHooks := session.ConsumeRealtimeTurnHooks(); activeHooks != nil {
|
||||
defer func() {
|
||||
if activeHooks.Cleanup != nil {
|
||||
activeHooks.Cleanup()
|
||||
}
|
||||
}()
|
||||
postErr := buildRealtimeTurnPostError(
|
||||
provider,
|
||||
model,
|
||||
rawRequest,
|
||||
rawResponse,
|
||||
bifrostErr,
|
||||
)
|
||||
postCtx := newRealtimeTurnContext(baseCtx, activeHooks.RequestID, session.ID(), session.ProviderSessionID(), realtimeTurnSourceLM, eventType, key)
|
||||
applyRealtimeTurnContextValues(postCtx, activeHooks.PreHookValues)
|
||||
setRealtimeTurnStreamContext(postCtx, activeHooks.StartedAt, true)
|
||||
_, hookErr := activeHooks.PostHookRunner(postCtx, nil, postErr)
|
||||
completeRealtimeTurnTrace(postCtx)
|
||||
return hookErr
|
||||
}
|
||||
|
||||
if len(turnInputs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
startedAt := time.Now()
|
||||
preCtx := newRealtimeTurnContext(baseCtx, "", session.ID(), session.ProviderSessionID(), realtimeTurnSourceEI, "", key)
|
||||
setRealtimeTurnStreamContext(preCtx, startedAt, false)
|
||||
preReq := buildRealtimeTurnPreRequest(provider, model, turnInputs)
|
||||
hooks, hookPreErr := client.RunRealtimeTurnPreHooks(preCtx, preReq)
|
||||
if hookPreErr != nil {
|
||||
return hookPreErr
|
||||
}
|
||||
if hooks.Cleanup != nil {
|
||||
defer hooks.Cleanup()
|
||||
}
|
||||
|
||||
requestID, _ := preCtx.Value(schemas.BifrostContextKeyRequestID).(string)
|
||||
postErr := buildRealtimeTurnPostError(
|
||||
provider,
|
||||
model,
|
||||
rawRequest,
|
||||
rawResponse,
|
||||
bifrostErr,
|
||||
)
|
||||
postCtx := newRealtimeTurnContext(baseCtx, requestID, session.ID(), session.ProviderSessionID(), realtimeTurnSourceLM, eventType, key)
|
||||
applyRealtimeTurnContextValues(postCtx, preCtx.GetUserValues())
|
||||
setRealtimeTurnStreamContext(postCtx, startedAt, true)
|
||||
_, hookErr := hooks.PostHookRunner(postCtx, nil, postErr)
|
||||
completeRealtimeTurnTrace(postCtx)
|
||||
return hookErr
|
||||
}
|
||||
|
||||
func buildRealtimeTurnPostError(
|
||||
provider schemas.ModelProvider,
|
||||
model string,
|
||||
rawRequest string,
|
||||
rawResponse []byte,
|
||||
bifrostErr *schemas.BifrostError,
|
||||
) *schemas.BifrostError {
|
||||
if bifrostErr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
copied := *bifrostErr
|
||||
copied.ExtraFields = bifrostErr.ExtraFields
|
||||
if bifrostErr.Error != nil {
|
||||
errorCopy := *bifrostErr.Error
|
||||
copied.Error = &errorCopy
|
||||
}
|
||||
copied.ExtraFields.RequestType = schemas.RealtimeRequest
|
||||
if copied.ExtraFields.Provider == "" {
|
||||
copied.ExtraFields.Provider = provider
|
||||
}
|
||||
if strings.TrimSpace(copied.ExtraFields.OriginalModelRequested) == "" {
|
||||
copied.ExtraFields.OriginalModelRequested = model
|
||||
}
|
||||
if strings.TrimSpace(rawRequest) != "" && copied.ExtraFields.RawRequest == nil {
|
||||
copied.ExtraFields.RawRequest = rawRequest
|
||||
}
|
||||
if len(rawResponse) > 0 && copied.ExtraFields.RawResponse == nil {
|
||||
copied.ExtraFields.RawResponse = json.RawMessage(append([]byte(nil), rawResponse...))
|
||||
}
|
||||
return &copied
|
||||
}
|
||||
|
||||
func newBifrostErrorFromRealtimeError(
|
||||
provider schemas.ModelProvider,
|
||||
model string,
|
||||
rawResponse []byte,
|
||||
realtimeErr *schemas.RealtimeError,
|
||||
) *schemas.BifrostError {
|
||||
if realtimeErr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
statusCode := 500
|
||||
values := []string{
|
||||
strings.TrimSpace(realtimeErr.Type),
|
||||
strings.TrimSpace(realtimeErr.Code),
|
||||
strings.TrimSpace(realtimeErr.Message),
|
||||
}
|
||||
for _, value := range values {
|
||||
lower := strings.ToLower(value)
|
||||
switch {
|
||||
case lower == "":
|
||||
continue
|
||||
case strings.Contains(lower, "invalid_request_error"):
|
||||
statusCode = 400
|
||||
case isBudgetOrBillingError(lower), bifrost.IsRateLimitErrorMessage(lower):
|
||||
statusCode = 429
|
||||
}
|
||||
}
|
||||
|
||||
errType := strings.TrimSpace(realtimeErr.Type)
|
||||
if errType == "" {
|
||||
errType = "server_error"
|
||||
}
|
||||
errCode := strings.TrimSpace(realtimeErr.Code)
|
||||
if errCode == "" {
|
||||
errCode = errType
|
||||
}
|
||||
message := strings.TrimSpace(realtimeErr.Message)
|
||||
if message == "" {
|
||||
message = "realtime turn failed"
|
||||
}
|
||||
|
||||
bifrostErr := &schemas.BifrostError{
|
||||
IsBifrostError: true,
|
||||
StatusCode: schemas.Ptr(statusCode),
|
||||
Type: schemas.Ptr(errType),
|
||||
Error: &schemas.ErrorField{
|
||||
Type: schemas.Ptr(errType),
|
||||
Code: schemas.Ptr(errCode),
|
||||
Message: message,
|
||||
},
|
||||
ExtraFields: schemas.BifrostErrorExtraFields{
|
||||
Provider: provider,
|
||||
OriginalModelRequested: model,
|
||||
RequestType: schemas.RealtimeRequest,
|
||||
},
|
||||
}
|
||||
if strings.TrimSpace(realtimeErr.Param) != "" {
|
||||
bifrostErr.Error.Param = realtimeErr.Param
|
||||
}
|
||||
if len(rawResponse) > 0 {
|
||||
bifrostErr.ExtraFields.RawResponse = json.RawMessage(append([]byte(nil), rawResponse...))
|
||||
}
|
||||
return bifrostErr
|
||||
}
|
||||
|
||||
func completeRealtimeTurnTrace(ctx *schemas.BifrostContext) {
|
||||
if ctx == nil {
|
||||
return
|
||||
}
|
||||
traceID, _ := ctx.Value(schemas.BifrostContextKeyTraceID).(string)
|
||||
if strings.TrimSpace(traceID) == "" {
|
||||
return
|
||||
}
|
||||
tracer, _ := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer)
|
||||
if tracer == nil {
|
||||
return
|
||||
}
|
||||
tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID))
|
||||
}
|
||||
|
||||
func finalizeRealtimeTurnHooksOnTransportError(
|
||||
client *bifrost.Bifrost,
|
||||
baseCtx *schemas.BifrostContext,
|
||||
session *bfws.Session,
|
||||
provider schemas.ModelProvider,
|
||||
model string,
|
||||
key *schemas.Key,
|
||||
status int,
|
||||
code string,
|
||||
message string,
|
||||
) *schemas.BifrostError {
|
||||
return finalizeRealtimeTurnHooksWithError(
|
||||
client,
|
||||
baseCtx,
|
||||
session,
|
||||
provider,
|
||||
model,
|
||||
key,
|
||||
schemas.RTEventError,
|
||||
nil,
|
||||
newRealtimeWireBifrostError(status, code, message),
|
||||
)
|
||||
}
|
||||
230
transports/bifrost-http/handlers/session.go
Normal file
230
transports/bifrost-http/handlers/session.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/google/uuid"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
"github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/maximhq/bifrost/framework/encrypt"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// SessionHandler manages HTTP requests for session operations
|
||||
type SessionHandler struct {
|
||||
configStore configstore.ConfigStore
|
||||
wsTicketStore *WSTicketStore
|
||||
}
|
||||
|
||||
// NewSessionHandler creates a new session handler instance
|
||||
func NewSessionHandler(configStore configstore.ConfigStore, wsTicketStore *WSTicketStore) *SessionHandler {
|
||||
return &SessionHandler{
|
||||
configStore: configStore,
|
||||
wsTicketStore: wsTicketStore,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the session-related routes
|
||||
func (h *SessionHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
r.POST("/api/session/login", lib.ChainMiddlewares(h.login, middlewares...))
|
||||
r.POST("/api/session/logout", lib.ChainMiddlewares(h.logout, middlewares...))
|
||||
r.GET("/api/session/is-auth-enabled", lib.ChainMiddlewares(h.isAuthEnabled, middlewares...))
|
||||
r.POST("/api/session/ws-ticket", lib.ChainMiddlewares(h.issueWSTicket, middlewares...))
|
||||
}
|
||||
|
||||
// isAuthEnabled handles GET /api/session/is-auth-enabled - Check if auth is enabled
|
||||
func (h *SessionHandler) isAuthEnabled(ctx *fasthttp.RequestCtx) {
|
||||
if h.configStore == nil {
|
||||
SendJSON(ctx, map[string]any{
|
||||
"is_auth_enabled": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
authConfig, err := h.configStore.GetAuthConfig(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get auth config: %v", err))
|
||||
return
|
||||
}
|
||||
if authConfig == nil {
|
||||
SendJSON(ctx, map[string]any{
|
||||
"is_auth_enabled": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
// Check if the header has a token and is valid (Authorization header or cookie)
|
||||
token := ""
|
||||
if authHeader := string(ctx.Request.Header.Peek("Authorization")); strings.HasPrefix(authHeader, "Bearer ") {
|
||||
token = strings.TrimPrefix(authHeader, "Bearer ")
|
||||
}
|
||||
if token == "" {
|
||||
token = string(ctx.Request.Header.Cookie("token"))
|
||||
}
|
||||
hasValidToken := false
|
||||
if token != "" {
|
||||
session, err := h.configStore.GetSession(ctx, token)
|
||||
if err == nil && session != nil && session.ExpiresAt.After(time.Now()) {
|
||||
hasValidToken = true
|
||||
}
|
||||
}
|
||||
SendJSON(ctx, map[string]any{
|
||||
"is_auth_enabled": authConfig.IsEnabled,
|
||||
"has_valid_token": hasValidToken,
|
||||
})
|
||||
}
|
||||
|
||||
// login handles POST /api/session/login - Login a user
|
||||
func (h *SessionHandler) login(ctx *fasthttp.RequestCtx) {
|
||||
if h.configStore == nil {
|
||||
SendError(ctx, fasthttp.StatusForbidden, "Authentication is not enabled")
|
||||
return
|
||||
}
|
||||
payload := struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}{}
|
||||
if err := json.Unmarshal(ctx.PostBody(), &payload); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Get auth config
|
||||
authConfig, err := h.configStore.GetAuthConfig(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get auth config: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Check if auth is enabled
|
||||
if authConfig == nil || !authConfig.IsEnabled {
|
||||
SendError(ctx, fasthttp.StatusForbidden, "Authentication is not enabled")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify credentials
|
||||
if payload.Username != authConfig.AdminUserName.GetValue() {
|
||||
SendError(ctx, fasthttp.StatusUnauthorized, "Invalid username or password")
|
||||
return
|
||||
}
|
||||
compare, err := encrypt.CompareHash(authConfig.AdminPassword.GetValue(), payload.Password)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
}
|
||||
if !compare {
|
||||
SendError(ctx, fasthttp.StatusUnauthorized, "Invalid username or password")
|
||||
return
|
||||
}
|
||||
|
||||
// Creating a new session
|
||||
token := uuid.New().String()
|
||||
session := &tables.SessionsTable{
|
||||
Token: token,
|
||||
ExpiresAt: time.Now().Add(time.Hour * 24 * 30), // 30 days
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
err = h.configStore.CreateSession(ctx, session)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to create session: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Setting cookies
|
||||
cookie := fasthttp.AcquireCookie()
|
||||
defer fasthttp.ReleaseCookie(cookie)
|
||||
cookie.SetKey("token")
|
||||
cookie.SetValue(token)
|
||||
cookie.SetExpire(time.Now().Add(time.Hour * 24 * 30))
|
||||
cookie.SetPath("/")
|
||||
cookie.SetHTTPOnly(true)
|
||||
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
|
||||
// Check if source is https then set secure
|
||||
if string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" {
|
||||
cookie.SetSecure(true)
|
||||
}
|
||||
ctx.Response.Header.SetCookie(cookie)
|
||||
|
||||
SendJSON(ctx, map[string]any{
|
||||
"message": "Login successful",
|
||||
})
|
||||
}
|
||||
|
||||
// logout handles POST /api/session/logout - Logout a user
|
||||
func (h *SessionHandler) logout(ctx *fasthttp.RequestCtx) {
|
||||
if h.configStore == nil {
|
||||
SendError(ctx, fasthttp.StatusForbidden, "Authentication is not enabled")
|
||||
return
|
||||
}
|
||||
// Get token from Authorization header
|
||||
token := string(ctx.Request.Header.Peek("Authorization"))
|
||||
token = strings.TrimPrefix(token, "Bearer ")
|
||||
|
||||
// If no token in header, try to get from cookie
|
||||
if token == "" {
|
||||
token = string(ctx.Request.Header.Cookie("token"))
|
||||
}
|
||||
|
||||
// clear token from cookies
|
||||
cookie := fasthttp.AcquireCookie()
|
||||
defer fasthttp.ReleaseCookie(cookie)
|
||||
cookie.SetKey("token")
|
||||
cookie.SetValue("")
|
||||
cookie.SetExpire(time.Now().Add(-time.Hour * 24 * 30))
|
||||
cookie.SetPath("/")
|
||||
cookie.SetHTTPOnly(true)
|
||||
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
|
||||
// Check if source is https then set secure
|
||||
if string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" {
|
||||
cookie.SetSecure(true)
|
||||
}
|
||||
ctx.Response.Header.SetCookie(cookie)
|
||||
|
||||
// delete session from database if token exists
|
||||
if token != "" {
|
||||
err := h.configStore.DeleteSession(ctx, token)
|
||||
if err != nil && !errors.Is(err, configstore.ErrNotFound) {
|
||||
logger.Error("failed to delete session during logout: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to invalidate session. Please try again.")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
SendJSON(ctx, map[string]any{
|
||||
"message": "Logout successful",
|
||||
})
|
||||
}
|
||||
|
||||
// issueWSTicket handles POST /api/session/ws-ticket - Issue a short-lived ticket for WebSocket auth.
|
||||
// The caller must already be authenticated (via cookie or Authorization header).
|
||||
// Returns a one-time-use ticket that the frontend passes as ?ticket= when opening the WebSocket.
|
||||
func (h *SessionHandler) issueWSTicket(ctx *fasthttp.RequestCtx) {
|
||||
if h.wsTicketStore == nil {
|
||||
SendError(ctx, fasthttp.StatusServiceUnavailable, "WebSocket tickets are not available")
|
||||
return
|
||||
}
|
||||
sessionToken,ok := ctx.UserValue(schemas.BifrostContextKeySessionToken).(string)
|
||||
if !ok {
|
||||
SendError(ctx, fasthttp.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
}
|
||||
if sessionToken == "" {
|
||||
// This is the case where auth is not configured or not enabled
|
||||
sessionToken = "dummy-session"
|
||||
}
|
||||
ticket, err := h.wsTicketStore.Issue(sessionToken)
|
||||
if err != nil {
|
||||
logger.Error("failed to issue WS ticket: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to issue WebSocket ticket")
|
||||
return
|
||||
}
|
||||
SendJSON(ctx, map[string]any{
|
||||
"ticket": ticket,
|
||||
})
|
||||
}
|
||||
127
transports/bifrost-http/handlers/ssestreaming_test.go
Normal file
127
transports/bifrost-http/handlers/ssestreaming_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// TestSSEStreamReaderNoEventBatching verifies that SSE events are delivered
|
||||
// individually through fasthttp's chunked transfer encoding, not batched
|
||||
// into larger TCP segments. This is the core regression test for the
|
||||
// fasthttputil.PipeConns batching fix.
|
||||
func TestSSEStreamReaderNoEventBatching(t *testing.T) {
|
||||
const numEvents = 20
|
||||
|
||||
// Build expected events
|
||||
events := make([]string, numEvents)
|
||||
for i := range events {
|
||||
events[i] = fmt.Sprintf("data: {\"index\":%d,\"content\":\"chunk-%d\"}\n\n", i, i)
|
||||
}
|
||||
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetContentType("text/event-stream")
|
||||
ctx.Response.Header.Set("Cache-Control", "no-cache")
|
||||
|
||||
reader := lib.NewSSEStreamReader()
|
||||
|
||||
go func() {
|
||||
defer reader.Done()
|
||||
for _, event := range events {
|
||||
if !reader.Send([]byte(event)) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
ctx.Response.SetBodyStream(reader, -1)
|
||||
}
|
||||
|
||||
// Use net.Pipe for deterministic in-process testing
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer clientConn.Close()
|
||||
|
||||
// Run fasthttp server on one end of the pipe
|
||||
go func() {
|
||||
_ = fasthttp.ServeConn(serverConn, handler)
|
||||
}()
|
||||
|
||||
// Send HTTP request through the pipe
|
||||
_, err := clientConn.Write([]byte("GET /stream HTTP/1.1\r\nHost: test\r\n\r\n"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write request: %v", err)
|
||||
}
|
||||
|
||||
// Read response using bufio to parse chunked encoding
|
||||
br := bufio.NewReader(clientConn)
|
||||
|
||||
// Read and skip HTTP response headers
|
||||
for {
|
||||
line, err := br.ReadString('\n')
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read response header: %v", err)
|
||||
}
|
||||
if strings.TrimSpace(line) == "" {
|
||||
break // End of headers
|
||||
}
|
||||
}
|
||||
|
||||
// Read chunked transfer-encoded body.
|
||||
// Each HTTP chunk should contain exactly one SSE event.
|
||||
var receivedEvents []string
|
||||
for {
|
||||
// Read chunk size line (hex size + CRLF)
|
||||
sizeLine, err := br.ReadString('\n')
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read chunk size: %v", err)
|
||||
}
|
||||
sizeLine = strings.TrimSpace(sizeLine)
|
||||
|
||||
var chunkSize int
|
||||
_, err = fmt.Sscanf(sizeLine, "%x", &chunkSize)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse chunk size %q: %v", sizeLine, err)
|
||||
}
|
||||
|
||||
if chunkSize == 0 {
|
||||
break // Terminal chunk
|
||||
}
|
||||
|
||||
// Read exactly chunkSize bytes + trailing CRLF
|
||||
chunkData := make([]byte, chunkSize+2) // +2 for CRLF
|
||||
n := 0
|
||||
for n < len(chunkData) {
|
||||
nn, err := br.Read(chunkData[n:])
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read chunk data: %v", err)
|
||||
}
|
||||
n += nn
|
||||
}
|
||||
|
||||
chunk := string(chunkData[:chunkSize])
|
||||
receivedEvents = append(receivedEvents, chunk)
|
||||
}
|
||||
|
||||
// Verify each chunk contains exactly one SSE event
|
||||
if len(receivedEvents) != numEvents {
|
||||
t.Errorf("expected %d individual chunks, got %d (events were batched)", numEvents, len(receivedEvents))
|
||||
for i, chunk := range receivedEvents {
|
||||
eventCount := strings.Count(chunk, "\n\n")
|
||||
t.Logf(" chunk %d: %d SSE events, %d bytes", i, eventCount, len(chunk))
|
||||
}
|
||||
}
|
||||
|
||||
for i, chunk := range receivedEvents {
|
||||
if i >= len(events) {
|
||||
break
|
||||
}
|
||||
if chunk != events[i] {
|
||||
t.Errorf("chunk %d: got %q, want %q", i, chunk, events[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
136
transports/bifrost-http/handlers/ui.go
Normal file
136
transports/bifrost-http/handlers/ui.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"mime"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// UIHandler handles UI routes.
|
||||
type UIHandler struct {
|
||||
uiContent embed.FS
|
||||
}
|
||||
|
||||
// NewUIHandler creates a new UIHandler instance.
|
||||
func NewUIHandler(uiContent embed.FS) *UIHandler {
|
||||
return &UIHandler{
|
||||
uiContent: uiContent,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the UI routes with the provided router.
|
||||
func (h *UIHandler) RegisterRoutes(router *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
router.GET("/", lib.ChainMiddlewares(h.serveDashboard, middlewares...))
|
||||
router.GET("/{filepath:*}", lib.ChainMiddlewares(h.serveDashboard, middlewares...))
|
||||
}
|
||||
|
||||
// ServeDashboard serves the dashboard UI.
|
||||
func (h *UIHandler) serveDashboard(ctx *fasthttp.RequestCtx) {
|
||||
// Get the request path
|
||||
requestPath := string(ctx.Path())
|
||||
|
||||
// Clean the path to prevent directory traversal
|
||||
cleanPath := path.Clean(requestPath)
|
||||
|
||||
// Handle .txt files - map from /{page}.txt to /{page}/index.txt
|
||||
if strings.HasSuffix(cleanPath, ".txt") {
|
||||
// Remove .txt extension and add /index.txt
|
||||
basePath := strings.TrimSuffix(cleanPath, ".txt")
|
||||
if basePath == "/" || basePath == "" {
|
||||
basePath = "/index"
|
||||
}
|
||||
cleanPath = basePath + "/index.txt"
|
||||
}
|
||||
|
||||
// Remove leading slash and add ui prefix
|
||||
if cleanPath == "/" {
|
||||
cleanPath = "ui/index.html"
|
||||
} else {
|
||||
cleanPath = "ui" + cleanPath
|
||||
}
|
||||
|
||||
// Block hidden directories and files (any path segment starting with .)
|
||||
segments := strings.Split(cleanPath, "/")
|
||||
for _, segment := range segments {
|
||||
if strings.HasPrefix(segment, ".") {
|
||||
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||
ctx.SetBodyString("404 - Not found")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Block sensitive files
|
||||
baseName := filepath.Base(cleanPath)
|
||||
sensitiveFiles := []string{"package.json", "package-lock.json"}
|
||||
for _, sensitive := range sensitiveFiles {
|
||||
if baseName == sensitive {
|
||||
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||
ctx.SetBodyString("404 - Not found")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this is a static asset request (has file extension)
|
||||
hasExtension := strings.Contains(filepath.Base(cleanPath), ".")
|
||||
|
||||
// Try to read the file from embedded filesystem
|
||||
data, err := h.uiContent.ReadFile(cleanPath)
|
||||
if err != nil {
|
||||
|
||||
// If it's a static asset (has extension) and not found, return 404
|
||||
if hasExtension {
|
||||
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||
ctx.SetBodyString("404 - Static asset not found: " + requestPath)
|
||||
return
|
||||
}
|
||||
|
||||
// For routes without extensions (SPA routing), try {path}/index.html first
|
||||
if !hasExtension {
|
||||
indexPath := cleanPath + "/index.html"
|
||||
data, err = h.uiContent.ReadFile(indexPath)
|
||||
if err == nil {
|
||||
cleanPath = indexPath
|
||||
} else {
|
||||
// If that fails, serve root index.html as fallback
|
||||
data, err = h.uiContent.ReadFile("ui/index.html")
|
||||
if err != nil {
|
||||
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||
ctx.SetBodyString("404 - File not found")
|
||||
return
|
||||
}
|
||||
cleanPath = "ui/index.html"
|
||||
}
|
||||
} else {
|
||||
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||
ctx.SetBodyString("404 - File not found")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Set content type based on file extension
|
||||
ext := filepath.Ext(cleanPath)
|
||||
contentType := mime.TypeByExtension(ext)
|
||||
if contentType == "" {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
ctx.SetContentType(contentType)
|
||||
|
||||
// Set cache headers for static assets
|
||||
if strings.HasPrefix(cleanPath, "ui/assets/") {
|
||||
ctx.Response.Header.Set("Cache-Control", "public, max-age=31536000, immutable")
|
||||
} else if ext == ".html" {
|
||||
ctx.Response.Header.Set("Cache-Control", "no-cache")
|
||||
} else {
|
||||
ctx.Response.Header.Set("Cache-Control", "public, max-age=3600")
|
||||
}
|
||||
|
||||
// Send the file content
|
||||
ctx.SetBody(data)
|
||||
}
|
||||
225
transports/bifrost-http/handlers/utils.go
Normal file
225
transports/bifrost-http/handlers/utils.go
Normal file
@@ -0,0 +1,225 @@
|
||||
// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
|
||||
// This file contains common utility functions used across all handlers.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// pluginDisabledKey is a dedicated context key type for marking a plugin as disabled
|
||||
// rather than removed. Using a named type instead of a raw string follows Go best practices.
|
||||
type pluginDisabledKey struct{}
|
||||
|
||||
// PluginDisabledKey is the context key used to indicate a plugin is being disabled.
|
||||
var PluginDisabledKey pluginDisabledKey
|
||||
|
||||
// badRequestError wraps a client input validation error so that outer handlers
|
||||
// can distinguish it from internal server errors and return HTTP 400.
|
||||
type badRequestError struct{ err error }
|
||||
|
||||
func (e *badRequestError) Error() string { return e.err.Error() }
|
||||
func (e *badRequestError) Unwrap() error { return e.err }
|
||||
|
||||
// SendJSON sends a JSON response with 200 OK status
|
||||
func SendJSON(ctx *fasthttp.RequestCtx, data interface{}) {
|
||||
ctx.SetContentType("application/json")
|
||||
if err := json.NewEncoder(ctx).Encode(data); err != nil {
|
||||
logger.Warn(fmt.Sprintf("Failed to encode JSON response: %v", err))
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to encode response: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// SendJSONWithStatus sends a JSON response with a custom status code
|
||||
func SendJSONWithStatus(ctx *fasthttp.RequestCtx, data interface{}, statusCode int) {
|
||||
ctx.SetContentType("application/json")
|
||||
ctx.SetStatusCode(statusCode)
|
||||
if err := json.NewEncoder(ctx).Encode(data); err != nil {
|
||||
logger.Warn(fmt.Sprintf("Failed to encode JSON response: %v", err))
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to encode response: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// SendError sends a BifrostError response
|
||||
func SendError(ctx *fasthttp.RequestCtx, statusCode int, message string) {
|
||||
bifrostErr := &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
StatusCode: &statusCode,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: message,
|
||||
},
|
||||
}
|
||||
SendBifrostError(ctx, bifrostErr)
|
||||
}
|
||||
|
||||
// SendBifrostError sends a BifrostError response
|
||||
func SendBifrostError(ctx *fasthttp.RequestCtx, bifrostErr *schemas.BifrostError) {
|
||||
if bifrostErr.StatusCode != nil {
|
||||
ctx.SetStatusCode(*bifrostErr.StatusCode)
|
||||
} else if !bifrostErr.IsBifrostError {
|
||||
ctx.SetStatusCode(fasthttp.StatusBadRequest)
|
||||
} else {
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
}
|
||||
|
||||
ctx.SetContentType("application/json")
|
||||
if encodeErr := json.NewEncoder(ctx).Encode(bifrostErr); encodeErr != nil {
|
||||
logger.Warn(fmt.Sprintf("Failed to encode error response: %v", encodeErr))
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
ctx.SetBodyString(fmt.Sprintf("Failed to encode error response: %v", encodeErr))
|
||||
}
|
||||
}
|
||||
|
||||
// streamLargeResponseIfActive checks if large response mode was activated by the provider
|
||||
// and streams the response directly to the client. Returns true if the response was handled
|
||||
// (caller should return), false if normal response handling should continue.
|
||||
func streamLargeResponseIfActive(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext) bool {
|
||||
isLargeResponse, ok := bifrostCtx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool)
|
||||
if !ok || !isLargeResponse {
|
||||
return false
|
||||
}
|
||||
if !lib.StreamLargeResponseBody(ctx, bifrostCtx) {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Large response reader not available")
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// SendSSEError sends an error in Server-Sent Events format
|
||||
func SendSSEError(ctx *fasthttp.RequestCtx, bifrostErr *schemas.BifrostError) {
|
||||
errorJSON, err := json.Marshal(map[string]interface{}{
|
||||
"error": bifrostErr,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("failed to marshal error for SSE: %v", err)
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(ctx, "data: %s\n\n", errorJSON); err != nil {
|
||||
logger.Warn(fmt.Sprintf("Failed to write SSE error: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// IsOriginAllowed checks if the given origin is allowed based on localhost rules and configured allowed origins.
|
||||
// Localhost origins are always allowed. Additional origins can be configured in allowedOrigins.
|
||||
// Supports wildcard patterns like *.example.com to match any subdomain.
|
||||
func IsOriginAllowed(origin string, allowedOrigins []string) bool {
|
||||
// Always allow localhost origins
|
||||
if isLocalhostOrigin(origin) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check configured allowed origins
|
||||
for _, allowedOrigin := range allowedOrigins {
|
||||
// Check for exact match first
|
||||
if allowedOrigin == origin {
|
||||
return true
|
||||
}
|
||||
|
||||
if allowedOrigin == "*" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for wildcard pattern
|
||||
if strings.Contains(allowedOrigin, "*") {
|
||||
if matchesWildcardPattern(origin, allowedOrigin) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isLocalhostOrigin checks if the given origin is a localhost origin
|
||||
func isLocalhostOrigin(origin string) bool {
|
||||
return strings.HasPrefix(origin, "http://localhost:") ||
|
||||
strings.HasPrefix(origin, "https://localhost:") ||
|
||||
strings.HasPrefix(origin, "http://127.0.0.1:") ||
|
||||
strings.HasPrefix(origin, "http://0.0.0.0:") ||
|
||||
strings.HasPrefix(origin, "https://127.0.0.1:")
|
||||
}
|
||||
|
||||
// matchesWildcardPattern checks if an origin matches a wildcard pattern.
|
||||
// Supports patterns like *.example.com, https://*.example.com, or http://*.example.com
|
||||
func matchesWildcardPattern(origin string, pattern string) bool {
|
||||
// Convert wildcard pattern to regex pattern
|
||||
// Escape special regex characters except *
|
||||
regexPattern := regexp.QuoteMeta(pattern)
|
||||
// Replace escaped \* with regex pattern for subdomain matching
|
||||
// \* should match one or more characters that are not dots (to match a subdomain)
|
||||
regexPattern = strings.ReplaceAll(regexPattern, `\*`, `[^/.]+`)
|
||||
// Anchor the pattern to match the entire origin
|
||||
regexPattern = "^" + regexPattern + "$"
|
||||
|
||||
// Compile and test the regex
|
||||
re, err := regexp.Compile(regexPattern)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return re.MatchString(origin)
|
||||
}
|
||||
|
||||
// ParseModel parses a model string in the format "provider/model" or "provider/nested/model"
|
||||
// Returns the provider and full model name after the first slash
|
||||
func ParseModel(model string) (string, string, error) {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return "", "", fmt.Errorf("model cannot be empty")
|
||||
}
|
||||
|
||||
parts := strings.SplitN(model, "/", 2)
|
||||
if len(parts) < 2 {
|
||||
return "", "", fmt.Errorf("model must be in the format 'provider/model'")
|
||||
}
|
||||
|
||||
provider := strings.TrimSpace(parts[0])
|
||||
name := strings.TrimSpace(parts[1])
|
||||
if provider == "" || name == "" {
|
||||
return "", "", fmt.Errorf("model must be in the format 'provider/model' with non-empty provider and model")
|
||||
}
|
||||
return provider, name, nil
|
||||
}
|
||||
|
||||
// ClampPaginationParams applies default/max bounds to limit and offset so that
|
||||
// the handler response matches the values the store actually uses.
|
||||
func ClampPaginationParams(limit, offset int) (int, int) {
|
||||
if limit <= 0 {
|
||||
limit = 25
|
||||
} else if limit > 100 {
|
||||
limit = 100
|
||||
}
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
return limit, offset
|
||||
}
|
||||
|
||||
// fuzzyMatch checks if all characters in query appear in text in order (case-insensitive)
|
||||
// Example: "gpt4" matches "gpt-4", "gpt-4-turbo", etc.
|
||||
func fuzzyMatch(text, query string) bool {
|
||||
if query == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
text = strings.ToLower(text)
|
||||
query = strings.ToLower(query)
|
||||
|
||||
queryIndex := 0
|
||||
queryRunes := []rune(query)
|
||||
|
||||
for _, textChar := range text {
|
||||
if queryIndex < len(queryRunes) && textChar == queryRunes[queryIndex] {
|
||||
queryIndex++
|
||||
}
|
||||
}
|
||||
|
||||
return queryIndex == len(queryRunes)
|
||||
}
|
||||
1216
transports/bifrost-http/handlers/webrtc_realtime.go
Normal file
1216
transports/bifrost-http/handlers/webrtc_realtime.go
Normal file
File diff suppressed because it is too large
Load Diff
346
transports/bifrost-http/handlers/webrtc_realtime_test.go
Normal file
346
transports/bifrost-http/handlers/webrtc_realtime_test.go
Normal file
@@ -0,0 +1,346 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/kvstore"
|
||||
"github.com/maximhq/bifrost/framework/logstore"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
type testHandlerStore struct {
|
||||
kv *kvstore.Store
|
||||
}
|
||||
|
||||
func (s testHandlerStore) ShouldAllowDirectKeys() bool { return true }
|
||||
func (s testHandlerStore) GetHeaderMatcher() *lib.HeaderMatcher { return nil }
|
||||
func (s testHandlerStore) GetAvailableProviders() []schemas.ModelProvider { return nil }
|
||||
func (s testHandlerStore) GetStreamChunkInterceptor() lib.StreamChunkInterceptor {
|
||||
return nil
|
||||
}
|
||||
func (s testHandlerStore) GetAsyncJobExecutor() *logstore.AsyncJobExecutor { return nil }
|
||||
func (s testHandlerStore) GetAsyncJobResultTTL() int { return 0 }
|
||||
func (s testHandlerStore) GetKVStore() *kvstore.Store { return s.kv }
|
||||
func (s testHandlerStore) GetMCPHeaderCombinedAllowlist() schemas.WhiteList { return nil }
|
||||
|
||||
func TestResolveRealtimeSDPTarget_BaseRouteRequiresProviderPrefix(t *testing.T) {
|
||||
_, _, _, err := resolveRealtimeSDPTarget("/v1/realtime", []byte(`{"model":"gpt-4o-realtime-preview"}`))
|
||||
if err == nil {
|
||||
t.Fatal("expected provider/model validation error")
|
||||
}
|
||||
if err.Error == nil || err.Error.Message != "session.model must use provider/model on /v1 realtime routes" {
|
||||
t.Fatalf("unexpected error: %#v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRealtimeSDPTarget_BaseRouteNormalizesModel(t *testing.T) {
|
||||
provider, model, normalized, err := resolveRealtimeSDPTarget("/v1/realtime", []byte(`{"model":"openai/gpt-4o-realtime-preview","voice":"alloy"}`))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if provider != schemas.OpenAI {
|
||||
t.Fatalf("expected provider %s, got %s", schemas.OpenAI, provider)
|
||||
}
|
||||
if model != "gpt-4o-realtime-preview" {
|
||||
t.Fatalf("unexpected normalized model: %s", model)
|
||||
}
|
||||
|
||||
var root map[string]json.RawMessage
|
||||
if unmarshalErr := json.Unmarshal(normalized, &root); unmarshalErr != nil {
|
||||
t.Fatalf("failed to unmarshal normalized session: %v", unmarshalErr)
|
||||
}
|
||||
var sessionModel string
|
||||
if unmarshalErr := json.Unmarshal(root["model"], &sessionModel); unmarshalErr != nil {
|
||||
t.Fatalf("failed to unmarshal model: %v", unmarshalErr)
|
||||
}
|
||||
if sessionModel != "gpt-4o-realtime-preview" {
|
||||
t.Fatalf("unexpected marshaled model: %s", sessionModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRealtimeSDPTarget_OpenAIRouteDefaultsProvider(t *testing.T) {
|
||||
provider, model, _, err := resolveRealtimeSDPTarget("/openai/v1/realtime", []byte(`{"model":"gpt-4o-realtime-preview"}`))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if provider != schemas.OpenAI {
|
||||
t.Fatalf("expected provider %s, got %s", schemas.OpenAI, provider)
|
||||
}
|
||||
if model != "gpt-4o-realtime-preview" {
|
||||
t.Fatalf("unexpected model: %s", model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCallsWebRTCRequest_RawSDPKeepsGARoute(t *testing.T) {
|
||||
var ctx fasthttp.RequestCtx
|
||||
ctx.Request.Header.SetMethod(fasthttp.MethodPost)
|
||||
ctx.Request.SetRequestURI("/openai/v1/realtime/calls?model=gpt-realtime")
|
||||
ctx.Request.Header.SetContentType("application/sdp")
|
||||
ctx.Request.SetBodyString("v=0\r\n")
|
||||
|
||||
sdpOffer, provider, model, session, err := parseCallsWebRTCRequest(&ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if sdpOffer != "v=0\r\n" {
|
||||
t.Fatalf("unexpected sdp offer: %q", sdpOffer)
|
||||
}
|
||||
if provider != schemas.OpenAI {
|
||||
t.Fatalf("expected provider %s, got %s", schemas.OpenAI, provider)
|
||||
}
|
||||
if model != "gpt-realtime" {
|
||||
t.Fatalf("unexpected model: %s", model)
|
||||
}
|
||||
if session != nil {
|
||||
t.Fatalf("expected nil session for raw SDP /calls request, got %s", string(session))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRealtimeRelayContextCopiesValuesWithoutRequestCancellation(t *testing.T) {
|
||||
requestCtx, requestCancel := schemas.NewBifrostContextWithCancel(context.Background())
|
||||
requestCtx.SetValue(schemas.BifrostContextKeyHTTPRequestType, schemas.RealtimeRequest)
|
||||
requestCtx.SetValue(schemas.BifrostContextKeyIntegrationType, "openai")
|
||||
requestCtx.SetValue(schemas.BifrostContextKeyGovernanceVirtualKeyID, "vk_test")
|
||||
|
||||
relayCtx, relayCancel := newRealtimeRelayContext(requestCtx)
|
||||
defer relayCancel()
|
||||
|
||||
requestCancel()
|
||||
|
||||
select {
|
||||
case <-requestCtx.Done():
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected request context to be cancelled")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-relayCtx.Done():
|
||||
t.Fatal("relay context should outlive cancelled request context")
|
||||
default:
|
||||
}
|
||||
|
||||
if got := relayCtx.Value(schemas.BifrostContextKeyHTTPRequestType); got != schemas.RealtimeRequest {
|
||||
t.Fatalf("request type = %v, want %v", got, schemas.RealtimeRequest)
|
||||
}
|
||||
if got := relayCtx.Value(schemas.BifrostContextKeyIntegrationType); got != "openai" {
|
||||
t.Fatalf("integration type = %v, want %q", got, "openai")
|
||||
}
|
||||
if got := relayCtx.Value(schemas.BifrostContextKeyGovernanceVirtualKeyID); got != "vk_test" {
|
||||
t.Fatalf("virtual key id = %v, want %q", got, "vk_test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRealtimeEventPreservesExtraParams(t *testing.T) {
|
||||
event, err := schemas.ParseRealtimeEvent([]byte(`{"type":"conversation.item.truncate","item_id":"item_123","content_index":0,"audio_end_ms":640}`))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseRealtimeEvent() error = %v", err)
|
||||
}
|
||||
|
||||
var itemID string
|
||||
if err := json.Unmarshal(event.ExtraParams["item_id"], &itemID); err != nil {
|
||||
t.Fatalf("json.Unmarshal(item_id) error = %v", err)
|
||||
}
|
||||
if itemID != "item_123" {
|
||||
t.Fatalf("item_id = %q, want %q", itemID, "item_123")
|
||||
}
|
||||
|
||||
var contentIndex int
|
||||
if err := json.Unmarshal(event.ExtraParams["content_index"], &contentIndex); err != nil {
|
||||
t.Fatalf("json.Unmarshal(content_index) error = %v", err)
|
||||
}
|
||||
if contentIndex != 0 {
|
||||
t.Fatalf("content_index = %d, want 0", contentIndex)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractRealtimeBearerToken(t *testing.T) {
|
||||
var ctx fasthttp.RequestCtx
|
||||
ctx.Request.Header.Set("Authorization", "Bearer ek_test_123")
|
||||
|
||||
if got := extractRealtimeBearerToken(&ctx); got != "ek_test_123" {
|
||||
t.Fatalf("extractRealtimeBearerToken() = %q, want %q", got, "ek_test_123")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupRealtimeEphemeralKeyMappingKeepsEntryUntilTTLExpiry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, err := kvstore.New(kvstore.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("kvstore.New() error = %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
payload, err := json.Marshal(realtimeEphemeralKeyMapping{KeyID: "key_123", VirtualKey: "sk-bf-test"})
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal() error = %v", err)
|
||||
}
|
||||
if err := store.SetWithTTL(buildRealtimeEphemeralKeyMappingKey("ek_test_123"), payload, time.Minute); err != nil {
|
||||
t.Fatalf("store.SetWithTTL() error = %v", err)
|
||||
}
|
||||
|
||||
mapping, ok := lookupRealtimeEphemeralKeyMapping(store, "ek_test_123")
|
||||
if !ok {
|
||||
t.Fatal("expected mapping to be consumed")
|
||||
}
|
||||
if mapping.KeyID != "key_123" {
|
||||
t.Fatalf("mapping.KeyID = %q, want %q", mapping.KeyID, "key_123")
|
||||
}
|
||||
if mapping.VirtualKey != "sk-bf-test" {
|
||||
t.Fatalf("mapping.VirtualKey = %q, want %q", mapping.VirtualKey, "sk-bf-test")
|
||||
}
|
||||
|
||||
raw, err := store.Get(buildRealtimeEphemeralKeyMappingKey("ek_test_123"))
|
||||
if err != nil {
|
||||
t.Fatalf("expected mapping to remain until TTL expiry: %v", err)
|
||||
}
|
||||
if raw == nil {
|
||||
t.Fatal("expected mapping to remain in KV store")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupRealtimeEphemeralKeyMapping_BackwardsCompatibleStringValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, err := kvstore.New(kvstore.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("kvstore.New() error = %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
if err := store.SetWithTTL(buildRealtimeEphemeralKeyMappingKey("ek_test_legacy"), "key_legacy", time.Minute); err != nil {
|
||||
t.Fatalf("store.SetWithTTL() error = %v", err)
|
||||
}
|
||||
|
||||
mapping, ok := lookupRealtimeEphemeralKeyMapping(store, "ek_test_legacy")
|
||||
if !ok {
|
||||
t.Fatal("expected legacy mapping to be consumed")
|
||||
}
|
||||
if mapping.KeyID != "key_legacy" {
|
||||
t.Fatalf("mapping.KeyID = %q, want %q", mapping.KeyID, "key_legacy")
|
||||
}
|
||||
if mapping.VirtualKey != "" {
|
||||
t.Fatalf("mapping.VirtualKey = %q, want empty", mapping.VirtualKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebRTCRealtimeRelayCloseFinalizesActiveTurnHooks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
session := bfws.NewSession(nil)
|
||||
session.SetProviderSessionID("sess_provider_123")
|
||||
session.AddRealtimeInput("hello from user", `{"type":"conversation.item.added"}`)
|
||||
|
||||
var (
|
||||
capturedErr *schemas.BifrostError
|
||||
cleanedUp bool
|
||||
)
|
||||
session.SetRealtimeTurnHooks(&bfws.RealtimeTurnPluginState{
|
||||
RequestID: "req_realtime_123",
|
||||
StartedAt: time.Now().Add(-time.Second),
|
||||
PostHookRunner: func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
|
||||
capturedErr = err
|
||||
return result, nil
|
||||
},
|
||||
Cleanup: func() {
|
||||
cleanedUp = true
|
||||
},
|
||||
})
|
||||
|
||||
relay := &webrtcRealtimeRelay{
|
||||
session: session,
|
||||
providerKey: schemas.OpenAI,
|
||||
model: "gpt-realtime",
|
||||
}
|
||||
|
||||
relay.close()
|
||||
|
||||
if capturedErr == nil {
|
||||
t.Fatal("expected active turn to be finalized with an error on close")
|
||||
}
|
||||
if capturedErr.ExtraFields.RequestType != schemas.RealtimeRequest {
|
||||
t.Fatalf("request type = %q, want %q", capturedErr.ExtraFields.RequestType, schemas.RealtimeRequest)
|
||||
}
|
||||
if capturedErr.Error == nil || capturedErr.Error.Message != "realtime WebRTC session closed before turn completed" {
|
||||
t.Fatalf("error message = %#v, want realtime close message", capturedErr.Error)
|
||||
}
|
||||
if session.PeekRealtimeTurnHooks() != nil {
|
||||
t.Fatal("expected active realtime turn hooks to be cleared")
|
||||
}
|
||||
if !cleanedUp {
|
||||
t.Fatal("expected realtime hook cleanup to run")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRealtimeWebRTCKeys_UnmappedEphemeralTokenStaysAnonymous(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, err := kvstore.New(kvstore.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("kvstore.New() error = %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
handler := &WebRTCRealtimeHandler{
|
||||
handlerStore: testHandlerStore{kv: store},
|
||||
}
|
||||
|
||||
var ctx fasthttp.RequestCtx
|
||||
ctx.Request.Header.Set("Authorization", "Bearer ek_test_unmapped")
|
||||
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyDirectKey, schemas.Key{ID: "header-provided"})
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeySelectedKeyID, "selected")
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeySelectedKeyName, "selected-name")
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyAPIKeyID, "mapped-id")
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyAPIKeyName, "mapped-name")
|
||||
|
||||
authKey, selectedKey, err := handler.resolveRealtimeWebRTCKeys(&ctx, bifrostCtx, schemas.OpenAI, "gpt-realtime")
|
||||
if err != nil {
|
||||
t.Fatalf("resolveRealtimeWebRTCKeys() error = %v", err)
|
||||
}
|
||||
if got := authKey.Value.GetValue(); got != "ek_test_unmapped" {
|
||||
t.Fatalf("auth key value = %q, want %q", got, "ek_test_unmapped")
|
||||
}
|
||||
if selectedKey != nil {
|
||||
t.Fatalf("selectedKey = %#v, want nil", selectedKey)
|
||||
}
|
||||
if got := bifrostCtx.Value(schemas.BifrostContextKeyDirectKey); got != nil {
|
||||
t.Fatalf("direct key context = %#v, want nil", got)
|
||||
}
|
||||
if got := bifrostCtx.Value(schemas.BifrostContextKeySelectedKeyID); got != nil {
|
||||
t.Fatalf("selected key id context = %#v, want nil", got)
|
||||
}
|
||||
if got := bifrostCtx.Value(schemas.BifrostContextKeySelectedKeyName); got != nil {
|
||||
t.Fatalf("selected key name context = %#v, want nil", got)
|
||||
}
|
||||
if got := bifrostCtx.Value(schemas.BifrostContextKeyAPIKeyID); got != nil {
|
||||
t.Fatalf("api key id context = %#v, want nil", got)
|
||||
}
|
||||
if got := bifrostCtx.Value(schemas.BifrostContextKeyAPIKeyName); got != nil {
|
||||
t.Fatalf("api key name context = %#v, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyRealtimeEphemeralKeyMapping_RestoresVirtualKeyAndKeyID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
applyRealtimeEphemeralKeyMapping(bifrostCtx, realtimeEphemeralKeyMapping{
|
||||
KeyID: "key_123",
|
||||
VirtualKey: "sk-bf-test",
|
||||
})
|
||||
|
||||
if got := bifrostCtx.Value(schemas.BifrostContextKeyVirtualKey); got != "sk-bf-test" {
|
||||
t.Fatalf("virtual key context = %#v, want %q", got, "sk-bf-test")
|
||||
}
|
||||
if got := bifrostCtx.Value(schemas.BifrostContextKeyAPIKeyID); got != "key_123" {
|
||||
t.Fatalf("api key id context = %#v, want %q", got, "key_123")
|
||||
}
|
||||
}
|
||||
268
transports/bifrost-http/handlers/websocket.go
Normal file
268
transports/bifrost-http/handlers/websocket.go
Normal file
@@ -0,0 +1,268 @@
|
||||
// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
|
||||
// This file contains WebSocket handlers for real-time log streaming.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/fasthttp/websocket"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// WebSocketClient represents a connected WebSocket client with its own mutex
|
||||
type WebSocketClient struct {
|
||||
conn *websocket.Conn
|
||||
mu sync.Mutex // Per-connection mutex for thread-safe writes
|
||||
}
|
||||
|
||||
// WebSocketHandler manages WebSocket connections for real-time updates
|
||||
type WebSocketHandler struct {
|
||||
ctx context.Context
|
||||
allowedOrigins []string
|
||||
clients map[*websocket.Conn]*WebSocketClient
|
||||
mu sync.RWMutex
|
||||
stopChan chan struct{} // Channel to signal heartbeat goroutine to stop
|
||||
done chan struct{} // Channel to signal when heartbeat goroutine has stopped
|
||||
}
|
||||
|
||||
// NewWebSocketHandler creates a new WebSocket handler instance
|
||||
func NewWebSocketHandler(ctx context.Context, allowedOrigins []string) *WebSocketHandler {
|
||||
return &WebSocketHandler{
|
||||
ctx: ctx,
|
||||
allowedOrigins: allowedOrigins,
|
||||
clients: make(map[*websocket.Conn]*WebSocketClient),
|
||||
stopChan: make(chan struct{}),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers all WebSocket-related routes
|
||||
func (h *WebSocketHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
r.GET("/ws", lib.ChainMiddlewares(h.connectStream, middlewares...))
|
||||
}
|
||||
|
||||
// getUpgrader returns a WebSocket upgrader configured with the current allowed origins
|
||||
func (h *WebSocketHandler) getUpgrader() websocket.FastHTTPUpgrader {
|
||||
return websocket.FastHTTPUpgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(ctx *fasthttp.RequestCtx) bool {
|
||||
origin := string(ctx.Request.Header.Peek("Origin"))
|
||||
if origin == "" {
|
||||
// If no Origin header, check the Host header for direct connections
|
||||
host := string(ctx.Request.Header.Peek("Host"))
|
||||
return isLocalhost(host)
|
||||
}
|
||||
// Check if origin is allowed (localhost always allowed + configured origins)
|
||||
return IsOriginAllowed(origin, h.allowedOrigins)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// isLocalhost checks if the given host is localhost
|
||||
func isLocalhost(host string) bool {
|
||||
// Remove port if present
|
||||
if idx := strings.LastIndex(host, ":"); idx != -1 {
|
||||
host = host[:idx]
|
||||
}
|
||||
|
||||
// Check for localhost variations
|
||||
return host == "localhost" ||
|
||||
host == "127.0.0.1" ||
|
||||
host == "::1" ||
|
||||
host == ""
|
||||
}
|
||||
|
||||
// connectStream handles WebSocket connections for real-time streaming
|
||||
func (h *WebSocketHandler) connectStream(ctx *fasthttp.RequestCtx) {
|
||||
upgrader := h.getUpgrader()
|
||||
err := upgrader.Upgrade(ctx, func(ws *websocket.Conn) {
|
||||
// Read safety & liveness
|
||||
ws.SetReadLimit(50 << 20) // 50 MiB
|
||||
ws.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
ws.SetPongHandler(func(string) error {
|
||||
ws.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
return nil
|
||||
})
|
||||
// Create a new client with its own mutex
|
||||
client := &WebSocketClient{
|
||||
conn: ws,
|
||||
}
|
||||
|
||||
// Register new client
|
||||
h.mu.Lock()
|
||||
h.clients[ws] = client
|
||||
h.mu.Unlock()
|
||||
|
||||
// Clean up on disconnect
|
||||
defer func() {
|
||||
h.mu.Lock()
|
||||
delete(h.clients, ws)
|
||||
h.mu.Unlock()
|
||||
ws.Close()
|
||||
}()
|
||||
|
||||
// Keep connection alive and handle client messages
|
||||
// This loop continuously reads and discards incoming WebSocket messages to:
|
||||
// 1. Keep the connection alive by processing client pings and control frames
|
||||
// 2. Detect when the client disconnects by watching for close frames or errors
|
||||
// 3. Maintain proper WebSocket protocol handling without accumulating messages
|
||||
for {
|
||||
_, _, err := ws.ReadMessage()
|
||||
if err != nil {
|
||||
// Only log unexpected close errors
|
||||
if websocket.IsUnexpectedCloseError(err,
|
||||
websocket.CloseNormalClosure,
|
||||
websocket.CloseGoingAway,
|
||||
websocket.CloseAbnormalClosure,
|
||||
websocket.CloseNoStatusReceived) {
|
||||
logger.Error("websocket read error: %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logger.Error("websocket upgrade error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// sendMessageSafely sends a message to a client with proper locking and error handling
|
||||
func (h *WebSocketHandler) sendMessageSafely(client *WebSocketClient, messageType int, data []byte) error {
|
||||
client.mu.Lock()
|
||||
defer client.mu.Unlock()
|
||||
|
||||
// Set a write deadline to prevent hanging connections
|
||||
client.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
defer client.conn.SetWriteDeadline(time.Time{}) // Clear the deadline
|
||||
|
||||
err := client.conn.WriteMessage(messageType, data)
|
||||
if err != nil {
|
||||
// Remove the client from the map if write fails
|
||||
go func() {
|
||||
h.mu.Lock()
|
||||
delete(h.clients, client.conn)
|
||||
h.mu.Unlock()
|
||||
client.conn.Close()
|
||||
}()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// BroadcastUpdatesToClients sends a store update notification to all connected WebSocket clients
|
||||
// The tags parameter should match RTK Query tagTypes (e.g., "Providers", "VirtualKeys", "MCPClients")
|
||||
func (h *WebSocketHandler) BroadcastUpdatesToClients(tags []string) {
|
||||
message := struct {
|
||||
Type string `json:"type"`
|
||||
Tags []string `json:"tags"`
|
||||
}{
|
||||
Type: "store_update",
|
||||
Tags: tags,
|
||||
}
|
||||
|
||||
data, err := sonic.Marshal(message)
|
||||
if err != nil {
|
||||
logger.Error("failed to marshal store update: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
h.BroadcastMarshaledMessage(data)
|
||||
}
|
||||
|
||||
// BroadcastEvent sends a typed event to all connected WebSocket clients.
|
||||
// Any subsystem can use this to push real-time updates to the frontend.
|
||||
func (h *WebSocketHandler) BroadcastEvent(eventType string, data interface{}) {
|
||||
message := struct {
|
||||
Type string `json:"type"`
|
||||
Data interface{} `json:"data"`
|
||||
}{
|
||||
Type: eventType,
|
||||
Data: data,
|
||||
}
|
||||
|
||||
bytes, err := sonic.Marshal(message)
|
||||
if err != nil {
|
||||
logger.Error("failed to marshal event %s: %v", eventType, err)
|
||||
return
|
||||
}
|
||||
|
||||
h.BroadcastMarshaledMessage(bytes)
|
||||
}
|
||||
|
||||
// BroadcastMarshaledMessage sends an adaptive routing update to all connected WebSocket clients
|
||||
func (h *WebSocketHandler) BroadcastMarshaledMessage(data []byte) {
|
||||
// Get a snapshot of clients to avoid holding the lock during writes
|
||||
h.mu.RLock()
|
||||
clients := make([]*WebSocketClient, 0, len(h.clients))
|
||||
for _, client := range h.clients {
|
||||
clients = append(clients, client)
|
||||
}
|
||||
h.mu.RUnlock()
|
||||
|
||||
// Send message to each client safely
|
||||
for _, client := range clients {
|
||||
if err := h.sendMessageSafely(client, websocket.TextMessage, data); err != nil {
|
||||
logger.Error("failed to send message to client: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StartHeartbeat starts sending periodic heartbeat messages to keep connections alive
|
||||
func (h *WebSocketHandler) StartHeartbeat() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
go func() {
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
close(h.done)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-h.ctx.Done():
|
||||
logger.Info("got context cancel(), stopping webserver")
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Get a snapshot of clients to avoid holding the lock during writes
|
||||
h.mu.RLock()
|
||||
clients := make([]*WebSocketClient, 0, len(h.clients))
|
||||
for _, client := range h.clients {
|
||||
clients = append(clients, client)
|
||||
}
|
||||
h.mu.RUnlock()
|
||||
|
||||
// Send heartbeat to each client safely
|
||||
for _, client := range clients {
|
||||
if err := h.sendMessageSafely(client, websocket.PingMessage, nil); err != nil {
|
||||
logger.Error("failed to send heartbeat: %v", err)
|
||||
}
|
||||
}
|
||||
case <-h.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the WebSocket handler
|
||||
func (h *WebSocketHandler) Stop() {
|
||||
close(h.stopChan) // Signal heartbeat goroutine to stop
|
||||
<-h.done // Wait for heartbeat goroutine to finish
|
||||
|
||||
// Close all client connections
|
||||
h.mu.Lock()
|
||||
for _, client := range h.clients {
|
||||
client.conn.Close()
|
||||
}
|
||||
h.clients = make(map[*websocket.Conn]*WebSocketClient)
|
||||
h.mu.Unlock()
|
||||
}
|
||||
102
transports/bifrost-http/handlers/ws_ticket.go
Normal file
102
transports/bifrost-http/handlers/ws_ticket.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
wsTicketTTL = 30 * time.Second
|
||||
wsTicketCleanupHz = 60 * time.Second
|
||||
)
|
||||
|
||||
type wsTicketEntry struct {
|
||||
sessionToken string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// WSTicketStore provides short-lived, single-use tickets for WebSocket authentication.
|
||||
// Instead of putting the long-lived session token in the WS URL (visible in logs/history),
|
||||
// clients exchange their session for a 30-second one-time ticket via an authenticated endpoint.
|
||||
type WSTicketStore struct {
|
||||
mu sync.Mutex
|
||||
tickets map[string]wsTicketEntry
|
||||
done chan struct{}
|
||||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
// NewWSTicketStore creates a new ticket store and starts a background goroutine
|
||||
// that periodically purges expired tickets.
|
||||
func NewWSTicketStore() *WSTicketStore {
|
||||
s := &WSTicketStore{
|
||||
tickets: make(map[string]wsTicketEntry),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go s.cleanup()
|
||||
return s
|
||||
}
|
||||
|
||||
// Issue generates a cryptographically random ticket bound to the given session token.
|
||||
// The ticket expires after wsTicketTTL (30 seconds).
|
||||
func (s *WSTicketStore) Issue(sessionToken string) (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
ticket := hex.EncodeToString(b)
|
||||
|
||||
s.mu.Lock()
|
||||
s.tickets[ticket] = wsTicketEntry{
|
||||
sessionToken: sessionToken,
|
||||
expiresAt: time.Now().Add(wsTicketTTL),
|
||||
}
|
||||
s.mu.Unlock()
|
||||
return ticket, nil
|
||||
}
|
||||
|
||||
// Consume validates and deletes a ticket, returning the underlying session token.
|
||||
// Returns empty string if the ticket doesn't exist or has expired (single-use).
|
||||
func (s *WSTicketStore) Consume(ticket string) string {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
entry, ok := s.tickets[ticket]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
delete(s.tickets, ticket)
|
||||
if time.Now().After(entry.expiresAt) {
|
||||
return ""
|
||||
}
|
||||
return entry.sessionToken
|
||||
}
|
||||
|
||||
// Stop terminates the background cleanup goroutine.
|
||||
func (s *WSTicketStore) Stop() {
|
||||
s.stopOnce.Do(func() {
|
||||
close(s.done)
|
||||
})
|
||||
}
|
||||
|
||||
// cleanup periodically removes expired tickets to prevent unbounded memory growth.
|
||||
func (s *WSTicketStore) cleanup() {
|
||||
ticker := time.NewTicker(wsTicketCleanupHz)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
s.mu.Lock()
|
||||
for k, v := range s.tickets {
|
||||
if now.After(v.expiresAt) {
|
||||
delete(s.tickets, k)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
666
transports/bifrost-http/handlers/wsrealtime.go
Normal file
666
transports/bifrost-http/handlers/wsrealtime.go
Normal file
@@ -0,0 +1,666 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
ws "github.com/fasthttp/websocket"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/integrations"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
const (
|
||||
realtimeWSPingInterval = 15 * time.Second
|
||||
realtimeWSPongTimeout = 45 * time.Second
|
||||
realtimeWSPingWriteTimeout = 10 * time.Second
|
||||
realtimeWSWriteTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// WSRealtimeHandler handles bidirectional WebSocket proxying for the Realtime API.
|
||||
type WSRealtimeHandler struct {
|
||||
client *bifrost.Bifrost
|
||||
config *lib.Config
|
||||
handlerStore lib.HandlerStore
|
||||
pool *bfws.Pool
|
||||
sessions *bfws.SessionManager
|
||||
}
|
||||
|
||||
// NewWSRealtimeHandler creates a new Realtime WebSocket handler.
|
||||
func NewWSRealtimeHandler(client *bifrost.Bifrost, config *lib.Config, pool *bfws.Pool) *WSRealtimeHandler {
|
||||
maxConns := config.WebSocketConfig.MaxConnections
|
||||
|
||||
return &WSRealtimeHandler{
|
||||
client: client,
|
||||
config: config,
|
||||
handlerStore: config,
|
||||
pool: pool,
|
||||
sessions: bfws.NewSessionManager(maxConns),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the Realtime WebSocket endpoint at the base path and OpenAI integration paths.
|
||||
func (h *WSRealtimeHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
handler := lib.ChainMiddlewares(h.handleUpgrade, middlewares...)
|
||||
r.GET("/v1/realtime", handler)
|
||||
for _, path := range integrations.OpenAIRealtimePaths("/openai") {
|
||||
r.GET(path, handler)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WSRealtimeHandler) Close() {
|
||||
if h == nil || h.sessions == nil {
|
||||
return
|
||||
}
|
||||
h.sessions.CloseAll()
|
||||
}
|
||||
|
||||
func (h *WSRealtimeHandler) handleUpgrade(ctx *fasthttp.RequestCtx) {
|
||||
path := string(ctx.Path())
|
||||
modelParam := string(ctx.QueryArgs().Peek("model"))
|
||||
deploymentParam := string(ctx.QueryArgs().Peek("deployment"))
|
||||
auth := captureAuthHeaders(ctx)
|
||||
// OpenAI's SDK sends the API key via WebSocket subprotocol: "openai-insecure-api-key.<key>".
|
||||
// Extract it into the auth headers so downstream processing recognizes it.
|
||||
if auth.authorization == "" {
|
||||
if token := extractRealtimeSubprotocolAPIKey(ctx); token != "" {
|
||||
auth.authorization = "Bearer " + token
|
||||
}
|
||||
}
|
||||
|
||||
providerKey, model, err := resolveRealtimeTarget(path, modelParam, deploymentParam)
|
||||
if err != nil {
|
||||
upgrader := h.websocketUpgrader("")
|
||||
upgradeErr := upgrader.Upgrade(ctx, func(conn *ws.Conn) {
|
||||
defer conn.Close()
|
||||
clientConn := newRealtimeClientConn(conn)
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", err.Error()))
|
||||
})
|
||||
if upgradeErr != nil {
|
||||
logger.Warn("websocket upgrade failed for %s: %v", path, upgradeErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
provider := h.client.GetProviderByKey(providerKey)
|
||||
rtProvider, ok := provider.(schemas.RealtimeProvider)
|
||||
if provider == nil || !ok || !rtProvider.SupportsRealtimeAPI() {
|
||||
upgrader := h.websocketUpgrader("")
|
||||
upgradeErr := upgrader.Upgrade(ctx, func(conn *ws.Conn) {
|
||||
defer conn.Close()
|
||||
clientConn := newRealtimeClientConn(conn)
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "provider does not support realtime: "+string(providerKey)))
|
||||
})
|
||||
if upgradeErr != nil {
|
||||
logger.Warn("websocket upgrade failed for %s: %v", path, upgradeErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
upgrader := h.websocketUpgrader(rtProvider.RealtimeWebSocketSubprotocol())
|
||||
err = upgrader.Upgrade(ctx, func(conn *ws.Conn) {
|
||||
defer conn.Close()
|
||||
clientConn := newRealtimeClientConn(conn)
|
||||
|
||||
session, sessionErr := h.sessions.Create(conn)
|
||||
if sessionErr != nil {
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(429, "rate_limit_exceeded", sessionErr.Error()))
|
||||
return
|
||||
}
|
||||
defer h.sessions.Remove(conn)
|
||||
|
||||
h.runRealtimeSession(clientConn, session, auth, path, providerKey, model)
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn("websocket upgrade failed for %s: %v", path, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WSRealtimeHandler) websocketUpgrader(subprotocol string) ws.FastHTTPUpgrader {
|
||||
upgrader := ws.FastHTTPUpgrader{
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
CheckOrigin: func(ctx *fasthttp.RequestCtx) bool {
|
||||
origin := string(ctx.Request.Header.Peek("Origin"))
|
||||
if origin == "" {
|
||||
return true
|
||||
}
|
||||
return IsOriginAllowed(origin, h.config.ClientConfig.AllowedOrigins)
|
||||
},
|
||||
}
|
||||
if strings.TrimSpace(subprotocol) != "" {
|
||||
upgrader.Subprotocols = []string{subprotocol}
|
||||
}
|
||||
return upgrader
|
||||
}
|
||||
|
||||
func (h *WSRealtimeHandler) runRealtimeSession(
|
||||
clientConn *realtimeClientConn,
|
||||
session *bfws.Session,
|
||||
auth *authHeaders,
|
||||
path string,
|
||||
providerKey schemas.ModelProvider,
|
||||
model string,
|
||||
) {
|
||||
clientConn.startHeartbeat()
|
||||
defer clientConn.stopHeartbeat()
|
||||
|
||||
bifrostCtx, cancel := createBifrostContextFromAuth(h.handlerStore, auth)
|
||||
if bifrostCtx == nil {
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(500, "server_error", "failed to create request context"))
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
// Resolve ephemeral key mapping to restore virtual key context.
|
||||
token := extractRealtimeBearerTokenFromHeader(auth.authorization)
|
||||
if isRealtimeEphemeralToken(token) {
|
||||
mapping, ok := lookupRealtimeEphemeralKeyMapping(h.handlerStore.GetKVStore(), token)
|
||||
if ok {
|
||||
applyRealtimeEphemeralKeyMapping(bifrostCtx, mapping)
|
||||
}
|
||||
}
|
||||
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyHTTPRequestType, schemas.RealtimeRequest)
|
||||
if strings.HasPrefix(path, "/openai") {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyIntegrationType, "openai")
|
||||
}
|
||||
|
||||
provider := h.client.GetProviderByKey(providerKey)
|
||||
if provider == nil {
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "provider not found: "+string(providerKey)))
|
||||
return
|
||||
}
|
||||
|
||||
rtProvider, ok := provider.(schemas.RealtimeProvider)
|
||||
if !ok || !rtProvider.SupportsRealtimeAPI() {
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "provider does not support realtime: "+string(providerKey)))
|
||||
return
|
||||
}
|
||||
|
||||
key, err := h.client.SelectKeyForProviderRequestType(bifrostCtx, schemas.RealtimeRequest, providerKey, model)
|
||||
if err != nil {
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve model alias so the provider receives the actual model identifier.
|
||||
model = key.Aliases.Resolve(model)
|
||||
|
||||
wsURL := rtProvider.RealtimeWebSocketURL(key, model)
|
||||
upstream, err := h.pool.Get(bfws.PoolKey{
|
||||
Provider: providerKey,
|
||||
KeyID: key.ID,
|
||||
Endpoint: wsURL,
|
||||
}, rtProvider.RealtimeHeaders(key))
|
||||
if err != nil {
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", err.Error()))
|
||||
return
|
||||
}
|
||||
defer h.pool.Discard(upstream)
|
||||
|
||||
errCh := make(chan error, 2)
|
||||
go func() {
|
||||
errCh <- h.relayClientToRealtimeProvider(clientConn, session, upstream, rtProvider, bifrostCtx, providerKey, model, key)
|
||||
}()
|
||||
go func() {
|
||||
errCh <- h.relayRealtimeProviderToClient(clientConn, session, upstream, rtProvider, bifrostCtx, providerKey, model, key)
|
||||
}()
|
||||
|
||||
firstErr := <-errCh
|
||||
_ = upstream.Close()
|
||||
_ = clientConn.Close()
|
||||
secondErr := <-errCh
|
||||
|
||||
if logErr := selectRealtimeRelayError(firstErr, secondErr); logErr != nil {
|
||||
logger.Warn("realtime websocket relay ended for %s/%s on %s: %v", providerKey, model, path, logErr)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WSRealtimeHandler) relayClientToRealtimeProvider(
|
||||
clientConn *realtimeClientConn,
|
||||
session *bfws.Session,
|
||||
upstream *bfws.UpstreamConn,
|
||||
provider schemas.RealtimeProvider,
|
||||
bifrostCtx *schemas.BifrostContext,
|
||||
providerKey schemas.ModelProvider,
|
||||
model string,
|
||||
key schemas.Key,
|
||||
) error {
|
||||
for {
|
||||
messageType, message, err := clientConn.ReadMessage()
|
||||
if err != nil {
|
||||
finalizeRealtimeTurnHooksOnTransportError(
|
||||
h.client,
|
||||
bifrostCtx,
|
||||
session,
|
||||
providerKey,
|
||||
model,
|
||||
&key,
|
||||
499,
|
||||
"client_closed_request",
|
||||
"client realtime websocket disconnected before turn completed",
|
||||
)
|
||||
if isNormalWebSocketClosure(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
if messageType != ws.TextMessage {
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "realtime websocket only accepts text messages"))
|
||||
return nil
|
||||
}
|
||||
|
||||
event, err := schemas.ParseRealtimeEvent(message)
|
||||
if err != nil {
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "failed to parse realtime event JSON"))
|
||||
continue
|
||||
}
|
||||
// Extract pending tool/input summaries but defer recording until the event
|
||||
// passes validation — rejected events must not pollute session state.
|
||||
toolItemID, toolSummary := pendingRealtimeToolOutputUpdate(event)
|
||||
inputItemID, inputSummary := pendingRealtimeInputUpdate(event)
|
||||
|
||||
startsTurn := provider.ShouldStartRealtimeTurn(event)
|
||||
if startsTurn {
|
||||
if session.PeekRealtimeTurnHooks() != nil {
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "Conversation already has an active response in progress."))
|
||||
continue
|
||||
}
|
||||
if toolSummary != "" {
|
||||
session.RecordRealtimeToolOutput(toolItemID, toolSummary, string(message))
|
||||
}
|
||||
if inputSummary != "" {
|
||||
session.RecordRealtimeInput(inputItemID, inputSummary, string(message))
|
||||
}
|
||||
if bifrostErr := startRealtimeTurnHooks(h.client, bifrostCtx, session, provider, providerKey, model, &key, event.Type); bifrostErr != nil {
|
||||
clientConn.writeRealtimeError(bifrostErr)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
providerEvent, err := provider.ToProviderRealtimeEvent(event)
|
||||
if err != nil {
|
||||
if startsTurn {
|
||||
if finalizeErr := finalizeRealtimeTurnHooksWithError(
|
||||
h.client,
|
||||
bifrostCtx,
|
||||
session,
|
||||
providerKey,
|
||||
model,
|
||||
&key,
|
||||
schemas.RTEventError,
|
||||
nil,
|
||||
newRealtimeWireBifrostError(400, "invalid_request_error", err.Error()),
|
||||
); finalizeErr != nil {
|
||||
clientConn.writeRealtimeError(finalizeErr)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", err.Error()))
|
||||
continue
|
||||
}
|
||||
|
||||
// Record tool output / input only after the event passed validation.
|
||||
if !startsTurn {
|
||||
if toolSummary != "" {
|
||||
session.RecordRealtimeToolOutput(toolItemID, toolSummary, string(message))
|
||||
}
|
||||
if inputSummary != "" {
|
||||
session.RecordRealtimeInput(inputItemID, inputSummary, string(message))
|
||||
}
|
||||
}
|
||||
|
||||
if err := upstream.WriteMessage(ws.TextMessage, providerEvent); err != nil {
|
||||
finalizeRealtimeTurnHooksWithError(
|
||||
h.client,
|
||||
bifrostCtx,
|
||||
session,
|
||||
providerKey,
|
||||
model,
|
||||
&key,
|
||||
schemas.RTEventError,
|
||||
nil,
|
||||
newRealtimeWireBifrostError(502, "server_error", "failed to write realtime event upstream"),
|
||||
)
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "failed to write realtime event upstream"))
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WSRealtimeHandler) relayRealtimeProviderToClient(
|
||||
clientConn *realtimeClientConn,
|
||||
session *bfws.Session,
|
||||
upstream *bfws.UpstreamConn,
|
||||
provider schemas.RealtimeProvider,
|
||||
bifrostCtx *schemas.BifrostContext,
|
||||
providerKey schemas.ModelProvider,
|
||||
model string,
|
||||
key schemas.Key,
|
||||
) error {
|
||||
for {
|
||||
disconnectAfterWrite := false
|
||||
messageType, message, err := upstream.ReadMessage()
|
||||
if err != nil {
|
||||
finalizeRealtimeTurnHooksOnTransportError(
|
||||
h.client,
|
||||
bifrostCtx,
|
||||
session,
|
||||
providerKey,
|
||||
model,
|
||||
&key,
|
||||
502,
|
||||
"upstream_connection_error",
|
||||
"upstream realtime websocket closed before turn completed",
|
||||
)
|
||||
if isNormalWebSocketClosure(err) {
|
||||
return nil
|
||||
}
|
||||
finalizeRealtimeTurnHooksWithError(
|
||||
h.client,
|
||||
bifrostCtx,
|
||||
session,
|
||||
providerKey,
|
||||
model,
|
||||
&key,
|
||||
schemas.RTEventError,
|
||||
nil,
|
||||
newRealtimeWireBifrostError(502, "server_error", "upstream realtime websocket stream interrupted"),
|
||||
)
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "upstream realtime websocket stream interrupted"))
|
||||
return err
|
||||
}
|
||||
|
||||
if messageType == ws.TextMessage {
|
||||
event, err := provider.ToBifrostRealtimeEvent(message)
|
||||
if err != nil {
|
||||
finalizeRealtimeTurnHooksWithError(
|
||||
h.client,
|
||||
bifrostCtx,
|
||||
session,
|
||||
providerKey,
|
||||
model,
|
||||
&key,
|
||||
schemas.RTEventError,
|
||||
message,
|
||||
newRealtimeWireBifrostError(502, "server_error", "failed to translate upstream realtime event"),
|
||||
)
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "failed to translate upstream realtime event"))
|
||||
return err
|
||||
}
|
||||
if event != nil {
|
||||
if event.Session != nil && event.Session.ID != "" {
|
||||
session.SetProviderSessionID(event.Session.ID)
|
||||
}
|
||||
if event.Delta != nil && provider.ShouldAccumulateRealtimeOutput(event.Type) {
|
||||
session.AppendRealtimeOutputText(event.Delta.Text)
|
||||
session.AppendRealtimeOutputText(event.Delta.Transcript)
|
||||
}
|
||||
if provider.ShouldStartRealtimeTurn(event) && session.PeekRealtimeTurnHooks() == nil {
|
||||
if bifrostErr := startRealtimeTurnHooks(h.client, bifrostCtx, session, provider, providerKey, model, &key, event.Type); bifrostErr != nil {
|
||||
clientConn.writeRealtimeError(bifrostErr)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
if event != nil {
|
||||
inputItemID, inputSummary := pendingRealtimeInputUpdate(event)
|
||||
if !provider.ShouldForwardRealtimeEvent(event) {
|
||||
continue
|
||||
}
|
||||
if event.Type == provider.RealtimeTurnFinalEvent() {
|
||||
contentOverride := session.ConsumeRealtimeOutputText()
|
||||
if bifrostErr := finalizeRealtimeTurnHooks(h.client, bifrostCtx, session, provider, providerKey, model, &key, message, contentOverride); bifrostErr != nil {
|
||||
clientConn.writeRealtimeError(bifrostErr)
|
||||
return nil
|
||||
}
|
||||
} else if event.Error != nil {
|
||||
turnErr := newBifrostErrorFromRealtimeError(providerKey, model, message, event.Error)
|
||||
finalizeErr := finalizeRealtimeTurnHooksWithError(
|
||||
h.client,
|
||||
bifrostCtx,
|
||||
session,
|
||||
providerKey,
|
||||
model,
|
||||
&key,
|
||||
event.Type,
|
||||
message,
|
||||
turnErr,
|
||||
)
|
||||
if finalizeErr != nil {
|
||||
clientConn.writeRealtimeError(finalizeErr)
|
||||
return nil
|
||||
}
|
||||
// Defer the disconnect so the normal translated-write path
|
||||
// below still runs — otherwise terminal errors from translated
|
||||
// providers would reach the client in provider-native format.
|
||||
disconnectAfterWrite = shouldGracefullyDisconnectRealtime(turnErr)
|
||||
} else if inputSummary != "" {
|
||||
session.RecordRealtimeInput(inputItemID, inputSummary, string(message))
|
||||
}
|
||||
if len(event.RawData) == 0 {
|
||||
message, err = provider.ToProviderRealtimeEvent(event)
|
||||
if err != nil {
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "failed to encode translated realtime event"))
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := clientConn.WriteMessage(messageType, message); err != nil {
|
||||
finalizeRealtimeTurnHooksOnTransportError(
|
||||
h.client,
|
||||
bifrostCtx,
|
||||
session,
|
||||
providerKey,
|
||||
model,
|
||||
&key,
|
||||
499,
|
||||
"client_closed_request",
|
||||
"client realtime websocket disconnected before turn completed",
|
||||
)
|
||||
if isNormalWebSocketClosure(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
if disconnectAfterWrite {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func resolveRealtimeTarget(path, modelParam, deploymentParam string) (schemas.ModelProvider, string, error) {
|
||||
defaultProvider := realtimeDefaultProviderForPath(path)
|
||||
|
||||
switch {
|
||||
case strings.TrimSpace(modelParam) != "":
|
||||
provider, model := schemas.ParseModelString(strings.TrimSpace(modelParam), defaultProvider)
|
||||
if provider == "" || strings.TrimSpace(model) == "" {
|
||||
return "", "", errRealtimeModelFormat
|
||||
}
|
||||
return provider, strings.TrimSpace(model), nil
|
||||
case strings.TrimSpace(deploymentParam) != "":
|
||||
provider, model := schemas.ParseModelString(strings.TrimSpace(deploymentParam), defaultProvider)
|
||||
if provider == "" || strings.TrimSpace(model) == "" {
|
||||
return "", "", errRealtimeDeploymentFormat
|
||||
}
|
||||
return provider, strings.TrimSpace(model), nil
|
||||
default:
|
||||
return "", "", errRealtimeModelRequired
|
||||
}
|
||||
}
|
||||
|
||||
func realtimeDefaultProviderForPath(path string) schemas.ModelProvider {
|
||||
if strings.HasPrefix(path, "/openai/") {
|
||||
return schemas.OpenAI
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func isNormalWebSocketClosure(err error) bool {
|
||||
return ws.IsCloseError(err, ws.CloseNormalClosure, ws.CloseGoingAway, ws.CloseNoStatusReceived)
|
||||
}
|
||||
|
||||
func isExpectedRealtimeRelayShutdown(err error) bool {
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
if isNormalWebSocketClosure(err) || errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
|
||||
return true
|
||||
}
|
||||
// Relay teardown closes the opposite socket after the first side exits, which can
|
||||
// surface as a plain network-close read error instead of a websocket close frame.
|
||||
return strings.Contains(err.Error(), "use of closed network connection")
|
||||
}
|
||||
|
||||
func selectRealtimeRelayError(errs ...error) error {
|
||||
for _, err := range errs {
|
||||
if err != nil && !isExpectedRealtimeRelayShutdown(err) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
errRealtimeModelRequired = errorf("model or deployment query parameter is required for realtime websocket")
|
||||
errRealtimeModelFormat = errorf("model query parameter must resolve to provider/model for realtime websocket")
|
||||
errRealtimeDeploymentFormat = errorf("deployment query parameter must resolve to provider/model for realtime websocket")
|
||||
)
|
||||
|
||||
type realtimeClientConn struct {
|
||||
conn *ws.Conn
|
||||
writeMu sync.Mutex
|
||||
closeOnce sync.Once
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func newRealtimeClientConn(conn *ws.Conn) *realtimeClientConn {
|
||||
return &realtimeClientConn{
|
||||
conn: conn,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *realtimeClientConn) ReadMessage() (messageType int, p []byte, err error) {
|
||||
messageType, p, err = c.conn.ReadMessage()
|
||||
if err == nil {
|
||||
c.refreshReadDeadline()
|
||||
}
|
||||
return messageType, p, err
|
||||
}
|
||||
|
||||
func (c *realtimeClientConn) WriteMessage(messageType int, data []byte) error {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
if err := c.conn.SetWriteDeadline(time.Now().Add(realtimeWSWriteTimeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.conn.WriteMessage(messageType, data); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.conn.SetWriteDeadline(time.Time{})
|
||||
}
|
||||
|
||||
func (c *realtimeClientConn) startHeartbeat() {
|
||||
c.installPongHandler()
|
||||
c.refreshReadDeadline()
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(realtimeWSPingInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := c.writePing(); err != nil {
|
||||
_ = c.Close()
|
||||
return
|
||||
}
|
||||
case <-c.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *realtimeClientConn) stopHeartbeat() {
|
||||
c.closeDone()
|
||||
}
|
||||
|
||||
func (c *realtimeClientConn) installPongHandler() {
|
||||
c.conn.SetPongHandler(func(string) error {
|
||||
return c.refreshReadDeadline()
|
||||
})
|
||||
}
|
||||
|
||||
func (c *realtimeClientConn) refreshReadDeadline() error {
|
||||
return c.conn.SetReadDeadline(time.Now().Add(realtimeWSPongTimeout))
|
||||
}
|
||||
|
||||
func (c *realtimeClientConn) writePing() error {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
if err := c.conn.SetWriteDeadline(time.Now().Add(realtimeWSPingWriteTimeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.conn.WriteMessage(ws.PingMessage, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.conn.SetWriteDeadline(time.Time{})
|
||||
}
|
||||
|
||||
func (c *realtimeClientConn) closeDone() {
|
||||
c.closeOnce.Do(func() {
|
||||
close(c.done)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *realtimeClientConn) writeRealtimeError(bifrostErr *schemas.BifrostError) {
|
||||
payload := newRealtimeTurnErrorEventPayload(bifrostErr)
|
||||
_ = c.WriteMessage(ws.TextMessage, payload)
|
||||
}
|
||||
|
||||
func (c *realtimeClientConn) Close() error {
|
||||
c.closeDone()
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
const realtimeSubprotocolAPIKeyPrefix = "openai-insecure-api-key."
|
||||
|
||||
// extractRealtimeSubprotocolAPIKey extracts an API key from the Sec-WebSocket-Protocol
|
||||
// header. The OpenAI SDK sends: "realtime, openai-insecure-api-key.<key>".
|
||||
func extractRealtimeSubprotocolAPIKey(ctx *fasthttp.RequestCtx) string {
|
||||
header := string(ctx.Request.Header.Peek("Sec-WebSocket-Protocol"))
|
||||
for _, proto := range strings.Split(header, ",") {
|
||||
proto = strings.TrimSpace(proto)
|
||||
if strings.HasPrefix(proto, realtimeSubprotocolAPIKeyPrefix) {
|
||||
return strings.TrimPrefix(proto, realtimeSubprotocolAPIKeyPrefix)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func newRealtimeWireBifrostError(status int, code, message string) *schemas.BifrostError {
|
||||
errType := code
|
||||
return &schemas.BifrostError{
|
||||
StatusCode: &status,
|
||||
Type: &errType,
|
||||
Error: &schemas.ErrorField{
|
||||
Type: &errType,
|
||||
Code: &errType,
|
||||
Message: message,
|
||||
},
|
||||
}
|
||||
}
|
||||
702
transports/bifrost-http/handlers/wsresponses.go
Normal file
702
transports/bifrost-http/handlers/wsresponses.go
Normal file
@@ -0,0 +1,702 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/fasthttp/router"
|
||||
ws "github.com/fasthttp/websocket"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/integrations"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// wsWriter abstracts a WebSocket write target. Both *ws.Conn (pre-session)
|
||||
// and *bfws.Session (post-session, mutex-protected) satisfy this interface.
|
||||
type wsWriter interface {
|
||||
WriteMessage(messageType int, data []byte) error
|
||||
}
|
||||
|
||||
// WSResponsesHandler handles WebSocket connections for the Responses API WebSocket Mode.
|
||||
// Clients connect via `GET /v1/responses` with a WS upgrade and send `response.create` events.
|
||||
// Each event is routed through the standard Bifrost inference pipeline (PreLLMHook, key selection,
|
||||
// provider call, PostLLMHook) via the HTTP bridge, with native WS upstream as an optimization.
|
||||
type WSResponsesHandler struct {
|
||||
client *bifrost.Bifrost
|
||||
config *lib.Config
|
||||
handlerStore lib.HandlerStore
|
||||
pool *bfws.Pool
|
||||
sessions *bfws.SessionManager
|
||||
upgrader ws.FastHTTPUpgrader
|
||||
}
|
||||
|
||||
// NewWSResponsesHandler creates a new WebSocket Responses handler.
|
||||
func NewWSResponsesHandler(client *bifrost.Bifrost, config *lib.Config, pool *bfws.Pool) *WSResponsesHandler {
|
||||
maxConns := config.WebSocketConfig.MaxConnections
|
||||
|
||||
return &WSResponsesHandler{
|
||||
client: client,
|
||||
config: config,
|
||||
handlerStore: config,
|
||||
pool: pool,
|
||||
sessions: bfws.NewSessionManager(maxConns),
|
||||
upgrader: ws.FastHTTPUpgrader{
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
CheckOrigin: func(ctx *fasthttp.RequestCtx) bool {
|
||||
origin := string(ctx.Request.Header.Peek("Origin"))
|
||||
if origin == "" {
|
||||
return true
|
||||
}
|
||||
return IsOriginAllowed(origin, config.ClientConfig.AllowedOrigins)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Close gracefully shuts down all active WebSocket responses sessions.
|
||||
func (h *WSResponsesHandler) Close() {
|
||||
if h == nil || h.sessions == nil {
|
||||
return
|
||||
}
|
||||
h.sessions.CloseAll()
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the WebSocket Responses endpoint at the base path
|
||||
// and all OpenAI integration paths.
|
||||
func (h *WSResponsesHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
handler := lib.ChainMiddlewares(h.handleUpgrade, middlewares...)
|
||||
// Base path (outside integration prefix)
|
||||
r.GET("/v1/responses", handler)
|
||||
// OpenAI integration paths (/openai/v1/responses, /openai/responses, /openai/openai/responses)
|
||||
for _, path := range integrations.OpenAIWSResponsesPaths("/openai") {
|
||||
r.GET(path, handler)
|
||||
}
|
||||
}
|
||||
|
||||
// handleUpgrade upgrades the HTTP connection to WebSocket and starts the event loop.
|
||||
func (h *WSResponsesHandler) handleUpgrade(ctx *fasthttp.RequestCtx) {
|
||||
err := h.upgrader.Upgrade(ctx, func(conn *ws.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
session, sessionErr := h.sessions.Create(conn)
|
||||
if sessionErr != nil {
|
||||
writeWSError(conn, 429, "websocket_connection_limit_reached", sessionErr.Error())
|
||||
return
|
||||
}
|
||||
defer h.sessions.Remove(conn)
|
||||
|
||||
// Capture auth headers from the upgrade request for per-event context creation
|
||||
authHeaders := captureAuthHeaders(ctx)
|
||||
|
||||
h.eventLoop(conn, session, authHeaders)
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn("websocket upgrade failed for /v1/responses: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// authHeaders holds auth-related headers captured during the WS upgrade.
|
||||
type authHeaders struct {
|
||||
authorization string
|
||||
virtualKey string
|
||||
apiKey string
|
||||
googAPIKey string
|
||||
baggage string
|
||||
extraHeaders map[string]string
|
||||
}
|
||||
|
||||
// captureAuthHeaders captures the auth headers from the request.
|
||||
func captureAuthHeaders(ctx *fasthttp.RequestCtx) *authHeaders {
|
||||
ah := &authHeaders{
|
||||
authorization: string(ctx.Request.Header.Peek("Authorization")),
|
||||
virtualKey: string(ctx.Request.Header.Peek("x-bf-vk")),
|
||||
apiKey: string(ctx.Request.Header.Peek("x-api-key")),
|
||||
googAPIKey: string(ctx.Request.Header.Peek("x-goog-api-key")),
|
||||
baggage: string(ctx.Request.Header.Peek("baggage")),
|
||||
extraHeaders: make(map[string]string),
|
||||
}
|
||||
|
||||
for key, value := range ctx.Request.Header.All() {
|
||||
k := string(key)
|
||||
lk := strings.ToLower(k)
|
||||
if strings.HasPrefix(lk, "x-bf-") {
|
||||
ah.extraHeaders[k] = string(value)
|
||||
}
|
||||
}
|
||||
return ah
|
||||
}
|
||||
|
||||
// eventLoop reads events from the client WebSocket and processes them.
|
||||
func (h *WSResponsesHandler) eventLoop(conn *ws.Conn, session *bfws.Session, auth *authHeaders) {
|
||||
for {
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if ws.IsUnexpectedCloseError(err, ws.CloseGoingAway, ws.CloseNormalClosure) {
|
||||
logger.Warn("websocket read error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Parse the event type
|
||||
var envelope struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if err := sonic.Unmarshal(message, &envelope); err != nil {
|
||||
writeWSError(session, 400, "invalid_request_error", "failed to parse event JSON")
|
||||
continue
|
||||
}
|
||||
|
||||
switch schemas.WebSocketEventType(envelope.Type) {
|
||||
case schemas.WSEventResponseCreate:
|
||||
h.handleResponseCreate(session, auth, message)
|
||||
default:
|
||||
writeWSError(session, 400, "invalid_request_error", "unsupported event type: "+envelope.Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleResponseCreate processes a response.create event.
|
||||
// Strategy: try native WS upstream for providers that support it, otherwise use HTTP bridge.
|
||||
// If native WS upstream fails mid-stream, falls back to HTTP bridge.
|
||||
func (h *WSResponsesHandler) handleResponseCreate(session *bfws.Session, auth *authHeaders, message []byte) {
|
||||
var event schemas.WebSocketResponsesEvent
|
||||
|
||||
if err := sonic.Unmarshal(message, &event); err != nil {
|
||||
writeWSError(session, 400, "invalid_request_error", "failed to parse response.create event")
|
||||
return
|
||||
}
|
||||
|
||||
// Store override: default to store=true (Codex sends false by default but expects true).
|
||||
// If DisableStore is set in provider config, force store=false.
|
||||
// If client explicitly sets store, respect that value unless DisableStore overrides it.
|
||||
provider, modelName := schemas.ParseModelString(event.Model, "")
|
||||
if provider == "" || modelName == "" {
|
||||
writeWSError(session, 400, "invalid_request_error", "failed to parse model string")
|
||||
return
|
||||
}
|
||||
|
||||
if providerCfg, cfgErr := h.config.GetProviderConfigRaw(provider); cfgErr == nil &&
|
||||
providerCfg.OpenAIConfig != nil && providerCfg.OpenAIConfig.DisableStore {
|
||||
event.Store = schemas.Ptr(false)
|
||||
} else {
|
||||
event.Store = schemas.Ptr(true)
|
||||
}
|
||||
|
||||
bifrostReq, err := h.convertEventToRequest(&event)
|
||||
if err != nil {
|
||||
writeWSError(session, 400, "invalid_request_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Extract extra params (unknown fields) and forward them, matching the HTTP path behavior
|
||||
extraParams, extractErr := extractExtraParams(message, wsResponsesKnownFields)
|
||||
if extractErr == nil && len(extraParams) > 0 {
|
||||
if bifrostReq.Params == nil {
|
||||
bifrostReq.Params = &schemas.ResponsesParameters{}
|
||||
}
|
||||
bifrostReq.Params.ExtraParams = extraParams
|
||||
}
|
||||
|
||||
bifrostCtx, cancel := createBifrostContextFromAuth(h.handlerStore, auth)
|
||||
if bifrostCtx == nil {
|
||||
writeWSError(session, 500, "server_error", "failed to create request context")
|
||||
return
|
||||
}
|
||||
|
||||
// Try native WS upstream first
|
||||
if h.tryNativeWSUpstream(session, bifrostCtx, bifrostReq, message) {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
|
||||
// Fall back to HTTP bridge
|
||||
h.executeHTTPBridge(session, bifrostCtx, cancel, bifrostReq)
|
||||
}
|
||||
|
||||
// tryNativeWSUpstream attempts to forward the event to a native WS upstream connection.
|
||||
// Returns true if the event was handled (successfully or with error sent to client).
|
||||
// Returns false if the provider doesn't support WS and we should fall back to HTTP bridge.
|
||||
func (h *WSResponsesHandler) tryNativeWSUpstream(
|
||||
session *bfws.Session,
|
||||
ctx *schemas.BifrostContext,
|
||||
req *schemas.BifrostResponsesRequest,
|
||||
rawEvent []byte,
|
||||
) bool {
|
||||
provider := h.client.GetProviderByKey(req.Provider)
|
||||
if provider == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
wsProvider, ok := provider.(schemas.WebSocketCapableProvider)
|
||||
if !ok || !wsProvider.SupportsWebSocketMode() {
|
||||
return false
|
||||
}
|
||||
|
||||
key, err := h.client.SelectKeyForProviderRequestType(ctx, schemas.WebSocketResponsesRequest, req.Provider, req.Model)
|
||||
if err != nil {
|
||||
writeWSError(session, 400, "invalid_request_error", err.Error())
|
||||
return true
|
||||
}
|
||||
|
||||
wsURL := wsProvider.WebSocketResponsesURL(key)
|
||||
upstream := session.Upstream()
|
||||
|
||||
// Validate the pinned upstream matches the current request's provider/key
|
||||
if upstream != nil && !upstream.IsClosed() &&
|
||||
(upstream.Provider() != req.Provider || upstream.KeyID() != key.ID) {
|
||||
h.pool.Discard(upstream)
|
||||
session.SetUpstream(nil)
|
||||
upstream = nil
|
||||
}
|
||||
|
||||
// If no upstream connection pinned, get one from the pool or dial
|
||||
if upstream == nil || upstream.IsClosed() {
|
||||
headers := wsProvider.WebSocketHeaders(key)
|
||||
poolKey := bfws.PoolKey{
|
||||
Provider: req.Provider,
|
||||
KeyID: key.ID,
|
||||
Endpoint: wsURL,
|
||||
}
|
||||
|
||||
upstream, err = h.pool.Get(poolKey, headers)
|
||||
if err != nil {
|
||||
logger.Warn("failed to get upstream WS connection for %s: %v, falling back to HTTP bridge", req.Provider, err)
|
||||
return false
|
||||
}
|
||||
session.SetUpstream(upstream)
|
||||
}
|
||||
|
||||
// Run plugin pre-hooks before forwarding to upstream
|
||||
bifrostReq := &schemas.BifrostRequest{
|
||||
RequestType: schemas.WebSocketResponsesRequest,
|
||||
ResponsesRequest: req,
|
||||
}
|
||||
|
||||
hooks, preErr := h.client.RunStreamPreHooks(ctx, bifrostReq)
|
||||
if preErr != nil {
|
||||
writeWSBifrostError(session, preErr)
|
||||
return true
|
||||
}
|
||||
defer hooks.Cleanup()
|
||||
|
||||
// If a plugin short-circuited with a cached response, write it and skip upstream
|
||||
if hooks.ShortCircuitResponse != nil {
|
||||
writeWSShortCircuitResponse(session, hooks.ShortCircuitResponse)
|
||||
return true
|
||||
}
|
||||
|
||||
// Forward the raw event to upstream
|
||||
if err := upstream.WriteMessage(ws.TextMessage, rawEvent); err != nil {
|
||||
logger.Warn("upstream WS write failed for %s: %v, falling back to HTTP bridge", req.Provider, err)
|
||||
h.pool.Discard(upstream)
|
||||
session.SetUpstream(nil)
|
||||
return false
|
||||
}
|
||||
|
||||
// Retrieve tracer and traceID for chunk accumulation
|
||||
tracer, _ := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer)
|
||||
traceID, _ := ctx.Value(schemas.BifrostContextKeyTraceID).(string)
|
||||
|
||||
// Read response events from upstream and relay to client, running post-hooks per chunk
|
||||
forwardedAny := false
|
||||
for {
|
||||
msgType, data, readErr := upstream.ReadMessage()
|
||||
if readErr != nil {
|
||||
logger.Warn("upstream WS read failed for %s: %v, falling back to HTTP bridge", req.Provider, readErr)
|
||||
h.pool.Discard(upstream)
|
||||
session.SetUpstream(nil)
|
||||
if !forwardedAny {
|
||||
return false
|
||||
}
|
||||
writeWSError(session, 502, "upstream_connection_error", "upstream websocket stream interrupted")
|
||||
return true
|
||||
}
|
||||
|
||||
streamResp := parseUpstreamWSEvent(data, req.Provider, req.Model)
|
||||
isTerminal := streamResp != nil && isTerminalStreamType(streamResp.Type)
|
||||
|
||||
if isTerminal {
|
||||
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
||||
}
|
||||
|
||||
if streamResp != nil {
|
||||
resp := &schemas.BifrostResponse{ResponsesStreamResponse: streamResp}
|
||||
|
||||
if tracer != nil && traceID != "" {
|
||||
tracer.AddStreamingChunk(traceID, resp)
|
||||
}
|
||||
|
||||
_, postErr := hooks.PostHookRunner(ctx, resp, nil)
|
||||
if postErr != nil {
|
||||
h.pool.Discard(upstream)
|
||||
session.SetUpstream(nil)
|
||||
writeWSBifrostError(session, postErr)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if writeErr := session.WriteMessage(msgType, data); writeErr != nil {
|
||||
h.pool.Discard(upstream)
|
||||
session.SetUpstream(nil)
|
||||
return true
|
||||
}
|
||||
forwardedAny = true
|
||||
|
||||
if isTerminal {
|
||||
h.trackResponseID(session, data)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// writeWSShortCircuitResponse writes a short-circuited plugin response as WS events.
|
||||
func writeWSShortCircuitResponse(session *bfws.Session, resp *schemas.BifrostResponse) {
|
||||
if resp.ResponsesResponse != nil {
|
||||
data, err := sonic.Marshal(resp.ResponsesResponse)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if err := session.WriteMessage(ws.TextMessage, data); err != nil {
|
||||
return
|
||||
}
|
||||
if resp.ResponsesResponse.ID != nil && *resp.ResponsesResponse.ID != "" {
|
||||
session.SetLastResponseID(*resp.ResponsesResponse.ID)
|
||||
}
|
||||
} else if resp.ResponsesStreamResponse != nil {
|
||||
data, err := sonic.Marshal(resp.ResponsesStreamResponse)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
session.WriteMessage(ws.TextMessage, data)
|
||||
}
|
||||
}
|
||||
|
||||
// parseUpstreamWSEvent attempts to parse a raw upstream WS event into a BifrostResponsesStreamResponse.
|
||||
// It populates ExtraFields so downstream plugins (logging, tracing) can identify the request type.
|
||||
// Returns nil if the data cannot be parsed (non-fatal, the raw bytes are still relayed).
|
||||
func parseUpstreamWSEvent(data []byte, provider schemas.ModelProvider, model string) *schemas.BifrostResponsesStreamResponse {
|
||||
var streamResp schemas.BifrostResponsesStreamResponse
|
||||
if err := sonic.Unmarshal(data, &streamResp); err != nil {
|
||||
return nil
|
||||
}
|
||||
if streamResp.Type == "" {
|
||||
return nil
|
||||
}
|
||||
streamResp.ExtraFields.RequestType = schemas.ResponsesStreamRequest
|
||||
streamResp.ExtraFields.Provider = provider
|
||||
streamResp.ExtraFields.OriginalModelRequested = model
|
||||
return &streamResp
|
||||
}
|
||||
|
||||
// isTerminalStreamType returns true if the event type signals the end of a response stream.
|
||||
func isTerminalStreamType(t schemas.ResponsesStreamResponseType) bool {
|
||||
switch t {
|
||||
case schemas.ResponsesStreamResponseTypeCompleted,
|
||||
schemas.ResponsesStreamResponseTypeFailed,
|
||||
schemas.ResponsesStreamResponseTypeIncomplete,
|
||||
schemas.ResponsesStreamResponseTypeError:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// trackResponseID extracts and stores the response ID from terminal events.
|
||||
func (h *WSResponsesHandler) trackResponseID(session *bfws.Session, data []byte) {
|
||||
var envelope struct {
|
||||
Response struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"response"`
|
||||
}
|
||||
if err := sonic.Unmarshal(data, &envelope); err == nil && envelope.Response.ID != "" {
|
||||
session.SetLastResponseID(envelope.Response.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// convertEventToRequest converts a WebSocket response.create event to a BifrostResponsesRequest.
|
||||
func (h *WSResponsesHandler) convertEventToRequest(event *schemas.WebSocketResponsesEvent) (*schemas.BifrostResponsesRequest, error) {
|
||||
provider, modelName := schemas.ParseModelString(event.Model, "")
|
||||
if provider == "" || modelName == "" {
|
||||
return nil, errModelFormat
|
||||
}
|
||||
|
||||
var input []schemas.ResponsesMessage
|
||||
if event.Input != nil {
|
||||
// Try parsing as array first
|
||||
if err := sonic.Unmarshal(event.Input, &input); err != nil {
|
||||
// Try as string
|
||||
var inputStr string
|
||||
if strErr := sonic.Unmarshal(event.Input, &inputStr); strErr != nil {
|
||||
return nil, errInputRequired
|
||||
}
|
||||
input = []schemas.ResponsesMessage{
|
||||
{
|
||||
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
||||
Content: &schemas.ResponsesMessageContent{ContentStr: &inputStr},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(input) == 0 {
|
||||
return nil, errInputRequired
|
||||
}
|
||||
|
||||
params := &schemas.ResponsesParameters{}
|
||||
if event.Temperature != nil {
|
||||
params.Temperature = event.Temperature
|
||||
}
|
||||
if event.TopP != nil {
|
||||
params.TopP = event.TopP
|
||||
}
|
||||
if event.MaxOutputTokens != nil {
|
||||
params.MaxOutputTokens = event.MaxOutputTokens
|
||||
}
|
||||
if event.Instructions != "" {
|
||||
params.Instructions = &event.Instructions
|
||||
}
|
||||
if event.PreviousResponseID != "" {
|
||||
params.PreviousResponseID = &event.PreviousResponseID
|
||||
}
|
||||
if event.Store != nil {
|
||||
params.Store = event.Store
|
||||
}
|
||||
if event.Tools != nil {
|
||||
var tools []schemas.ResponsesTool
|
||||
if err := sonic.Unmarshal(event.Tools, &tools); err == nil {
|
||||
params.Tools = tools
|
||||
}
|
||||
}
|
||||
if event.ToolChoice != nil {
|
||||
var tc schemas.ResponsesToolChoice
|
||||
if err := sonic.Unmarshal(event.ToolChoice, &tc); err == nil {
|
||||
params.ToolChoice = &tc
|
||||
}
|
||||
}
|
||||
if event.Reasoning != nil {
|
||||
var reasoning schemas.ResponsesParametersReasoning
|
||||
if err := sonic.Unmarshal(event.Reasoning, &reasoning); err == nil {
|
||||
params.Reasoning = &reasoning
|
||||
}
|
||||
}
|
||||
if event.Text != nil {
|
||||
var text schemas.ResponsesTextConfig
|
||||
if err := sonic.Unmarshal(event.Text, &text); err == nil {
|
||||
params.Text = &text
|
||||
}
|
||||
}
|
||||
if event.Metadata != nil {
|
||||
var metadata map[string]any
|
||||
if err := sonic.Unmarshal(event.Metadata, &metadata); err == nil {
|
||||
params.Metadata = &metadata
|
||||
}
|
||||
}
|
||||
if event.Truncation != "" {
|
||||
params.Truncation = &event.Truncation
|
||||
}
|
||||
|
||||
return &schemas.BifrostResponsesRequest{
|
||||
Provider: schemas.ModelProvider(provider),
|
||||
Model: modelName,
|
||||
Input: input,
|
||||
Params: params,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// createBifrostContextFromAuth builds a BifrostContext from the auth headers captured during upgrade.
|
||||
func createBifrostContextFromAuth(handlerStore lib.HandlerStore, auth *authHeaders) (*schemas.BifrostContext, context.CancelFunc) {
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(context.Background())
|
||||
|
||||
if sessionID := lib.ParseSessionIDFromBaggage(auth.baggage); sessionID != "" {
|
||||
ctx.SetValue(schemas.BifrostContextKeyParentRequestID, sessionID)
|
||||
}
|
||||
|
||||
if auth.virtualKey != "" {
|
||||
ctx.SetValue(schemas.BifrostContextKeyVirtualKey, auth.virtualKey)
|
||||
}
|
||||
|
||||
// Handle Bearer token with sk-bf- prefix (virtual key via Authorization header)
|
||||
if auth.authorization != "" {
|
||||
if strings.HasPrefix(auth.authorization, "Bearer ") {
|
||||
token := strings.TrimPrefix(auth.authorization, "Bearer ")
|
||||
if strings.HasPrefix(token, "sk-bf-") {
|
||||
ctx.SetValue(schemas.BifrostContextKeyVirtualKey, strings.TrimPrefix(token, "sk-bf-"))
|
||||
} else if handlerStore.ShouldAllowDirectKeys() {
|
||||
key := schemas.Key{
|
||||
ID: "header-provided",
|
||||
Value: *schemas.NewEnvVar(token),
|
||||
Models: schemas.WhiteList{"*"},
|
||||
Weight: 1.0,
|
||||
}
|
||||
ctx.SetValue(schemas.BifrostContextKeyDirectKey, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
if auth.apiKey != "" {
|
||||
if strings.HasPrefix(auth.apiKey, "sk-bf-") {
|
||||
ctx.SetValue(schemas.BifrostContextKeyVirtualKey, strings.TrimPrefix(auth.apiKey, "sk-bf-"))
|
||||
} else if handlerStore.ShouldAllowDirectKeys() {
|
||||
key := schemas.Key{
|
||||
ID: "header-provided",
|
||||
Value: *schemas.NewEnvVar(auth.apiKey),
|
||||
Models: schemas.WhiteList{"*"},
|
||||
Weight: 1.0,
|
||||
}
|
||||
ctx.SetValue(schemas.BifrostContextKeyDirectKey, key)
|
||||
}
|
||||
}
|
||||
if auth.googAPIKey != "" {
|
||||
if strings.HasPrefix(auth.googAPIKey, "sk-bf-") {
|
||||
ctx.SetValue(schemas.BifrostContextKeyVirtualKey, strings.TrimPrefix(auth.googAPIKey, "sk-bf-"))
|
||||
} else if handlerStore.ShouldAllowDirectKeys() {
|
||||
key := schemas.Key{
|
||||
ID: "header-provided",
|
||||
Value: *schemas.NewEnvVar(auth.googAPIKey),
|
||||
Models: schemas.WhiteList{"*"},
|
||||
Weight: 1.0,
|
||||
}
|
||||
ctx.SetValue(schemas.BifrostContextKeyDirectKey, key)
|
||||
}
|
||||
}
|
||||
|
||||
// Forward x-bf-* headers
|
||||
for k, v := range auth.extraHeaders {
|
||||
lk := strings.ToLower(k)
|
||||
switch {
|
||||
case lk == "x-bf-vk":
|
||||
// Already handled above
|
||||
case lk == "x-bf-api-key":
|
||||
ctx.SetValue(schemas.BifrostContextKeyAPIKeyName, v)
|
||||
case strings.HasPrefix(lk, "x-bf-eh-"):
|
||||
suffix := strings.TrimPrefix(lk, "x-bf-eh-")
|
||||
existing, _ := ctx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string]string)
|
||||
if existing == nil {
|
||||
existing = make(map[string]string)
|
||||
}
|
||||
existing[suffix] = v
|
||||
ctx.SetValue(schemas.BifrostContextKeyExtraHeaders, existing)
|
||||
}
|
||||
}
|
||||
|
||||
return ctx, cancel
|
||||
}
|
||||
|
||||
// executeHTTPBridge runs the response through the existing streaming inference pipeline.
|
||||
func (h *WSResponsesHandler) executeHTTPBridge(
|
||||
session *bfws.Session,
|
||||
ctx *schemas.BifrostContext,
|
||||
cancel context.CancelFunc,
|
||||
req *schemas.BifrostResponsesRequest,
|
||||
) {
|
||||
defer cancel()
|
||||
|
||||
stream, bifrostErr := h.client.ResponsesStreamRequest(ctx, req)
|
||||
if bifrostErr != nil {
|
||||
writeWSBifrostError(session, bifrostErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Relay streaming chunks as WS messages
|
||||
for chunk := range stream {
|
||||
if chunk == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
chunkJSON, err := sonic.Marshal(chunk)
|
||||
if err != nil {
|
||||
logger.Warn("failed to marshal stream chunk: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if writeErr := session.WriteMessage(ws.TextMessage, chunkJSON); writeErr != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Track last response ID for session chaining
|
||||
if chunk.BifrostResponsesStreamResponse != nil &&
|
||||
chunk.BifrostResponsesStreamResponse.Response != nil &&
|
||||
chunk.BifrostResponsesStreamResponse.Response.ID != nil &&
|
||||
*chunk.BifrostResponsesStreamResponse.Response.ID != "" {
|
||||
session.SetLastResponseID(*chunk.BifrostResponsesStreamResponse.Response.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// writeWSError sends a JSON error event to a WebSocket write target.
|
||||
// Accepts either a raw *ws.Conn (pre-session) or a *bfws.Session (mutex-protected).
|
||||
func writeWSError(w wsWriter, status int, code, message string) {
|
||||
event := schemas.WebSocketErrorEvent{
|
||||
Type: schemas.WSEventError,
|
||||
Status: status,
|
||||
Error: &schemas.WebSocketErrorBody{
|
||||
Code: code,
|
||||
Message: message,
|
||||
},
|
||||
}
|
||||
data, err := sonic.Marshal(event)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
w.WriteMessage(ws.TextMessage, data)
|
||||
}
|
||||
|
||||
// writeWSBifrostError converts a BifrostError to a WS error event.
|
||||
func writeWSBifrostError(w wsWriter, bifrostErr *schemas.BifrostError) {
|
||||
status := 500
|
||||
if bifrostErr.StatusCode != nil && *bifrostErr.StatusCode > 0 {
|
||||
status = *bifrostErr.StatusCode
|
||||
}
|
||||
code := "server_error"
|
||||
msg := "internal server error"
|
||||
if bifrostErr.Error != nil {
|
||||
if bifrostErr.Error.Code != nil && *bifrostErr.Error.Code != "" {
|
||||
code = *bifrostErr.Error.Code
|
||||
} else if bifrostErr.Error.Type != nil && *bifrostErr.Error.Type != "" {
|
||||
code = *bifrostErr.Error.Type
|
||||
}
|
||||
if bifrostErr.Error.Message != "" {
|
||||
msg = bifrostErr.Error.Message
|
||||
}
|
||||
}
|
||||
writeWSError(w, status, code, msg)
|
||||
}
|
||||
|
||||
// wsResponsesKnownFields lists the fields explicitly handled by WebSocketResponsesEvent.
|
||||
// Anything not in this set is treated as an extra param and forwarded as-is to the provider.
|
||||
var wsResponsesKnownFields = map[string]bool{
|
||||
"type": true,
|
||||
"model": true,
|
||||
"store": true,
|
||||
"input": true,
|
||||
"instructions": true,
|
||||
"previous_response_id": true,
|
||||
"tools": true,
|
||||
"tool_choice": true,
|
||||
"temperature": true,
|
||||
"top_p": true,
|
||||
"max_output_tokens": true,
|
||||
"reasoning": true,
|
||||
"metadata": true,
|
||||
"text": true,
|
||||
"truncation": true,
|
||||
}
|
||||
|
||||
var (
|
||||
errModelFormat = errorf("model should be in provider/model format")
|
||||
errInputRequired = errorf("input is required for responses")
|
||||
)
|
||||
|
||||
func errorf(msg string) error {
|
||||
return &simpleError{msg: msg}
|
||||
}
|
||||
|
||||
type simpleError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *simpleError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
68
transports/bifrost-http/handlers/wsresponses_test.go
Normal file
68
transports/bifrost-http/handlers/wsresponses_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/kvstore"
|
||||
"github.com/maximhq/bifrost/framework/logstore"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
)
|
||||
|
||||
type testWSHandlerStore struct {
|
||||
allowDirectKeys bool
|
||||
}
|
||||
|
||||
func (s testWSHandlerStore) ShouldAllowDirectKeys() bool {
|
||||
return s.allowDirectKeys
|
||||
}
|
||||
|
||||
func (s testWSHandlerStore) GetHeaderMatcher() *lib.HeaderMatcher {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s testWSHandlerStore) GetAvailableProviders() []schemas.ModelProvider {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s testWSHandlerStore) GetStreamChunkInterceptor() lib.StreamChunkInterceptor {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s testWSHandlerStore) GetAsyncJobExecutor() *logstore.AsyncJobExecutor {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s testWSHandlerStore) GetAsyncJobResultTTL() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s testWSHandlerStore) GetKVStore() *kvstore.Store {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s testWSHandlerStore) GetMCPHeaderCombinedAllowlist() schemas.WhiteList {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestCreateBifrostContextFromAuth_BaggageSessionIDSetsGrouping(t *testing.T) {
|
||||
ctx, cancel := createBifrostContextFromAuth(testWSHandlerStore{}, &authHeaders{
|
||||
baggage: "foo=bar, session-id=rt-ws-123, baz=qux",
|
||||
})
|
||||
defer cancel()
|
||||
|
||||
if got, _ := ctx.Value(schemas.BifrostContextKeyParentRequestID).(string); got != "rt-ws-123" {
|
||||
t.Fatalf("parent request id = %q, want %q", got, "rt-ws-123")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateBifrostContextFromAuth_EmptyBaggageSessionIDIgnored(t *testing.T) {
|
||||
ctx, cancel := createBifrostContextFromAuth(testWSHandlerStore{}, &authHeaders{
|
||||
baggage: "session-id= ",
|
||||
})
|
||||
defer cancel()
|
||||
|
||||
if got := ctx.Value(schemas.BifrostContextKeyParentRequestID); got != nil {
|
||||
t.Fatalf("parent request id should be unset, got %#v", got)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user