first commit
This commit is contained in:
337
transports/bifrost-http/handlers/governance_test.go
Normal file
337
transports/bifrost-http/handlers/governance_test.go
Normal file
@@ -0,0 +1,337 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/maximhq/bifrost/plugins/governance"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// mockGovernanceManagerForVK embeds the interface so unimplemented methods panic.
|
||||
// Only GetGovernanceData is needed for the getVirtualKeys handler path.
|
||||
type mockGovernanceManagerForVK struct {
|
||||
GovernanceManager
|
||||
}
|
||||
|
||||
func (m *mockGovernanceManagerForVK) GetGovernanceData(ctx context.Context) *governance.GovernanceData {
|
||||
return nil
|
||||
}
|
||||
|
||||
// mockConfigStoreForVK embeds the interface so unimplemented methods panic.
|
||||
// Only GetVirtualKeysPaginated is called in the non-from_memory path.
|
||||
type mockConfigStoreForVK struct {
|
||||
configstore.ConfigStore
|
||||
}
|
||||
|
||||
func (m *mockConfigStoreForVK) GetVirtualKeysPaginated(_ context.Context, _ configstore.VirtualKeyQueryParams) ([]configstoreTables.TableVirtualKey, int64, error) {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
func (m *mockConfigStoreForVK) GetVirtualKeys(_ context.Context) ([]configstoreTables.TableVirtualKey, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// TestGetVirtualKeys_PaginatedEndpoint_ResponseShape verifies the JSON response
|
||||
// from the paginated virtual keys endpoint contains all expected fields.
|
||||
func TestGetVirtualKeys_PaginatedEndpoint_ResponseShape(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := &GovernanceHandler{
|
||||
configStore: &mockConfigStoreForVK{},
|
||||
governanceManager: &mockGovernanceManagerForVK{},
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/api/governance/virtual-keys?limit=10&offset=0")
|
||||
|
||||
h.getVirtualKeys(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != 200 {
|
||||
t.Fatalf("expected status 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
|
||||
t.Fatalf("failed to parse JSON response: %v", err)
|
||||
}
|
||||
|
||||
// Assert expected fields exist with correct types
|
||||
requiredFields := []struct {
|
||||
key string
|
||||
wantType string
|
||||
}{
|
||||
{"virtual_keys", "array"},
|
||||
{"total_count", "number"},
|
||||
{"count", "number"},
|
||||
{"limit", "number"},
|
||||
{"offset", "number"},
|
||||
}
|
||||
|
||||
for _, f := range requiredFields {
|
||||
val, ok := resp[f.key]
|
||||
if !ok {
|
||||
t.Errorf("response missing required field %q", f.key)
|
||||
continue
|
||||
}
|
||||
switch f.wantType {
|
||||
case "array":
|
||||
if _, ok := val.([]interface{}); !ok {
|
||||
// nil decodes as nil, which is fine — JSON null for empty array
|
||||
if val != nil {
|
||||
t.Errorf("field %q: expected array, got %T", f.key, val)
|
||||
}
|
||||
}
|
||||
case "number":
|
||||
if _, ok := val.(float64); !ok {
|
||||
t.Errorf("field %q: expected number, got %T", f.key, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify no unexpected extra top-level fields
|
||||
allowedKeys := map[string]bool{
|
||||
"virtual_keys": true,
|
||||
"total_count": true,
|
||||
"count": true,
|
||||
"limit": true,
|
||||
"offset": true,
|
||||
}
|
||||
for key := range resp {
|
||||
if !allowedKeys[key] {
|
||||
t.Errorf("unexpected field %q in response", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetVirtualKeys_PaginatedEndpoint_QueryParams verifies query parameters are
|
||||
// parsed and reflected in the response.
|
||||
func TestGetVirtualKeys_PaginatedEndpoint_QueryParams(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := &GovernanceHandler{
|
||||
configStore: &mockConfigStoreForVK{},
|
||||
governanceManager: &mockGovernanceManagerForVK{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
uri string
|
||||
wantLimit float64
|
||||
wantOffset float64
|
||||
}{
|
||||
{
|
||||
name: "explicit limit and offset",
|
||||
uri: "/api/governance/virtual-keys?limit=10&offset=5",
|
||||
wantLimit: 10,
|
||||
wantOffset: 5,
|
||||
},
|
||||
{
|
||||
name: "no params uses defaults",
|
||||
uri: "/api/governance/virtual-keys",
|
||||
wantLimit: 0,
|
||||
wantOffset: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI(tt.uri)
|
||||
|
||||
h.getVirtualKeys(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != 200 {
|
||||
t.Fatalf("expected status 200, got %d", ctx.Response.StatusCode())
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
|
||||
t.Fatalf("failed to parse JSON: %v", err)
|
||||
}
|
||||
|
||||
if got := resp["limit"].(float64); got != tt.wantLimit {
|
||||
t.Errorf("limit: got %v, want %v", got, tt.wantLimit)
|
||||
}
|
||||
if got := resp["offset"].(float64); got != tt.wantOffset {
|
||||
t.Errorf("offset: got %v, want %v", got, tt.wantOffset)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure mockLogger satisfies schemas.Logger (already defined in middlewares_test.go
|
||||
// but we reference it here — same package, so no redeclaration needed).
|
||||
var _ schemas.Logger = (*mockLogger)(nil)
|
||||
|
||||
func TestBudgetRemovalRequestDetection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req *UpdateBudgetRequest
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil request is not removal",
|
||||
req: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty object is removal",
|
||||
req: &UpdateBudgetRequest{},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "max limit present is not removal",
|
||||
req: &UpdateBudgetRequest{MaxLimit: bifrostFloat(10)},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "reset duration only is not removal",
|
||||
req: &UpdateBudgetRequest{ResetDuration: bifrostString("1h")},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "calendar aligned only is treated as removal",
|
||||
req: &UpdateBudgetRequest{CalendarAligned: bifrostBool(true)},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := isBudgetRemovalRequest(tt.req); got != tt.want {
|
||||
t.Fatalf("isBudgetRemovalRequest() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitRemovalRequestDetection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req *UpdateRateLimitRequest
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil request is not removal",
|
||||
req: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty object is removal",
|
||||
req: &UpdateRateLimitRequest{},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "token limit present is not removal",
|
||||
req: &UpdateRateLimitRequest{TokenMaxLimit: bifrostInt64(100)},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "request limit present is not removal",
|
||||
req: &UpdateRateLimitRequest{RequestMaxLimit: bifrostInt64(10)},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "durations only is not removal",
|
||||
req: &UpdateRateLimitRequest{
|
||||
TokenResetDuration: bifrostString("1h"),
|
||||
RequestResetDuration: bifrostString("1h"),
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := isRateLimitRemovalRequest(tt.req); got != tt.want {
|
||||
t.Fatalf("isRateLimitRemovalRequest() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectProviderConfigDeleteIDs(t *testing.T) {
|
||||
budgetID := "budget-1"
|
||||
rateLimitID := "rate-limit-1"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config configstoreTables.TableVirtualKeyProviderConfig
|
||||
initialBudgetIDs []string
|
||||
initialRateIDs []string
|
||||
wantBudgetIDs []string
|
||||
wantRateIDs []string
|
||||
}{
|
||||
{
|
||||
name: "collects both IDs",
|
||||
config: configstoreTables.TableVirtualKeyProviderConfig{
|
||||
Budgets: []configstoreTables.TableBudget{{ID: budgetID}},
|
||||
RateLimitID: &rateLimitID,
|
||||
},
|
||||
wantBudgetIDs: []string{budgetID},
|
||||
wantRateIDs: []string{rateLimitID},
|
||||
},
|
||||
{
|
||||
name: "appends to existing slices",
|
||||
config: configstoreTables.TableVirtualKeyProviderConfig{
|
||||
Budgets: []configstoreTables.TableBudget{{ID: budgetID}},
|
||||
RateLimitID: &rateLimitID,
|
||||
},
|
||||
initialBudgetIDs: []string{"budget-0"},
|
||||
initialRateIDs: []string{"rate-limit-0"},
|
||||
wantBudgetIDs: []string{"budget-0", budgetID},
|
||||
wantRateIDs: []string{"rate-limit-0", rateLimitID},
|
||||
},
|
||||
{
|
||||
name: "ignores missing IDs",
|
||||
config: configstoreTables.TableVirtualKeyProviderConfig{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotBudgetIDs, gotRateIDs := collectProviderConfigDeleteIDs(tt.config, tt.initialBudgetIDs, tt.initialRateIDs)
|
||||
|
||||
if len(gotBudgetIDs) != len(tt.wantBudgetIDs) {
|
||||
t.Fatalf("budget IDs length = %d, want %d", len(gotBudgetIDs), len(tt.wantBudgetIDs))
|
||||
}
|
||||
for i := range gotBudgetIDs {
|
||||
if gotBudgetIDs[i] != tt.wantBudgetIDs[i] {
|
||||
t.Fatalf("budget IDs[%d] = %q, want %q", i, gotBudgetIDs[i], tt.wantBudgetIDs[i])
|
||||
}
|
||||
}
|
||||
|
||||
if len(gotRateIDs) != len(tt.wantRateIDs) {
|
||||
t.Fatalf("rate limit IDs length = %d, want %d", len(gotRateIDs), len(tt.wantRateIDs))
|
||||
}
|
||||
for i := range gotRateIDs {
|
||||
if gotRateIDs[i] != tt.wantRateIDs[i] {
|
||||
t.Fatalf("rate limit IDs[%d] = %q, want %q", i, gotRateIDs[i], tt.wantRateIDs[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func bifrostFloat(v float64) *float64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
func bifrostInt64(v int64) *int64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
func bifrostString(v string) *string {
|
||||
return &v
|
||||
}
|
||||
|
||||
func bifrostBool(v bool) *bool {
|
||||
return &v
|
||||
}
|
||||
Reference in New Issue
Block a user