first commit
This commit is contained in:
414
transports/bifrost-http/handlers/realtime_client_secrets_test.go
Normal file
414
transports/bifrost-http/handlers/realtime_client_secrets_test.go
Normal file
@@ -0,0 +1,414 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/kvstore"
|
||||
"github.com/maximhq/bifrost/plugins/governance"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func TestResolveRealtimeClientSecretTarget(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
route schemas.RealtimeSessionRoute
|
||||
body []byte
|
||||
wantProvider schemas.ModelProvider
|
||||
wantModel string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "base route with session model",
|
||||
route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets},
|
||||
body: []byte(`{"session":{"model":"openai/gpt-4o-realtime-preview"}}`),
|
||||
wantProvider: schemas.OpenAI,
|
||||
wantModel: "gpt-4o-realtime-preview",
|
||||
},
|
||||
{
|
||||
name: "base route with top level model",
|
||||
route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/sessions", EndpointType: schemas.RealtimeSessionEndpointSessions},
|
||||
body: []byte(`{"model":"openai/gpt-4o-realtime-preview"}`),
|
||||
wantProvider: schemas.OpenAI,
|
||||
wantModel: "gpt-4o-realtime-preview",
|
||||
},
|
||||
{
|
||||
name: "openai alias uses bare model",
|
||||
route: schemas.RealtimeSessionRoute{Path: "/openai/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets, DefaultProvider: schemas.OpenAI},
|
||||
body: []byte(`{"session":{"model":"gpt-4o-realtime-preview"}}`),
|
||||
wantProvider: schemas.OpenAI,
|
||||
wantModel: "gpt-4o-realtime-preview",
|
||||
},
|
||||
{
|
||||
name: "base route rejects bare model",
|
||||
route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets},
|
||||
body: []byte(`{"session":{"model":"gpt-4o-realtime-preview"}}`),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing model",
|
||||
route: schemas.RealtimeSessionRoute{Path: "/openai/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets, DefaultProvider: schemas.OpenAI},
|
||||
body: []byte(`{"session":{}}`),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gotProvider, gotModel, _, err := resolveRealtimeClientSecretTarget(tt.route, tt.body)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("resolveRealtimeClientSecretTarget() error = %v", err)
|
||||
}
|
||||
if gotProvider != tt.wantProvider {
|
||||
t.Fatalf("provider = %q, want %q", gotProvider, tt.wantProvider)
|
||||
}
|
||||
if gotModel != tt.wantModel {
|
||||
t.Fatalf("model = %q, want %q", gotModel, tt.wantModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRealtimeClientSecretTarget_NormalizesModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
route schemas.RealtimeSessionRoute
|
||||
body string
|
||||
wantModel string // bare model expected in normalized body
|
||||
}{
|
||||
{
|
||||
name: "session.model provider prefix stripped",
|
||||
route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets},
|
||||
body: `{"session":{"model":"openai/gpt-4o-realtime-preview","voice":"alloy"}}`,
|
||||
wantModel: "gpt-4o-realtime-preview",
|
||||
},
|
||||
{
|
||||
name: "top-level model provider prefix stripped",
|
||||
route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/sessions", EndpointType: schemas.RealtimeSessionEndpointSessions},
|
||||
body: `{"model":"openai/gpt-4o-realtime-preview"}`,
|
||||
wantModel: "gpt-4o-realtime-preview",
|
||||
},
|
||||
{
|
||||
name: "bare model unchanged on alias route",
|
||||
route: schemas.RealtimeSessionRoute{Path: "/openai/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets, DefaultProvider: schemas.OpenAI},
|
||||
body: `{"session":{"model":"gpt-4o-realtime-preview"}}`,
|
||||
wantModel: "gpt-4o-realtime-preview",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, _, normalizedBody, err := resolveRealtimeClientSecretTarget(tt.route, []byte(tt.body))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
var root map[string]json.RawMessage
|
||||
if unmarshalErr := json.Unmarshal(normalizedBody, &root); unmarshalErr != nil {
|
||||
t.Fatalf("failed to unmarshal normalized body: %v", unmarshalErr)
|
||||
}
|
||||
|
||||
// Check session.model if present
|
||||
if sessionJSON, ok := root["session"]; ok {
|
||||
var session map[string]json.RawMessage
|
||||
if unmarshalErr := json.Unmarshal(sessionJSON, &session); unmarshalErr != nil {
|
||||
t.Fatalf("failed to unmarshal session: %v", unmarshalErr)
|
||||
}
|
||||
if modelJSON, ok := session["model"]; ok {
|
||||
var model string
|
||||
if unmarshalErr := json.Unmarshal(modelJSON, &model); unmarshalErr != nil {
|
||||
t.Fatalf("failed to unmarshal session.model: %v", unmarshalErr)
|
||||
}
|
||||
if model != tt.wantModel {
|
||||
t.Fatalf("session.model = %q, want %q", model, tt.wantModel)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check top-level model if present
|
||||
if modelJSON, ok := root["model"]; ok {
|
||||
var model string
|
||||
if unmarshalErr := json.Unmarshal(modelJSON, &model); unmarshalErr != nil {
|
||||
t.Fatalf("failed to unmarshal model: %v", unmarshalErr)
|
||||
}
|
||||
if model != tt.wantModel {
|
||||
t.Fatalf("model = %q, want %q", model, tt.wantModel)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRealtimeEphemeralKeyMapping(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
token, ttl, ok := parseRealtimeEphemeralKeyMapping([]byte(`{
|
||||
"value": "ek_test_123",
|
||||
"expires_at": 4102444800
|
||||
}`))
|
||||
if !ok {
|
||||
t.Fatal("expected ephemeral mapping to be parsed")
|
||||
}
|
||||
if token != "ek_test_123" {
|
||||
t.Fatalf("token = %q, want %q", token, "ek_test_123")
|
||||
}
|
||||
if ttl <= 0 {
|
||||
t.Fatalf("ttl = %v, want > 0", ttl)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRealtimeEphemeralKeyMapping_NestedFallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
token, ttl, ok := parseRealtimeEphemeralKeyMapping([]byte(`{
|
||||
"client_secret": {
|
||||
"value": "ek_test_nested",
|
||||
"expires_at": 4102444800
|
||||
}
|
||||
}`))
|
||||
if !ok {
|
||||
t.Fatal("expected nested ephemeral mapping to be parsed")
|
||||
}
|
||||
if token != "ek_test_nested" {
|
||||
t.Fatalf("token = %q, want %q", token, "ek_test_nested")
|
||||
}
|
||||
if ttl <= 0 {
|
||||
t.Fatalf("ttl = %v, want > 0", ttl)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheRealtimeEphemeralKeyMappingStoresKeyID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, err := kvstore.New(kvstore.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("kvstore.New() error = %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
body := []byte(`{
|
||||
"value": "ek_test_456",
|
||||
"expires_at": ` + "4102444800" + `
|
||||
}`)
|
||||
cacheRealtimeEphemeralKeyMapping(store, body, "key_123", "sk-bf-test")
|
||||
|
||||
raw, err := store.Get(buildRealtimeEphemeralKeyMappingKey("ek_test_456"))
|
||||
if err != nil {
|
||||
t.Fatalf("store.Get() error = %v", err)
|
||||
}
|
||||
value, ok := raw.([]byte)
|
||||
if !ok {
|
||||
t.Fatalf("cached value type = %T, want []byte", raw)
|
||||
}
|
||||
var mapping realtimeEphemeralKeyMapping
|
||||
if err := json.Unmarshal(value, &mapping); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
if mapping.KeyID != "key_123" {
|
||||
t.Fatalf("mapping.KeyID = %q, want %q", mapping.KeyID, "key_123")
|
||||
}
|
||||
if mapping.VirtualKey != "sk-bf-test" {
|
||||
t.Fatalf("mapping.VirtualKey = %q, want %q", mapping.VirtualKey, "sk-bf-test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheRealtimeEphemeralKeyMappingSkipsExpiredSecrets(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, err := kvstore.New(kvstore.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("kvstore.New() error = %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
expired := time.Now().Add(-time.Minute).Unix()
|
||||
body := fmt.Appendf(nil, `{
|
||||
"value": "ek_expired",
|
||||
"expires_at": %d
|
||||
}`, expired)
|
||||
cacheRealtimeEphemeralKeyMapping(store, body, "key_123", "")
|
||||
|
||||
if _, err := store.Get(buildRealtimeEphemeralKeyMappingKey("ek_expired")); err == nil {
|
||||
t.Fatal("expected no cached mapping for expired token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsJSONContentType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if !isJSONContentType("application/json; charset=utf-8") {
|
||||
t.Fatal("expected application/json content type to pass")
|
||||
}
|
||||
if !isJSONContentType("application/vnd.openai+json") {
|
||||
t.Fatal("expected +json content type to pass")
|
||||
}
|
||||
if isJSONContentType("text/plain") {
|
||||
t.Fatal("expected text/plain content type to fail")
|
||||
}
|
||||
}
|
||||
|
||||
type mockRealtimeMintingGovernancePlugin struct {
|
||||
err *schemas.BifrostError
|
||||
seenUserID string
|
||||
seenVirtualKey string
|
||||
seenProvider schemas.ModelProvider
|
||||
seenModel string
|
||||
evaluateCalls int
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) GetName() string {
|
||||
return governance.PluginName
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) EvaluateGovernanceRequest(ctx *schemas.BifrostContext, evaluationRequest *governance.EvaluationRequest, _ schemas.RequestType) (*governance.EvaluationResult, *schemas.BifrostError) {
|
||||
m.evaluateCalls++
|
||||
m.seenUserID = ""
|
||||
m.seenVirtualKey = ""
|
||||
m.seenProvider = ""
|
||||
m.seenModel = ""
|
||||
if evaluationRequest != nil {
|
||||
m.seenUserID = evaluationRequest.UserID
|
||||
m.seenVirtualKey = evaluationRequest.VirtualKey
|
||||
m.seenProvider = evaluationRequest.Provider
|
||||
m.seenModel = evaluationRequest.Model
|
||||
}
|
||||
if ctx != nil && m.seenVirtualKey == "" {
|
||||
m.seenVirtualKey = bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey)
|
||||
}
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return &governance.EvaluationResult{Decision: governance.DecisionAllow}, nil
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) HTTPTransportPreHook(_ *schemas.BifrostContext, _ *schemas.HTTPRequest) (*schemas.HTTPResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) HTTPTransportPostHook(_ *schemas.BifrostContext, _ *schemas.HTTPRequest, _ *schemas.HTTPResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) PreLLMHook(_ *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) {
|
||||
return req, nil, nil
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) PostLLMHook(_ *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) {
|
||||
return result, bifrostErr, nil
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) PreMCPHook(_ *schemas.BifrostContext, req *schemas.BifrostMCPRequest) (*schemas.BifrostMCPRequest, *schemas.MCPPluginShortCircuit, error) {
|
||||
return req, nil, nil
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) PostMCPHook(_ *schemas.BifrostContext, resp *schemas.BifrostMCPResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostMCPResponse, *schemas.BifrostError, error) {
|
||||
return resp, bifrostErr, nil
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) Cleanup() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) GetGovernanceStore() governance.GovernanceStore {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRealtimeClientSecretsEvaluateMintingGovernance_RequiresAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &lib.Config{}
|
||||
plugin := &mockRealtimeMintingGovernancePlugin{
|
||||
err: &schemas.BifrostError{
|
||||
Type: schemas.Ptr("virtual_key_required"),
|
||||
StatusCode: schemas.Ptr(401),
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "virtual key is required. Provide a virtual key via the x-bf-vk header.",
|
||||
},
|
||||
},
|
||||
}
|
||||
plugins := []schemas.BasePlugin{plugin}
|
||||
config.BasePlugins.Store(&plugins)
|
||||
|
||||
handler := NewRealtimeClientSecretsHandler(nil, config)
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
defer bifrostCtx.Done()
|
||||
|
||||
err := handler.evaluateMintingGovernance(bifrostCtx, schemas.OpenAI, "gpt-realtime")
|
||||
if err == nil {
|
||||
t.Fatal("expected governance error")
|
||||
}
|
||||
if err.StatusCode == nil {
|
||||
t.Fatal("expected status code")
|
||||
}
|
||||
if got, want := *err.StatusCode, fasthttp.StatusUnauthorized; got != want {
|
||||
t.Fatalf("status = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRealtimeClientSecretsEvaluateMintingGovernance_PassesContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &lib.Config{}
|
||||
plugin := &mockRealtimeMintingGovernancePlugin{}
|
||||
plugins := []schemas.BasePlugin{
|
||||
plugin,
|
||||
}
|
||||
config.BasePlugins.Store(&plugins)
|
||||
|
||||
handler := NewRealtimeClientSecretsHandler(nil, config)
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
defer bifrostCtx.Done()
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyUserID, "user_123")
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, "sk-bf-123")
|
||||
|
||||
if err := handler.evaluateMintingGovernance(bifrostCtx, schemas.OpenAI, "gpt-realtime"); err != nil {
|
||||
t.Fatalf("unexpected governance error: %v", err)
|
||||
}
|
||||
if plugin.evaluateCalls != 1 {
|
||||
t.Fatalf("evaluate calls = %d, want 1", plugin.evaluateCalls)
|
||||
}
|
||||
if plugin.seenUserID != "user_123" {
|
||||
t.Fatalf("governance user id = %q, want %q", plugin.seenUserID, "user_123")
|
||||
}
|
||||
if plugin.seenVirtualKey != "sk-bf-123" {
|
||||
t.Fatalf("virtual key = %q, want %q", plugin.seenVirtualKey, "sk-bf-123")
|
||||
}
|
||||
if plugin.seenProvider != schemas.OpenAI {
|
||||
t.Fatalf("provider = %q, want %q", plugin.seenProvider, schemas.OpenAI)
|
||||
}
|
||||
if plugin.seenModel != "gpt-realtime" {
|
||||
t.Fatalf("model = %q, want %q", plugin.seenModel, "gpt-realtime")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRealtimeClientSecretsEvaluateMintingGovernance_ContinuesWithoutGovernance(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := NewRealtimeClientSecretsHandler(nil, &lib.Config{})
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
defer bifrostCtx.Done()
|
||||
|
||||
if err := handler.evaluateMintingGovernance(bifrostCtx, schemas.OpenAI, "gpt-realtime"); err != nil {
|
||||
t.Fatalf("unexpected governance error without plugin: %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user