552 lines
17 KiB
Go
552 lines
17 KiB
Go
package governance
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"math/rand"
|
|
"net/http"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// ModelCost defines the cost structure for a model
|
|
type ModelCost struct {
|
|
Provider string
|
|
InputCostPerToken float64
|
|
OutputCostPerToken float64
|
|
MaxInputTokens int
|
|
MaxOutputTokens int
|
|
}
|
|
|
|
// TestModels defines all models used for testing
|
|
var TestModels = map[string]ModelCost{
|
|
"openai/gpt-4o": {
|
|
Provider: "openai",
|
|
InputCostPerToken: 0.0000025,
|
|
OutputCostPerToken: 0.00001,
|
|
MaxInputTokens: 128000,
|
|
MaxOutputTokens: 16384,
|
|
},
|
|
"anthropic/claude-3-7-sonnet-20250219": {
|
|
Provider: "anthropic",
|
|
InputCostPerToken: 0.000003,
|
|
OutputCostPerToken: 0.000015,
|
|
MaxInputTokens: 200000,
|
|
MaxOutputTokens: 128000,
|
|
},
|
|
"anthropic/claude-4-opus-20250514": {
|
|
Provider: "anthropic",
|
|
InputCostPerToken: 0.000015,
|
|
OutputCostPerToken: 0.000075,
|
|
MaxInputTokens: 200000,
|
|
MaxOutputTokens: 32000,
|
|
},
|
|
"openrouter/anthropic/claude-3.7-sonnet": {
|
|
Provider: "openrouter",
|
|
InputCostPerToken: 0.000003,
|
|
OutputCostPerToken: 0.000015,
|
|
MaxInputTokens: 200000,
|
|
MaxOutputTokens: 128000,
|
|
},
|
|
"openrouter/openai/gpt-4o": {
|
|
Provider: "openrouter",
|
|
InputCostPerToken: 0.0000025,
|
|
OutputCostPerToken: 0.00001,
|
|
MaxInputTokens: 128000,
|
|
MaxOutputTokens: 4096,
|
|
},
|
|
}
|
|
|
|
// CalculateCost calculates the cost based on input and output tokens
|
|
func CalculateCost(model string, inputTokens, outputTokens int) (float64, error) {
|
|
modelInfo, ok := TestModels[model]
|
|
if !ok {
|
|
return 0, fmt.Errorf("unknown model: %s", model)
|
|
}
|
|
|
|
inputCost := float64(inputTokens) * modelInfo.InputCostPerToken
|
|
outputCost := float64(outputTokens) * modelInfo.OutputCostPerToken
|
|
return inputCost + outputCost, nil
|
|
}
|
|
|
|
// APIRequest represents a request to the Bifrost API
|
|
type APIRequest struct {
|
|
Method string
|
|
Path string
|
|
Body interface{}
|
|
VKHeader *string
|
|
}
|
|
|
|
// APIResponse represents a response from the Bifrost API
|
|
type APIResponse struct {
|
|
StatusCode int
|
|
Body map[string]interface{}
|
|
RawBody []byte
|
|
}
|
|
|
|
// MakeRequest makes an HTTP request to the Bifrost API
|
|
func MakeRequest(t *testing.T, req APIRequest) *APIResponse {
|
|
client := &http.Client{}
|
|
url := fmt.Sprintf("http://localhost:8080%s", req.Path)
|
|
|
|
var body io.Reader
|
|
if req.Body != nil {
|
|
bodyBytes, err := json.Marshal(req.Body)
|
|
if err != nil {
|
|
t.Fatalf("Failed to marshal request body: %v", err)
|
|
}
|
|
body = bytes.NewReader(bodyBytes)
|
|
}
|
|
|
|
httpReq, err := http.NewRequest(req.Method, url, body)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create HTTP request: %v", err)
|
|
}
|
|
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
|
|
// Add virtual key header if provided
|
|
if req.VKHeader != nil {
|
|
httpReq.Header.Set("x-bf-vk", *req.VKHeader)
|
|
}
|
|
|
|
resp, err := client.Do(httpReq)
|
|
if err != nil {
|
|
t.Fatalf("Failed to execute HTTP request: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
rawBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("Failed to read response body: %v", err)
|
|
}
|
|
|
|
var responseBody map[string]interface{}
|
|
if len(rawBody) > 0 {
|
|
err = json.Unmarshal(rawBody, &responseBody)
|
|
if err != nil {
|
|
// If unmarshaling fails, store the raw response
|
|
responseBody = map[string]interface{}{"raw": string(rawBody)}
|
|
}
|
|
}
|
|
|
|
return &APIResponse{
|
|
StatusCode: resp.StatusCode,
|
|
Body: responseBody,
|
|
RawBody: rawBody,
|
|
}
|
|
}
|
|
|
|
// MakeRequestWithCustomHeaders makes an HTTP request with custom headers
|
|
// Use this when you need to test specific header formats (e.g., Authorization, x-api-key)
|
|
func MakeRequestWithCustomHeaders(t *testing.T, req APIRequest, customHeaders map[string]string) *APIResponse {
|
|
client := &http.Client{}
|
|
url := fmt.Sprintf("http://localhost:8080%s", req.Path)
|
|
|
|
var body io.Reader
|
|
if req.Body != nil {
|
|
bodyBytes, err := json.Marshal(req.Body)
|
|
if err != nil {
|
|
t.Fatalf("Failed to marshal request body: %v", err)
|
|
}
|
|
body = bytes.NewReader(bodyBytes)
|
|
}
|
|
|
|
httpReq, err := http.NewRequest(req.Method, url, body)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create HTTP request: %v", err)
|
|
}
|
|
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
|
|
// Add custom headers
|
|
for key, value := range customHeaders {
|
|
httpReq.Header.Set(key, value)
|
|
}
|
|
|
|
resp, err := client.Do(httpReq)
|
|
if err != nil {
|
|
t.Fatalf("Failed to execute HTTP request: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
rawBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("Failed to read response body: %v", err)
|
|
}
|
|
|
|
var responseBody map[string]interface{}
|
|
if len(rawBody) > 0 {
|
|
err = json.Unmarshal(rawBody, &responseBody)
|
|
if err != nil {
|
|
// If unmarshaling fails, store the raw response
|
|
responseBody = map[string]interface{}{"raw": string(rawBody)}
|
|
}
|
|
}
|
|
|
|
return &APIResponse{
|
|
StatusCode: resp.StatusCode,
|
|
Body: responseBody,
|
|
RawBody: rawBody,
|
|
}
|
|
}
|
|
|
|
// generateRandomID generates a random ID for test resources
|
|
func generateRandomID() string {
|
|
rand.Seed(time.Now().UnixNano())
|
|
const letters = "abcdefghijklmnopqrstuvwxyz0123456789"
|
|
b := make([]byte, 8)
|
|
for i := range b {
|
|
b[i] = letters[rand.Intn(len(letters))]
|
|
}
|
|
return string(b)
|
|
}
|
|
|
|
// CreateVirtualKeyRequest represents a request to create a virtual key
|
|
type CreateVirtualKeyRequest struct {
|
|
Name string `json:"name"`
|
|
Description string `json:"description,omitempty"`
|
|
IsActive *bool `json:"is_active,omitempty"`
|
|
TeamID *string `json:"team_id,omitempty"`
|
|
CustomerID *string `json:"customer_id,omitempty"`
|
|
Budget *BudgetRequest `json:"budget,omitempty"`
|
|
RateLimit *CreateRateLimitRequest `json:"rate_limit,omitempty"`
|
|
ProviderConfigs []ProviderConfigRequest `json:"provider_configs,omitempty"`
|
|
}
|
|
|
|
// ProviderConfigRequest represents a provider configuration for a virtual key
|
|
type ProviderConfigRequest struct {
|
|
ID *uint `json:"id,omitempty"`
|
|
Provider string `json:"provider"`
|
|
Weight *float64 `json:"weight,omitempty"`
|
|
AllowedModels []string `json:"allowed_models,omitempty"`
|
|
KeyIDs []string `json:"key_ids,omitempty"`
|
|
Budget *BudgetRequest `json:"budget,omitempty"`
|
|
RateLimit *CreateRateLimitRequest `json:"rate_limit,omitempty"`
|
|
}
|
|
|
|
// float64Ptr returns a pointer to a float64 value
|
|
func float64Ptr(v float64) *float64 {
|
|
return &v
|
|
}
|
|
|
|
// BudgetRequest represents a budget request
|
|
type BudgetRequest struct {
|
|
MaxLimit float64 `json:"max_limit"`
|
|
ResetDuration string `json:"reset_duration"`
|
|
}
|
|
|
|
// CreateTeamRequest represents a request to create a team
|
|
type CreateTeamRequest struct {
|
|
Name string `json:"name"`
|
|
CustomerID *string `json:"customer_id,omitempty"`
|
|
Budgets []BudgetRequest `json:"budgets,omitempty"`
|
|
}
|
|
|
|
// CreateCustomerRequest represents a request to create a customer
|
|
type CreateCustomerRequest struct {
|
|
Name string `json:"name"`
|
|
Budget *BudgetRequest `json:"budget,omitempty"`
|
|
}
|
|
|
|
// UpdateBudgetRequest represents a request to update a budget
|
|
type UpdateBudgetRequest struct {
|
|
MaxLimit *float64 `json:"max_limit,omitempty"`
|
|
ResetDuration *string `json:"reset_duration,omitempty"`
|
|
}
|
|
|
|
// CreateRateLimitRequest represents a request to create a rate limit
|
|
type CreateRateLimitRequest struct {
|
|
TokenMaxLimit *int64 `json:"token_max_limit,omitempty"`
|
|
TokenResetDuration *string `json:"token_reset_duration,omitempty"`
|
|
RequestMaxLimit *int64 `json:"request_max_limit,omitempty"`
|
|
RequestResetDuration *string `json:"request_reset_duration,omitempty"`
|
|
}
|
|
|
|
// UpdateVirtualKeyRequest represents a request to update a virtual key
|
|
type UpdateVirtualKeyRequest struct {
|
|
Name *string `json:"name,omitempty"`
|
|
TeamID *string `json:"team_id,omitempty"`
|
|
CustomerID *string `json:"customer_id,omitempty"`
|
|
Budget *UpdateBudgetRequest `json:"budget,omitempty"`
|
|
RateLimit *CreateRateLimitRequest `json:"rate_limit,omitempty"`
|
|
IsActive *bool `json:"is_active,omitempty"`
|
|
ProviderConfigs []ProviderConfigRequest `json:"provider_configs,omitempty"`
|
|
}
|
|
|
|
// UpdateTeamRequest represents a request to update a team
|
|
type UpdateTeamRequest struct {
|
|
Name *string `json:"name,omitempty"`
|
|
// Pointer-to-slice so tests can distinguish:
|
|
// nil → field omitted (budgets untouched by server)
|
|
// &[]BudgetRequest{} → explicit empty array (server clears all budgets)
|
|
// &[]BudgetRequest{…} → replace with the provided budgets
|
|
Budgets *[]BudgetRequest `json:"budgets,omitempty"`
|
|
}
|
|
|
|
// UpdateCustomerRequest represents a request to update a customer
|
|
type UpdateCustomerRequest struct {
|
|
Name *string `json:"name,omitempty"`
|
|
Budget *UpdateBudgetRequest `json:"budget,omitempty"`
|
|
}
|
|
|
|
// ChatCompletionRequest represents an OpenAI-compatible chat completion request
|
|
type ChatCompletionRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []ChatMessage `json:"messages"`
|
|
Temperature *float64 `json:"temperature,omitempty"`
|
|
MaxTokens *int `json:"max_tokens,omitempty"`
|
|
TopP *float64 `json:"top_p,omitempty"`
|
|
}
|
|
|
|
// ChatMessage represents a chat message in OpenAI format
|
|
type ChatMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
// ExtractIDFromResponse extracts the ID from a creation response
|
|
func ExtractIDFromResponse(t *testing.T, resp *APIResponse) string {
|
|
if resp.StatusCode >= 400 {
|
|
t.Fatalf("Request failed with status %d: %v", resp.StatusCode, resp.Body)
|
|
}
|
|
|
|
// Navigate through the response to find the ID
|
|
data := resp.Body
|
|
parts := []string{"virtual_key", "team", "customer"}
|
|
for _, part := range parts {
|
|
if val, ok := data[part]; ok {
|
|
if nested, ok := val.(map[string]interface{}); ok {
|
|
if id, ok := nested["id"].(string); ok {
|
|
return id
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
t.Fatalf("Could not extract ID from response: %v", resp.Body)
|
|
return ""
|
|
}
|
|
|
|
// CheckErrorMessage checks if the response error contains expected text
|
|
// Returns true if error found, false otherwise. Asserts fail if status is not >= 400.
|
|
func CheckErrorMessage(t *testing.T, resp *APIResponse, expectedText string) bool {
|
|
if resp.StatusCode < 400 {
|
|
t.Fatalf("Expected error response but got status %d. Response: %v", resp.StatusCode, resp.Body)
|
|
}
|
|
|
|
// Check in various fields where errors might appear
|
|
if msg, ok := resp.Body["message"].(string); ok && contains(msg, expectedText) {
|
|
return true
|
|
}
|
|
|
|
if err, ok := resp.Body["error"].(string); ok && contains(err, expectedText) {
|
|
return true
|
|
}
|
|
|
|
// Check raw body as fallback
|
|
if contains(string(resp.RawBody), expectedText) {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// contains checks if a string contains a substring (case-insensitive)
|
|
func contains(haystack, needle string) bool {
|
|
return strings.Contains(strings.ToLower(haystack), strings.ToLower(needle))
|
|
}
|
|
|
|
// GlobalTestData stores IDs of created resources for cleanup
|
|
type GlobalTestData struct {
|
|
VirtualKeys []string
|
|
Teams []string
|
|
Customers []string
|
|
}
|
|
|
|
// NewGlobalTestData creates a new test data holder
|
|
func NewGlobalTestData() *GlobalTestData {
|
|
return &GlobalTestData{
|
|
VirtualKeys: make([]string, 0),
|
|
Teams: make([]string, 0),
|
|
Customers: make([]string, 0),
|
|
}
|
|
}
|
|
|
|
// AddVirtualKey adds a virtual key ID to the test data
|
|
func (g *GlobalTestData) AddVirtualKey(id string) {
|
|
g.VirtualKeys = append(g.VirtualKeys, id)
|
|
}
|
|
|
|
// AddTeam adds a team ID to the test data
|
|
func (g *GlobalTestData) AddTeam(id string) {
|
|
g.Teams = append(g.Teams, id)
|
|
}
|
|
|
|
// AddCustomer adds a customer ID to the test data
|
|
func (g *GlobalTestData) AddCustomer(id string) {
|
|
g.Customers = append(g.Customers, id)
|
|
}
|
|
|
|
// deleteWithRetry performs a DELETE request with retry logic
|
|
// Retries up to 5 times if the response status is not 200 or 204
|
|
// Delete requests don't require VK headers
|
|
func deleteWithRetry(t *testing.T, path string, resourceType string, resourceID string) bool {
|
|
maxRetries := 5
|
|
for attempt := 1; attempt <= maxRetries; attempt++ {
|
|
resp := MakeRequest(t, APIRequest{
|
|
Method: "DELETE",
|
|
Path: path,
|
|
// Note: VKHeader is intentionally not set for DELETE requests
|
|
})
|
|
|
|
// Success: 200 or 204 means the resource was deleted successfully
|
|
if resp.StatusCode == 200 || resp.StatusCode == 204 {
|
|
if attempt > 1 {
|
|
t.Logf("Successfully deleted %s %s after %d attempts", resourceType, resourceID, attempt)
|
|
}
|
|
return true
|
|
}
|
|
|
|
// 404 means resource doesn't exist, which is fine for cleanup
|
|
if resp.StatusCode == 404 {
|
|
t.Logf("%s %s not found (already deleted or never existed)", resourceType, resourceID)
|
|
return true
|
|
}
|
|
|
|
// If this is not the last attempt, log and retry
|
|
if attempt < maxRetries {
|
|
t.Logf("Attempt %d/%d: Failed to delete %s %s: status %d, retrying...", attempt, maxRetries, resourceType, resourceID, resp.StatusCode)
|
|
// Progressive backoff: 100ms, 200ms, 300ms, 400ms
|
|
time.Sleep(time.Duration(100*attempt) * time.Millisecond)
|
|
} else {
|
|
// Last attempt failed
|
|
t.Logf("Warning: Failed to delete %s %s after %d attempts: status %d", resourceType, resourceID, maxRetries, resp.StatusCode)
|
|
return false
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// Cleanup deletes all created resources
|
|
// Retries up to 5 times for each delete operation if status is not 200 or 204
|
|
// Delete requests don't require VK headers
|
|
func (g *GlobalTestData) Cleanup(t *testing.T) {
|
|
// Delete virtual keys
|
|
for _, vkID := range g.VirtualKeys {
|
|
deleteWithRetry(t, fmt.Sprintf("/api/governance/virtual-keys/%s", vkID), "virtual key", vkID)
|
|
}
|
|
|
|
// Delete teams
|
|
for _, teamID := range g.Teams {
|
|
deleteWithRetry(t, fmt.Sprintf("/api/governance/teams/%s", teamID), "team", teamID)
|
|
}
|
|
|
|
// Delete customers
|
|
for _, customerID := range g.Customers {
|
|
deleteWithRetry(t, fmt.Sprintf("/api/governance/customers/%s", customerID), "customer", customerID)
|
|
}
|
|
|
|
t.Logf("Cleanup completed: deleted %d VKs, %d teams, %d customers",
|
|
len(g.VirtualKeys), len(g.Teams), len(g.Customers))
|
|
}
|
|
|
|
// WaitForCondition polls a condition function until it returns true or times out
|
|
// Useful for waiting for async updates to propagate to in-memory store
|
|
func WaitForCondition(t *testing.T, checkFunc func() bool, timeout time.Duration, description string) bool {
|
|
deadline := time.Now().Add(timeout)
|
|
attempt := 0
|
|
|
|
for time.Now().Before(deadline) {
|
|
attempt++
|
|
if checkFunc() {
|
|
if attempt > 1 {
|
|
t.Logf("Condition '%s' met after %d attempts", description, attempt)
|
|
}
|
|
return true
|
|
}
|
|
|
|
// Progressive backoff: start with 50ms, max 500ms
|
|
sleepDuration := time.Duration(50*attempt) * time.Millisecond
|
|
if sleepDuration > 500*time.Millisecond {
|
|
sleepDuration = 500 * time.Millisecond
|
|
}
|
|
time.Sleep(sleepDuration)
|
|
}
|
|
|
|
t.Logf("Timeout waiting for condition '%s' after %d attempts (%.1fs)", description, attempt, timeout.Seconds())
|
|
return false
|
|
}
|
|
|
|
// WaitForAPICondition makes repeated API requests until a condition is satisfied or times out
|
|
// Useful for verifying async updates in API responses
|
|
func WaitForAPICondition(t *testing.T, req APIRequest, condition func(*APIResponse) bool, timeout time.Duration, description string) (*APIResponse, bool) {
|
|
deadline := time.Now().Add(timeout)
|
|
attempt := 0
|
|
var lastResp *APIResponse
|
|
|
|
for time.Now().Before(deadline) {
|
|
attempt++
|
|
lastResp = MakeRequest(t, req)
|
|
|
|
if condition(lastResp) {
|
|
if attempt > 1 {
|
|
t.Logf("API condition '%s' met after %d attempts", description, attempt)
|
|
}
|
|
return lastResp, true
|
|
}
|
|
|
|
// Progressive backoff: start with 100ms, max 500ms
|
|
sleepDuration := time.Duration(100*attempt) * time.Millisecond
|
|
if sleepDuration > 500*time.Millisecond {
|
|
sleepDuration = 500 * time.Millisecond
|
|
}
|
|
time.Sleep(sleepDuration)
|
|
}
|
|
|
|
t.Logf("Timeout waiting for API condition '%s' after %d attempts (%.1fs)", description, attempt, timeout.Seconds())
|
|
return lastResp, false
|
|
}
|
|
|
|
// ParseDuration function to parse duration strings
|
|
// Copied from framework/configstore/tables/utils.go
|
|
func ParseDuration(duration string) (time.Duration, error) {
|
|
if duration == "" {
|
|
return 0, fmt.Errorf("duration is empty")
|
|
}
|
|
|
|
// Handle special cases for days, weeks, months, years
|
|
switch {
|
|
case duration[len(duration)-1:] == "d":
|
|
days := duration[:len(duration)-1]
|
|
if d, err := time.ParseDuration(days + "h"); err == nil {
|
|
return d * 24, nil
|
|
}
|
|
return 0, fmt.Errorf("invalid day duration: %s", duration)
|
|
case duration[len(duration)-1:] == "w":
|
|
weeks := duration[:len(duration)-1]
|
|
if w, err := time.ParseDuration(weeks + "h"); err == nil {
|
|
return w * 24 * 7, nil
|
|
}
|
|
return 0, fmt.Errorf("invalid week duration: %s", duration)
|
|
case duration[len(duration)-1:] == "M":
|
|
months := duration[:len(duration)-1]
|
|
if m, err := time.ParseDuration(months + "h"); err == nil {
|
|
return m * 24 * 30, nil // Approximate month as 30 days
|
|
}
|
|
return 0, fmt.Errorf("invalid month duration: %s", duration)
|
|
case duration[len(duration)-1:] == "Y":
|
|
years := duration[:len(duration)-1]
|
|
if y, err := time.ParseDuration(years + "h"); err == nil {
|
|
return y * 24 * 365, nil // Approximate year as 365 days
|
|
}
|
|
return 0, fmt.Errorf("invalid year duration: %s", duration)
|
|
default:
|
|
return time.ParseDuration(duration)
|
|
}
|
|
}
|