first commit
This commit is contained in:
551
tests/governance/test_utils.go
Normal file
551
tests/governance/test_utils.go
Normal file
@@ -0,0 +1,551 @@
|
||||
package governance
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ModelCost defines the cost structure for a model
|
||||
type ModelCost struct {
|
||||
Provider string
|
||||
InputCostPerToken float64
|
||||
OutputCostPerToken float64
|
||||
MaxInputTokens int
|
||||
MaxOutputTokens int
|
||||
}
|
||||
|
||||
// TestModels defines all models used for testing
|
||||
var TestModels = map[string]ModelCost{
|
||||
"openai/gpt-4o": {
|
||||
Provider: "openai",
|
||||
InputCostPerToken: 0.0000025,
|
||||
OutputCostPerToken: 0.00001,
|
||||
MaxInputTokens: 128000,
|
||||
MaxOutputTokens: 16384,
|
||||
},
|
||||
"anthropic/claude-3-7-sonnet-20250219": {
|
||||
Provider: "anthropic",
|
||||
InputCostPerToken: 0.000003,
|
||||
OutputCostPerToken: 0.000015,
|
||||
MaxInputTokens: 200000,
|
||||
MaxOutputTokens: 128000,
|
||||
},
|
||||
"anthropic/claude-4-opus-20250514": {
|
||||
Provider: "anthropic",
|
||||
InputCostPerToken: 0.000015,
|
||||
OutputCostPerToken: 0.000075,
|
||||
MaxInputTokens: 200000,
|
||||
MaxOutputTokens: 32000,
|
||||
},
|
||||
"openrouter/anthropic/claude-3.7-sonnet": {
|
||||
Provider: "openrouter",
|
||||
InputCostPerToken: 0.000003,
|
||||
OutputCostPerToken: 0.000015,
|
||||
MaxInputTokens: 200000,
|
||||
MaxOutputTokens: 128000,
|
||||
},
|
||||
"openrouter/openai/gpt-4o": {
|
||||
Provider: "openrouter",
|
||||
InputCostPerToken: 0.0000025,
|
||||
OutputCostPerToken: 0.00001,
|
||||
MaxInputTokens: 128000,
|
||||
MaxOutputTokens: 4096,
|
||||
},
|
||||
}
|
||||
|
||||
// CalculateCost calculates the cost based on input and output tokens
|
||||
func CalculateCost(model string, inputTokens, outputTokens int) (float64, error) {
|
||||
modelInfo, ok := TestModels[model]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("unknown model: %s", model)
|
||||
}
|
||||
|
||||
inputCost := float64(inputTokens) * modelInfo.InputCostPerToken
|
||||
outputCost := float64(outputTokens) * modelInfo.OutputCostPerToken
|
||||
return inputCost + outputCost, nil
|
||||
}
|
||||
|
||||
// APIRequest represents a request to the Bifrost API
|
||||
type APIRequest struct {
|
||||
Method string
|
||||
Path string
|
||||
Body interface{}
|
||||
VKHeader *string
|
||||
}
|
||||
|
||||
// APIResponse represents a response from the Bifrost API
|
||||
type APIResponse struct {
|
||||
StatusCode int
|
||||
Body map[string]interface{}
|
||||
RawBody []byte
|
||||
}
|
||||
|
||||
// MakeRequest makes an HTTP request to the Bifrost API
|
||||
func MakeRequest(t *testing.T, req APIRequest) *APIResponse {
|
||||
client := &http.Client{}
|
||||
url := fmt.Sprintf("http://localhost:8080%s", req.Path)
|
||||
|
||||
var body io.Reader
|
||||
if req.Body != nil {
|
||||
bodyBytes, err := json.Marshal(req.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request body: %v", err)
|
||||
}
|
||||
body = bytes.NewReader(bodyBytes)
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequest(req.Method, url, body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create HTTP request: %v", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Add virtual key header if provided
|
||||
if req.VKHeader != nil {
|
||||
httpReq.Header.Set("x-bf-vk", *req.VKHeader)
|
||||
}
|
||||
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute HTTP request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
rawBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read response body: %v", err)
|
||||
}
|
||||
|
||||
var responseBody map[string]interface{}
|
||||
if len(rawBody) > 0 {
|
||||
err = json.Unmarshal(rawBody, &responseBody)
|
||||
if err != nil {
|
||||
// If unmarshaling fails, store the raw response
|
||||
responseBody = map[string]interface{}{"raw": string(rawBody)}
|
||||
}
|
||||
}
|
||||
|
||||
return &APIResponse{
|
||||
StatusCode: resp.StatusCode,
|
||||
Body: responseBody,
|
||||
RawBody: rawBody,
|
||||
}
|
||||
}
|
||||
|
||||
// MakeRequestWithCustomHeaders makes an HTTP request with custom headers
|
||||
// Use this when you need to test specific header formats (e.g., Authorization, x-api-key)
|
||||
func MakeRequestWithCustomHeaders(t *testing.T, req APIRequest, customHeaders map[string]string) *APIResponse {
|
||||
client := &http.Client{}
|
||||
url := fmt.Sprintf("http://localhost:8080%s", req.Path)
|
||||
|
||||
var body io.Reader
|
||||
if req.Body != nil {
|
||||
bodyBytes, err := json.Marshal(req.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request body: %v", err)
|
||||
}
|
||||
body = bytes.NewReader(bodyBytes)
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequest(req.Method, url, body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create HTTP request: %v", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Add custom headers
|
||||
for key, value := range customHeaders {
|
||||
httpReq.Header.Set(key, value)
|
||||
}
|
||||
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute HTTP request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
rawBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read response body: %v", err)
|
||||
}
|
||||
|
||||
var responseBody map[string]interface{}
|
||||
if len(rawBody) > 0 {
|
||||
err = json.Unmarshal(rawBody, &responseBody)
|
||||
if err != nil {
|
||||
// If unmarshaling fails, store the raw response
|
||||
responseBody = map[string]interface{}{"raw": string(rawBody)}
|
||||
}
|
||||
}
|
||||
|
||||
return &APIResponse{
|
||||
StatusCode: resp.StatusCode,
|
||||
Body: responseBody,
|
||||
RawBody: rawBody,
|
||||
}
|
||||
}
|
||||
|
||||
// generateRandomID generates a random ID for test resources
|
||||
func generateRandomID() string {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
const letters = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b := make([]byte, 8)
|
||||
for i := range b {
|
||||
b[i] = letters[rand.Intn(len(letters))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// CreateVirtualKeyRequest represents a request to create a virtual key
|
||||
type CreateVirtualKeyRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
IsActive *bool `json:"is_active,omitempty"`
|
||||
TeamID *string `json:"team_id,omitempty"`
|
||||
CustomerID *string `json:"customer_id,omitempty"`
|
||||
Budget *BudgetRequest `json:"budget,omitempty"`
|
||||
RateLimit *CreateRateLimitRequest `json:"rate_limit,omitempty"`
|
||||
ProviderConfigs []ProviderConfigRequest `json:"provider_configs,omitempty"`
|
||||
}
|
||||
|
||||
// ProviderConfigRequest represents a provider configuration for a virtual key
|
||||
type ProviderConfigRequest struct {
|
||||
ID *uint `json:"id,omitempty"`
|
||||
Provider string `json:"provider"`
|
||||
Weight *float64 `json:"weight,omitempty"`
|
||||
AllowedModels []string `json:"allowed_models,omitempty"`
|
||||
KeyIDs []string `json:"key_ids,omitempty"`
|
||||
Budget *BudgetRequest `json:"budget,omitempty"`
|
||||
RateLimit *CreateRateLimitRequest `json:"rate_limit,omitempty"`
|
||||
}
|
||||
|
||||
// float64Ptr returns a pointer to a float64 value
|
||||
func float64Ptr(v float64) *float64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// BudgetRequest represents a budget request
|
||||
type BudgetRequest struct {
|
||||
MaxLimit float64 `json:"max_limit"`
|
||||
ResetDuration string `json:"reset_duration"`
|
||||
}
|
||||
|
||||
// CreateTeamRequest represents a request to create a team
|
||||
type CreateTeamRequest struct {
|
||||
Name string `json:"name"`
|
||||
CustomerID *string `json:"customer_id,omitempty"`
|
||||
Budgets []BudgetRequest `json:"budgets,omitempty"`
|
||||
}
|
||||
|
||||
// CreateCustomerRequest represents a request to create a customer
|
||||
type CreateCustomerRequest struct {
|
||||
Name string `json:"name"`
|
||||
Budget *BudgetRequest `json:"budget,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateBudgetRequest represents a request to update a budget
|
||||
type UpdateBudgetRequest struct {
|
||||
MaxLimit *float64 `json:"max_limit,omitempty"`
|
||||
ResetDuration *string `json:"reset_duration,omitempty"`
|
||||
}
|
||||
|
||||
// CreateRateLimitRequest represents a request to create a rate limit
|
||||
type CreateRateLimitRequest struct {
|
||||
TokenMaxLimit *int64 `json:"token_max_limit,omitempty"`
|
||||
TokenResetDuration *string `json:"token_reset_duration,omitempty"`
|
||||
RequestMaxLimit *int64 `json:"request_max_limit,omitempty"`
|
||||
RequestResetDuration *string `json:"request_reset_duration,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateVirtualKeyRequest represents a request to update a virtual key
|
||||
type UpdateVirtualKeyRequest struct {
|
||||
Name *string `json:"name,omitempty"`
|
||||
TeamID *string `json:"team_id,omitempty"`
|
||||
CustomerID *string `json:"customer_id,omitempty"`
|
||||
Budget *UpdateBudgetRequest `json:"budget,omitempty"`
|
||||
RateLimit *CreateRateLimitRequest `json:"rate_limit,omitempty"`
|
||||
IsActive *bool `json:"is_active,omitempty"`
|
||||
ProviderConfigs []ProviderConfigRequest `json:"provider_configs,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateTeamRequest represents a request to update a team
|
||||
type UpdateTeamRequest struct {
|
||||
Name *string `json:"name,omitempty"`
|
||||
// Pointer-to-slice so tests can distinguish:
|
||||
// nil → field omitted (budgets untouched by server)
|
||||
// &[]BudgetRequest{} → explicit empty array (server clears all budgets)
|
||||
// &[]BudgetRequest{…} → replace with the provided budgets
|
||||
Budgets *[]BudgetRequest `json:"budgets,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateCustomerRequest represents a request to update a customer
|
||||
type UpdateCustomerRequest struct {
|
||||
Name *string `json:"name,omitempty"`
|
||||
Budget *UpdateBudgetRequest `json:"budget,omitempty"`
|
||||
}
|
||||
|
||||
// ChatCompletionRequest represents an OpenAI-compatible chat completion request
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
}
|
||||
|
||||
// ChatMessage represents a chat message in OpenAI format
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// ExtractIDFromResponse extracts the ID from a creation response
|
||||
func ExtractIDFromResponse(t *testing.T, resp *APIResponse) string {
|
||||
if resp.StatusCode >= 400 {
|
||||
t.Fatalf("Request failed with status %d: %v", resp.StatusCode, resp.Body)
|
||||
}
|
||||
|
||||
// Navigate through the response to find the ID
|
||||
data := resp.Body
|
||||
parts := []string{"virtual_key", "team", "customer"}
|
||||
for _, part := range parts {
|
||||
if val, ok := data[part]; ok {
|
||||
if nested, ok := val.(map[string]interface{}); ok {
|
||||
if id, ok := nested["id"].(string); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Fatalf("Could not extract ID from response: %v", resp.Body)
|
||||
return ""
|
||||
}
|
||||
|
||||
// CheckErrorMessage checks if the response error contains expected text
|
||||
// Returns true if error found, false otherwise. Asserts fail if status is not >= 400.
|
||||
func CheckErrorMessage(t *testing.T, resp *APIResponse, expectedText string) bool {
|
||||
if resp.StatusCode < 400 {
|
||||
t.Fatalf("Expected error response but got status %d. Response: %v", resp.StatusCode, resp.Body)
|
||||
}
|
||||
|
||||
// Check in various fields where errors might appear
|
||||
if msg, ok := resp.Body["message"].(string); ok && contains(msg, expectedText) {
|
||||
return true
|
||||
}
|
||||
|
||||
if err, ok := resp.Body["error"].(string); ok && contains(err, expectedText) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check raw body as fallback
|
||||
if contains(string(resp.RawBody), expectedText) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// contains checks if a string contains a substring (case-insensitive)
|
||||
func contains(haystack, needle string) bool {
|
||||
return strings.Contains(strings.ToLower(haystack), strings.ToLower(needle))
|
||||
}
|
||||
|
||||
// GlobalTestData stores IDs of created resources for cleanup
|
||||
type GlobalTestData struct {
|
||||
VirtualKeys []string
|
||||
Teams []string
|
||||
Customers []string
|
||||
}
|
||||
|
||||
// NewGlobalTestData creates a new test data holder
|
||||
func NewGlobalTestData() *GlobalTestData {
|
||||
return &GlobalTestData{
|
||||
VirtualKeys: make([]string, 0),
|
||||
Teams: make([]string, 0),
|
||||
Customers: make([]string, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// AddVirtualKey adds a virtual key ID to the test data
|
||||
func (g *GlobalTestData) AddVirtualKey(id string) {
|
||||
g.VirtualKeys = append(g.VirtualKeys, id)
|
||||
}
|
||||
|
||||
// AddTeam adds a team ID to the test data
|
||||
func (g *GlobalTestData) AddTeam(id string) {
|
||||
g.Teams = append(g.Teams, id)
|
||||
}
|
||||
|
||||
// AddCustomer adds a customer ID to the test data
|
||||
func (g *GlobalTestData) AddCustomer(id string) {
|
||||
g.Customers = append(g.Customers, id)
|
||||
}
|
||||
|
||||
// deleteWithRetry performs a DELETE request with retry logic
|
||||
// Retries up to 5 times if the response status is not 200 or 204
|
||||
// Delete requests don't require VK headers
|
||||
func deleteWithRetry(t *testing.T, path string, resourceType string, resourceID string) bool {
|
||||
maxRetries := 5
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
resp := MakeRequest(t, APIRequest{
|
||||
Method: "DELETE",
|
||||
Path: path,
|
||||
// Note: VKHeader is intentionally not set for DELETE requests
|
||||
})
|
||||
|
||||
// Success: 200 or 204 means the resource was deleted successfully
|
||||
if resp.StatusCode == 200 || resp.StatusCode == 204 {
|
||||
if attempt > 1 {
|
||||
t.Logf("Successfully deleted %s %s after %d attempts", resourceType, resourceID, attempt)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// 404 means resource doesn't exist, which is fine for cleanup
|
||||
if resp.StatusCode == 404 {
|
||||
t.Logf("%s %s not found (already deleted or never existed)", resourceType, resourceID)
|
||||
return true
|
||||
}
|
||||
|
||||
// If this is not the last attempt, log and retry
|
||||
if attempt < maxRetries {
|
||||
t.Logf("Attempt %d/%d: Failed to delete %s %s: status %d, retrying...", attempt, maxRetries, resourceType, resourceID, resp.StatusCode)
|
||||
// Progressive backoff: 100ms, 200ms, 300ms, 400ms
|
||||
time.Sleep(time.Duration(100*attempt) * time.Millisecond)
|
||||
} else {
|
||||
// Last attempt failed
|
||||
t.Logf("Warning: Failed to delete %s %s after %d attempts: status %d", resourceType, resourceID, maxRetries, resp.StatusCode)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Cleanup deletes all created resources
|
||||
// Retries up to 5 times for each delete operation if status is not 200 or 204
|
||||
// Delete requests don't require VK headers
|
||||
func (g *GlobalTestData) Cleanup(t *testing.T) {
|
||||
// Delete virtual keys
|
||||
for _, vkID := range g.VirtualKeys {
|
||||
deleteWithRetry(t, fmt.Sprintf("/api/governance/virtual-keys/%s", vkID), "virtual key", vkID)
|
||||
}
|
||||
|
||||
// Delete teams
|
||||
for _, teamID := range g.Teams {
|
||||
deleteWithRetry(t, fmt.Sprintf("/api/governance/teams/%s", teamID), "team", teamID)
|
||||
}
|
||||
|
||||
// Delete customers
|
||||
for _, customerID := range g.Customers {
|
||||
deleteWithRetry(t, fmt.Sprintf("/api/governance/customers/%s", customerID), "customer", customerID)
|
||||
}
|
||||
|
||||
t.Logf("Cleanup completed: deleted %d VKs, %d teams, %d customers",
|
||||
len(g.VirtualKeys), len(g.Teams), len(g.Customers))
|
||||
}
|
||||
|
||||
// WaitForCondition polls a condition function until it returns true or times out
|
||||
// Useful for waiting for async updates to propagate to in-memory store
|
||||
func WaitForCondition(t *testing.T, checkFunc func() bool, timeout time.Duration, description string) bool {
|
||||
deadline := time.Now().Add(timeout)
|
||||
attempt := 0
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
attempt++
|
||||
if checkFunc() {
|
||||
if attempt > 1 {
|
||||
t.Logf("Condition '%s' met after %d attempts", description, attempt)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Progressive backoff: start with 50ms, max 500ms
|
||||
sleepDuration := time.Duration(50*attempt) * time.Millisecond
|
||||
if sleepDuration > 500*time.Millisecond {
|
||||
sleepDuration = 500 * time.Millisecond
|
||||
}
|
||||
time.Sleep(sleepDuration)
|
||||
}
|
||||
|
||||
t.Logf("Timeout waiting for condition '%s' after %d attempts (%.1fs)", description, attempt, timeout.Seconds())
|
||||
return false
|
||||
}
|
||||
|
||||
// WaitForAPICondition makes repeated API requests until a condition is satisfied or times out
|
||||
// Useful for verifying async updates in API responses
|
||||
func WaitForAPICondition(t *testing.T, req APIRequest, condition func(*APIResponse) bool, timeout time.Duration, description string) (*APIResponse, bool) {
|
||||
deadline := time.Now().Add(timeout)
|
||||
attempt := 0
|
||||
var lastResp *APIResponse
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
attempt++
|
||||
lastResp = MakeRequest(t, req)
|
||||
|
||||
if condition(lastResp) {
|
||||
if attempt > 1 {
|
||||
t.Logf("API condition '%s' met after %d attempts", description, attempt)
|
||||
}
|
||||
return lastResp, true
|
||||
}
|
||||
|
||||
// Progressive backoff: start with 100ms, max 500ms
|
||||
sleepDuration := time.Duration(100*attempt) * time.Millisecond
|
||||
if sleepDuration > 500*time.Millisecond {
|
||||
sleepDuration = 500 * time.Millisecond
|
||||
}
|
||||
time.Sleep(sleepDuration)
|
||||
}
|
||||
|
||||
t.Logf("Timeout waiting for API condition '%s' after %d attempts (%.1fs)", description, attempt, timeout.Seconds())
|
||||
return lastResp, false
|
||||
}
|
||||
|
||||
// ParseDuration function to parse duration strings
|
||||
// Copied from framework/configstore/tables/utils.go
|
||||
func ParseDuration(duration string) (time.Duration, error) {
|
||||
if duration == "" {
|
||||
return 0, fmt.Errorf("duration is empty")
|
||||
}
|
||||
|
||||
// Handle special cases for days, weeks, months, years
|
||||
switch {
|
||||
case duration[len(duration)-1:] == "d":
|
||||
days := duration[:len(duration)-1]
|
||||
if d, err := time.ParseDuration(days + "h"); err == nil {
|
||||
return d * 24, nil
|
||||
}
|
||||
return 0, fmt.Errorf("invalid day duration: %s", duration)
|
||||
case duration[len(duration)-1:] == "w":
|
||||
weeks := duration[:len(duration)-1]
|
||||
if w, err := time.ParseDuration(weeks + "h"); err == nil {
|
||||
return w * 24 * 7, nil
|
||||
}
|
||||
return 0, fmt.Errorf("invalid week duration: %s", duration)
|
||||
case duration[len(duration)-1:] == "M":
|
||||
months := duration[:len(duration)-1]
|
||||
if m, err := time.ParseDuration(months + "h"); err == nil {
|
||||
return m * 24 * 30, nil // Approximate month as 30 days
|
||||
}
|
||||
return 0, fmt.Errorf("invalid month duration: %s", duration)
|
||||
case duration[len(duration)-1:] == "Y":
|
||||
years := duration[:len(duration)-1]
|
||||
if y, err := time.ParseDuration(years + "h"); err == nil {
|
||||
return y * 24 * 365, nil // Approximate year as 365 days
|
||||
}
|
||||
return 0, fmt.Errorf("invalid year duration: %s", duration)
|
||||
default:
|
||||
return time.ParseDuration(duration)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user