2577 lines
78 KiB
Go
2577 lines
78 KiB
Go
package bifrost
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
mistralprovider "github.com/maximhq/bifrost/core/providers/mistral"
|
|
schemas "github.com/maximhq/bifrost/core/schemas"
|
|
"golang.org/x/text/cases"
|
|
"golang.org/x/text/language"
|
|
)
|
|
|
|
// Mock time.Sleep to avoid real delays in tests
|
|
var mockSleep func(time.Duration)
|
|
|
|
// Override time.Sleep in tests and setup logger
|
|
func init() {
|
|
mockSleep = func(d time.Duration) {
|
|
// Do nothing in tests to avoid real delays
|
|
}
|
|
}
|
|
|
|
// Helper function to create test config with specific retry settings
|
|
func createTestConfig(maxRetries int, initialBackoff, maxBackoff time.Duration) *schemas.ProviderConfig {
|
|
return &schemas.ProviderConfig{
|
|
NetworkConfig: schemas.NetworkConfig{
|
|
MaxRetries: maxRetries,
|
|
RetryBackoffInitial: initialBackoff,
|
|
RetryBackoffMax: maxBackoff,
|
|
},
|
|
}
|
|
}
|
|
|
|
// Helper function to create a BifrostError
|
|
func createBifrostError(message string, statusCode *int, errorType *string, isBifrostError bool) *schemas.BifrostError {
|
|
return &schemas.BifrostError{
|
|
IsBifrostError: isBifrostError,
|
|
StatusCode: statusCode,
|
|
Error: &schemas.ErrorField{
|
|
Message: message,
|
|
Type: errorType,
|
|
},
|
|
}
|
|
}
|
|
|
|
// Test executeRequestWithRetries - success scenarios
|
|
func TestExecuteRequestWithRetries_SuccessScenarios(t *testing.T) {
|
|
config := createTestConfig(3, 100*time.Millisecond, 1*time.Second)
|
|
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
logger := NewDefaultLogger(schemas.LogLevelError)
|
|
// Adding dummy tracer to the context
|
|
ctx.SetValue(schemas.BifrostContextKeyTracer, &schemas.NoOpTracer{})
|
|
// Test immediate success
|
|
t.Run("ImmediateSuccess", func(t *testing.T) {
|
|
callCount := 0
|
|
handler := func(_ schemas.Key) (string, *schemas.BifrostError) {
|
|
callCount++
|
|
return "success", nil
|
|
}
|
|
|
|
result, err := executeRequestWithRetries(
|
|
ctx,
|
|
config,
|
|
handler,
|
|
nil,
|
|
schemas.ChatCompletionRequest,
|
|
schemas.OpenAI,
|
|
"gpt-4",
|
|
nil,
|
|
logger,
|
|
)
|
|
|
|
if callCount != 1 {
|
|
t.Errorf("Expected 1 call, got %d", callCount)
|
|
}
|
|
if result != "success" {
|
|
t.Errorf("Expected 'success', got %s", result)
|
|
}
|
|
if err != nil {
|
|
t.Errorf("Expected no error, got %v", err)
|
|
}
|
|
})
|
|
|
|
// Test success after retries
|
|
t.Run("SuccessAfterRetries", func(t *testing.T) {
|
|
callCount := 0
|
|
handler := func(_ schemas.Key) (string, *schemas.BifrostError) {
|
|
callCount++
|
|
if callCount <= 2 {
|
|
// First two calls fail with retryable error
|
|
return "", createBifrostError("rate limit exceeded", Ptr(429), nil, false)
|
|
}
|
|
// Third call succeeds
|
|
return "success", nil
|
|
}
|
|
|
|
result, err := executeRequestWithRetries(
|
|
ctx,
|
|
config,
|
|
handler,
|
|
nil,
|
|
schemas.ChatCompletionRequest,
|
|
schemas.OpenAI,
|
|
"gpt-4",
|
|
nil,
|
|
logger,
|
|
)
|
|
|
|
if callCount != 3 {
|
|
t.Errorf("Expected 3 calls, got %d", callCount)
|
|
}
|
|
if result != "success" {
|
|
t.Errorf("Expected 'success', got %s", result)
|
|
}
|
|
if err != nil {
|
|
t.Errorf("Expected no error, got %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
// Test executeRequestWithRetries - retry limits
|
|
func TestExecuteRequestWithRetries_RetryLimits(t *testing.T) {
|
|
config := createTestConfig(2, 100*time.Millisecond, 1*time.Second)
|
|
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
ctx.SetValue(schemas.BifrostContextKeyTracer, &schemas.NoOpTracer{})
|
|
logger := NewDefaultLogger(schemas.LogLevelError)
|
|
t.Run("ExceedsMaxRetries", func(t *testing.T) {
|
|
callCount := 0
|
|
handler := func(_ schemas.Key) (string, *schemas.BifrostError) {
|
|
callCount++
|
|
// Always fail with retryable error
|
|
return "", createBifrostError("rate limit exceeded", Ptr(429), nil, false)
|
|
}
|
|
|
|
result, err := executeRequestWithRetries(
|
|
ctx,
|
|
config,
|
|
handler,
|
|
nil,
|
|
schemas.ChatCompletionRequest,
|
|
schemas.OpenAI,
|
|
"gpt-4",
|
|
nil,
|
|
logger,
|
|
)
|
|
|
|
// Should try: initial + 2 retries = 3 total attempts
|
|
if callCount != 3 {
|
|
t.Errorf("Expected 3 calls (initial + 2 retries), got %d", callCount)
|
|
}
|
|
if result != "" {
|
|
t.Errorf("Expected empty result, got %s", result)
|
|
}
|
|
if err == nil {
|
|
t.Fatal("Expected error after exceeding max retries")
|
|
}
|
|
if err.Error == nil {
|
|
t.Fatal("Expected error structure, got nil")
|
|
}
|
|
if err.Error.Message != "rate limit exceeded" {
|
|
t.Errorf("Expected rate limit error, got %s", err.Error.Message)
|
|
}
|
|
})
|
|
}
|
|
|
|
// Test executeRequestWithRetries - non-retryable errors
|
|
func TestExecuteRequestWithRetries_NonRetryableErrors(t *testing.T) {
|
|
config := createTestConfig(3, 100*time.Millisecond, 1*time.Second)
|
|
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
ctx.SetValue(schemas.BifrostContextKeyTracer, &schemas.NoOpTracer{})
|
|
testCases := []struct {
|
|
name string
|
|
error *schemas.BifrostError
|
|
}{
|
|
{
|
|
name: "BifrostError",
|
|
error: createBifrostError("validation error", nil, nil, true),
|
|
},
|
|
{
|
|
name: "RequestCancelled",
|
|
error: createBifrostError("request cancelled", nil, Ptr(schemas.ErrRequestCancelled), false),
|
|
},
|
|
{
|
|
name: "Non-retryable status code",
|
|
error: createBifrostError("bad request", Ptr(400), nil, false),
|
|
},
|
|
{
|
|
name: "Non-retryable error message",
|
|
error: createBifrostError("invalid model", nil, nil, false),
|
|
},
|
|
}
|
|
logger := NewDefaultLogger(schemas.LogLevelError)
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
callCount := 0
|
|
handler := func(_ schemas.Key) (string, *schemas.BifrostError) {
|
|
callCount++
|
|
return "", tc.error
|
|
}
|
|
|
|
result, err := executeRequestWithRetries(
|
|
ctx,
|
|
config,
|
|
handler,
|
|
nil,
|
|
schemas.ChatCompletionRequest,
|
|
schemas.OpenAI,
|
|
"gpt-4",
|
|
nil,
|
|
logger,
|
|
)
|
|
|
|
if callCount != 1 {
|
|
t.Errorf("Expected 1 call (no retries), got %d", callCount)
|
|
}
|
|
if result != "" {
|
|
t.Errorf("Expected empty result, got %s", result)
|
|
}
|
|
if err != tc.error {
|
|
t.Error("Expected original error to be returned")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// Test executeRequestWithRetries - retryable conditions
|
|
func TestExecuteRequestWithRetries_RetryableConditions(t *testing.T) {
|
|
config := createTestConfig(1, 100*time.Millisecond, 1*time.Second)
|
|
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
ctx.SetValue(schemas.BifrostContextKeyTracer, &schemas.NoOpTracer{})
|
|
testCases := []struct {
|
|
name string
|
|
error *schemas.BifrostError
|
|
}{
|
|
{
|
|
name: "StatusCode_500",
|
|
error: createBifrostError("internal server error", Ptr(500), nil, false),
|
|
},
|
|
{
|
|
name: "StatusCode_502",
|
|
error: createBifrostError("bad gateway", Ptr(502), nil, false),
|
|
},
|
|
{
|
|
name: "StatusCode_503",
|
|
error: createBifrostError("service unavailable", Ptr(503), nil, false),
|
|
},
|
|
{
|
|
name: "StatusCode_504",
|
|
error: createBifrostError("gateway timeout", Ptr(504), nil, false),
|
|
},
|
|
{
|
|
name: "StatusCode_429",
|
|
error: createBifrostError("too many requests", Ptr(429), nil, false),
|
|
},
|
|
{
|
|
name: "ErrProviderDoRequest",
|
|
error: createBifrostError(schemas.ErrProviderDoRequest, nil, nil, false),
|
|
},
|
|
{
|
|
name: "RateLimitMessage",
|
|
error: createBifrostError("rate limit exceeded", nil, nil, false),
|
|
},
|
|
{
|
|
name: "RateLimitType",
|
|
error: createBifrostError("some error", nil, Ptr("rate_limit"), false),
|
|
},
|
|
}
|
|
logger := NewDefaultLogger(schemas.LogLevelError)
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
callCount := 0
|
|
handler := func(_ schemas.Key) (string, *schemas.BifrostError) {
|
|
callCount++
|
|
return "", tc.error
|
|
}
|
|
|
|
result, err := executeRequestWithRetries(
|
|
ctx,
|
|
config,
|
|
handler,
|
|
nil,
|
|
schemas.ChatCompletionRequest,
|
|
schemas.OpenAI,
|
|
"gpt-4",
|
|
nil,
|
|
logger,
|
|
)
|
|
|
|
// Should try: initial + 1 retry = 2 total attempts
|
|
if callCount != 2 {
|
|
t.Errorf("Expected 2 calls (initial + 1 retry), got %d", callCount)
|
|
}
|
|
if result != "" {
|
|
t.Errorf("Expected empty result, got %s", result)
|
|
}
|
|
if err != tc.error {
|
|
t.Error("Expected original error to be returned")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// Test calculateBackoff - exponential growth (base calculations without jitter)
|
|
func TestCalculateBackoff_ExponentialGrowth(t *testing.T) {
|
|
config := createTestConfig(5, 100*time.Millisecond, 5*time.Second)
|
|
|
|
// Test the base exponential calculation by checking that results fall within expected ranges
|
|
// Since we can't easily mock rand.Float64, we'll test the bounds instead
|
|
testCases := []struct {
|
|
attempt int
|
|
minExpected time.Duration
|
|
maxExpected time.Duration
|
|
}{
|
|
{0, 80 * time.Millisecond, 120 * time.Millisecond}, // 100ms ± 20%
|
|
{1, 160 * time.Millisecond, 240 * time.Millisecond}, // 200ms ± 20%
|
|
{2, 320 * time.Millisecond, 480 * time.Millisecond}, // 400ms ± 20%
|
|
{3, 640 * time.Millisecond, 960 * time.Millisecond}, // 800ms ± 20%
|
|
{4, 1280 * time.Millisecond, 1920 * time.Millisecond}, // 1600ms ± 20%
|
|
{5, 2560 * time.Millisecond, 3840 * time.Millisecond}, // 3200ms ± 20%
|
|
{10, 4 * time.Second, 6 * time.Second}, // should be capped at max (5s) ± 20%
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(fmt.Sprintf("Attempt_%d", tc.attempt), func(t *testing.T) {
|
|
backoff := calculateBackoff(tc.attempt, config)
|
|
if backoff < tc.minExpected || backoff > tc.maxExpected {
|
|
t.Errorf("Backoff %v outside expected range [%v, %v]", backoff, tc.minExpected, tc.maxExpected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// Test calculateBackoff - jitter bounds
|
|
func TestCalculateBackoff_JitterBounds(t *testing.T) {
|
|
config := createTestConfig(3, 100*time.Millisecond, 5*time.Second)
|
|
|
|
// Test jitter bounds for multiple attempts
|
|
for attempt := 0; attempt < 3; attempt++ {
|
|
t.Run(fmt.Sprintf("Attempt_%d_JitterBounds", attempt), func(t *testing.T) {
|
|
// Calculate expected base backoff
|
|
baseBackoff := config.NetworkConfig.RetryBackoffInitial * time.Duration(1<<uint(attempt))
|
|
if baseBackoff > config.NetworkConfig.RetryBackoffMax {
|
|
baseBackoff = config.NetworkConfig.RetryBackoffMax
|
|
}
|
|
|
|
// Test multiple samples to verify jitter bounds
|
|
for i := 0; i < 100; i++ {
|
|
backoff := calculateBackoff(attempt, config)
|
|
|
|
// Jitter should be ±20% (0.8 to 1.2 multiplier), but capped at configured max
|
|
minExpected := time.Duration(float64(baseBackoff) * 0.8)
|
|
maxExpected := min(time.Duration(float64(baseBackoff)*1.2), config.NetworkConfig.RetryBackoffMax)
|
|
|
|
if backoff < minExpected || backoff > maxExpected {
|
|
t.Errorf("Backoff %v outside expected range [%v, %v] for attempt %d",
|
|
backoff, minExpected, maxExpected, attempt)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// Test calculateBackoff - max backoff cap
|
|
func TestCalculateBackoff_MaxBackoffCap(t *testing.T) {
|
|
config := createTestConfig(10, 100*time.Millisecond, 500*time.Millisecond)
|
|
|
|
// High attempt numbers should be capped at max backoff
|
|
for attempt := 5; attempt < 10; attempt++ {
|
|
backoff := calculateBackoff(attempt, config)
|
|
|
|
// Jitter should never exceed the configured maximum
|
|
if backoff > config.NetworkConfig.RetryBackoffMax {
|
|
t.Errorf("Backoff %v exceeds configured max %v for attempt %d",
|
|
backoff, config.NetworkConfig.RetryBackoffMax, attempt)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Test IsRateLimitErrorMessage - all patterns
|
|
func TestIsRateLimitError_AllPatterns(t *testing.T) {
|
|
// Test all patterns from rateLimitPatterns
|
|
patterns := []string{
|
|
"rate limit",
|
|
"rate_limit",
|
|
"ratelimit",
|
|
"too many requests",
|
|
"quota exceeded",
|
|
"quota_exceeded",
|
|
"request limit",
|
|
"throttled",
|
|
"throttling",
|
|
"rate exceeded",
|
|
"limit exceeded",
|
|
"requests per",
|
|
"rpm exceeded",
|
|
"tpm exceeded",
|
|
"tokens per minute",
|
|
"requests per minute",
|
|
"requests per second",
|
|
"api rate limit",
|
|
"usage limit",
|
|
"concurrent requests limit",
|
|
"burst_rate",
|
|
"rate increased",
|
|
}
|
|
|
|
for _, pattern := range patterns {
|
|
t.Run(fmt.Sprintf("Pattern_%s", strings.ReplaceAll(pattern, " ", "_")), func(t *testing.T) {
|
|
// Test exact match
|
|
if !IsRateLimitErrorMessage(pattern) {
|
|
t.Errorf("Pattern '%s' should be detected as rate limit error", pattern)
|
|
}
|
|
|
|
// Test case insensitive - uppercase
|
|
if !IsRateLimitErrorMessage(strings.ToUpper(pattern)) {
|
|
t.Errorf("Uppercase pattern '%s' should be detected as rate limit error", strings.ToUpper(pattern))
|
|
}
|
|
|
|
// Test case insensitive - mixed case
|
|
if !IsRateLimitErrorMessage(cases.Title(language.English).String(pattern)) {
|
|
t.Errorf("Title case pattern '%s' should be detected as rate limit error", cases.Title(language.English).String(pattern))
|
|
}
|
|
|
|
// Test as part of larger message
|
|
message := fmt.Sprintf("Error: %s occurred", pattern)
|
|
if !IsRateLimitErrorMessage(message) {
|
|
t.Errorf("Pattern '%s' in message '%s' should be detected", pattern, message)
|
|
}
|
|
|
|
// Test with prefix and suffix
|
|
message = fmt.Sprintf("API call failed due to %s - please retry later", pattern)
|
|
if !IsRateLimitErrorMessage(message) {
|
|
t.Errorf("Pattern '%s' in complex message should be detected", pattern)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// Test IsRateLimitErrorMessage - negative cases
|
|
func TestIsRateLimitError_NegativeCases(t *testing.T) {
|
|
negativeCases := []string{
|
|
"",
|
|
"invalid request",
|
|
"authentication failed",
|
|
"model not found",
|
|
"internal server error",
|
|
"bad gateway",
|
|
"service unavailable",
|
|
"timeout",
|
|
"connection refused",
|
|
"rate", // partial match shouldn't trigger
|
|
"limit", // partial match shouldn't trigger
|
|
"quota", // partial match shouldn't trigger
|
|
"throttle", // partial match shouldn't trigger (need 'throttled' or 'throttling')
|
|
}
|
|
|
|
for _, testCase := range negativeCases {
|
|
t.Run(fmt.Sprintf("Negative_%s", strings.ReplaceAll(testCase, " ", "_")), func(t *testing.T) {
|
|
if IsRateLimitErrorMessage(testCase) {
|
|
t.Errorf("Message '%s' should NOT be detected as rate limit error", testCase)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// Test IsRateLimitErrorMessage - edge cases
|
|
func TestIsRateLimitError_EdgeCases(t *testing.T) {
|
|
t.Run("EmptyString", func(t *testing.T) {
|
|
if IsRateLimitErrorMessage("") {
|
|
t.Error("Empty string should not be detected as rate limit error")
|
|
}
|
|
})
|
|
|
|
t.Run("OnlyWhitespace", func(t *testing.T) {
|
|
if IsRateLimitErrorMessage(" \t\n ") {
|
|
t.Error("Whitespace-only string should not be detected as rate limit error")
|
|
}
|
|
})
|
|
|
|
t.Run("UnicodeCharacters", func(t *testing.T) {
|
|
// Test with unicode characters that might affect case conversion
|
|
message := "RATE LIMIT exceeded 🚫"
|
|
if !IsRateLimitErrorMessage(message) {
|
|
t.Error("Message with unicode should still detect rate limit pattern")
|
|
}
|
|
})
|
|
|
|
t.Run("DashScopeErrorCode", func(t *testing.T) {
|
|
// DashScope returns "limit_burst_rate" as the error code
|
|
if !IsRateLimitErrorMessage("limit_burst_rate") {
|
|
t.Error("DashScope error code 'limit_burst_rate' should be detected as rate limit error")
|
|
}
|
|
})
|
|
|
|
t.Run("DashScopeErrorMessage", func(t *testing.T) {
|
|
// DashScope returns this as the error message
|
|
if !IsRateLimitErrorMessage("Request rate increased too quickly, please slow down and try again") {
|
|
t.Error("DashScope error message should be detected as rate limit error")
|
|
}
|
|
})
|
|
}
|
|
|
|
// Test retry logging and attempt counting
|
|
func TestExecuteRequestWithRetries_LoggingAndCounting(t *testing.T) {
|
|
config := createTestConfig(2, 50*time.Millisecond, 1*time.Second)
|
|
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
ctx.SetValue(schemas.BifrostContextKeyTracer, &schemas.NoOpTracer{})
|
|
// Capture calls and timing for verification
|
|
var attemptCounts []int
|
|
callCount := 0
|
|
|
|
handler := func(_ schemas.Key) (string, *schemas.BifrostError) {
|
|
callCount++
|
|
attemptCounts = append(attemptCounts, callCount)
|
|
|
|
if callCount <= 2 {
|
|
// First two calls fail with retryable error
|
|
return "", createBifrostError("rate limit exceeded", Ptr(429), nil, false)
|
|
}
|
|
// Third call succeeds
|
|
return "success", nil
|
|
}
|
|
logger := NewDefaultLogger(schemas.LogLevelError)
|
|
|
|
result, err := executeRequestWithRetries(
|
|
ctx,
|
|
config,
|
|
handler,
|
|
nil,
|
|
schemas.ChatCompletionRequest,
|
|
schemas.OpenAI,
|
|
"gpt-4",
|
|
nil,
|
|
logger,
|
|
)
|
|
|
|
// Verify call progression
|
|
if len(attemptCounts) != 3 {
|
|
t.Errorf("Expected 3 attempts, got %d", len(attemptCounts))
|
|
}
|
|
|
|
for i, count := range attemptCounts {
|
|
if count != i+1 {
|
|
t.Errorf("Attempt %d should have call count %d, got %d", i, i+1, count)
|
|
}
|
|
}
|
|
|
|
if result != "success" {
|
|
t.Errorf("Expected success result, got %s", result)
|
|
}
|
|
|
|
if err != nil {
|
|
t.Errorf("Expected no error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHandleProviderRequest_OCROperationNotAllowed(t *testing.T) {
|
|
providerConfig := &schemas.ProviderConfig{
|
|
NetworkConfig: schemas.NetworkConfig{
|
|
BaseURL: "http://127.0.0.1:1",
|
|
DefaultRequestTimeoutInSeconds: 1,
|
|
},
|
|
CustomProviderConfig: &schemas.CustomProviderConfig{
|
|
CustomProviderKey: "custom-mistral",
|
|
BaseProviderType: schemas.Mistral,
|
|
AllowedRequests: &schemas.AllowedRequests{},
|
|
},
|
|
}
|
|
provider := mistralprovider.NewMistralProvider(providerConfig, NewDefaultLogger(schemas.LogLevelError))
|
|
if provider.GetProviderKey() != schemas.ModelProvider("custom-mistral") {
|
|
t.Fatalf("expected custom provider key, got %q", provider.GetProviderKey())
|
|
}
|
|
bifrost := &Bifrost{}
|
|
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
|
|
request := &ChannelMessage{
|
|
Context: ctx,
|
|
BifrostRequest: schemas.BifrostRequest{
|
|
RequestType: schemas.OCRRequest,
|
|
OCRRequest: &schemas.BifrostOCRRequest{
|
|
Model: "custom-mistral/mistral-ocr-latest",
|
|
Document: schemas.OCRDocument{
|
|
Type: schemas.OCRDocumentTypeDocumentURL,
|
|
DocumentURL: Ptr("https://example.com/doc.pdf"),
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
response, err := bifrost.handleProviderRequest(provider, providerConfig, request, schemas.Key{}, nil)
|
|
if response != nil {
|
|
t.Fatalf("expected nil response, got %#v", response)
|
|
}
|
|
if err == nil {
|
|
t.Fatal("expected unsupported operation error, got nil")
|
|
}
|
|
if err.Error == nil {
|
|
t.Fatal("expected detailed error, got nil")
|
|
}
|
|
if err.Error.Code == nil || *err.Error.Code != "unsupported_operation" {
|
|
t.Fatalf("expected unsupported_operation code, got %#v", err.Error.Code)
|
|
}
|
|
if err.ExtraFields.Provider != schemas.ModelProvider("custom-mistral") {
|
|
t.Fatalf("expected custom provider name, got %q", err.ExtraFields.Provider)
|
|
}
|
|
if err.ExtraFields.RequestType != schemas.OCRRequest {
|
|
t.Fatalf("expected OCR request type, got %q", err.ExtraFields.RequestType)
|
|
}
|
|
if err.ExtraFields.OriginalModelRequested != "custom-mistral/mistral-ocr-latest" {
|
|
t.Fatalf("expected model to be preserved, got %q", err.ExtraFields.OriginalModelRequested)
|
|
}
|
|
}
|
|
|
|
// Test that retryableStatusCodes are properly defined
|
|
func TestRetryableStatusCodes(t *testing.T) {
|
|
expectedCodes := map[int]bool{
|
|
500: true, // Internal Server Error
|
|
502: true, // Bad Gateway
|
|
503: true, // Service Unavailable
|
|
504: true, // Gateway Timeout
|
|
429: true, // Too Many Requests
|
|
}
|
|
|
|
for code, expected := range expectedCodes {
|
|
if retryableStatusCodes[code] != expected {
|
|
t.Errorf("Status code %d should be retryable=%v, got %v", code, expected, retryableStatusCodes[code])
|
|
}
|
|
}
|
|
|
|
// Test non-retryable codes
|
|
nonRetryableCodes := []int{200, 201, 400, 401, 403, 404, 422}
|
|
for _, code := range nonRetryableCodes {
|
|
if retryableStatusCodes[code] {
|
|
t.Errorf("Status code %d should not be retryable", code)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Benchmark calculateBackoff performance
|
|
func BenchmarkCalculateBackoff(b *testing.B) {
|
|
config := createTestConfig(10, 100*time.Millisecond, 5*time.Second)
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
calculateBackoff(i%10, config)
|
|
}
|
|
}
|
|
|
|
// Benchmark IsRateLimitErrorMessage performance
|
|
func BenchmarkIsRateLimitError(b *testing.B) {
|
|
messages := []string{
|
|
"rate limit exceeded",
|
|
"too many requests",
|
|
"quota exceeded",
|
|
"throttled by provider",
|
|
"API rate limit reached",
|
|
"not a rate limit error",
|
|
"authentication failed",
|
|
"model not found",
|
|
}
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
IsRateLimitErrorMessage(messages[i%len(messages)])
|
|
}
|
|
}
|
|
|
|
// Mock Account implementation for testing UpdateProvider
|
|
type MockAccount struct {
|
|
mu sync.RWMutex
|
|
configs map[schemas.ModelProvider]*schemas.ProviderConfig
|
|
keys map[schemas.ModelProvider][]schemas.Key
|
|
}
|
|
|
|
func NewMockAccount() *MockAccount {
|
|
return &MockAccount{
|
|
configs: make(map[schemas.ModelProvider]*schemas.ProviderConfig),
|
|
keys: make(map[schemas.ModelProvider][]schemas.Key),
|
|
}
|
|
}
|
|
|
|
func (ma *MockAccount) AddProvider(provider schemas.ModelProvider, concurrency int, bufferSize int) {
|
|
ma.AddProviderWithBaseURL(provider, concurrency, bufferSize, "")
|
|
}
|
|
|
|
func (ma *MockAccount) AddProviderWithBaseURL(provider schemas.ModelProvider, concurrency int, bufferSize int, baseURL string) {
|
|
ma.mu.Lock()
|
|
defer ma.mu.Unlock()
|
|
ma.configs[provider] = &schemas.ProviderConfig{
|
|
NetworkConfig: schemas.NetworkConfig{
|
|
BaseURL: baseURL,
|
|
DefaultRequestTimeoutInSeconds: 30,
|
|
MaxRetries: 3,
|
|
RetryBackoffInitial: 500 * time.Millisecond,
|
|
RetryBackoffMax: 5 * time.Second,
|
|
},
|
|
ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{
|
|
Concurrency: concurrency,
|
|
BufferSize: bufferSize,
|
|
},
|
|
}
|
|
|
|
ma.keys[provider] = []schemas.Key{
|
|
{
|
|
ID: fmt.Sprintf("test-key-%s", provider),
|
|
Value: *schemas.NewEnvVar(fmt.Sprintf("sk-test-%s", provider)),
|
|
Weight: 100,
|
|
},
|
|
}
|
|
}
|
|
|
|
func (ma *MockAccount) UpdateProviderConfig(provider schemas.ModelProvider, concurrency int, bufferSize int) {
|
|
ma.mu.Lock()
|
|
defer ma.mu.Unlock()
|
|
if config, exists := ma.configs[provider]; exists {
|
|
config.ConcurrencyAndBufferSize.Concurrency = concurrency
|
|
config.ConcurrencyAndBufferSize.BufferSize = bufferSize
|
|
}
|
|
}
|
|
|
|
func (ma *MockAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) {
|
|
ma.mu.RLock()
|
|
defer ma.mu.RUnlock()
|
|
providers := make([]schemas.ModelProvider, 0, len(ma.configs))
|
|
for provider := range ma.configs {
|
|
providers = append(providers, provider)
|
|
}
|
|
return providers, nil
|
|
}
|
|
|
|
func (ma *MockAccount) GetConfigForProvider(provider schemas.ModelProvider) (*schemas.ProviderConfig, error) {
|
|
ma.mu.RLock()
|
|
defer ma.mu.RUnlock()
|
|
if config, exists := ma.configs[provider]; exists {
|
|
// Return a copy to simulate real behavior
|
|
configCopy := *config
|
|
return &configCopy, nil
|
|
}
|
|
return nil, fmt.Errorf("provider %s not configured", provider)
|
|
}
|
|
|
|
func (ma *MockAccount) GetKeysForProvider(ctx context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) {
|
|
ma.mu.RLock()
|
|
defer ma.mu.RUnlock()
|
|
if keys, exists := ma.keys[provider]; exists {
|
|
return keys, nil
|
|
}
|
|
return nil, fmt.Errorf("no keys for provider %s", provider)
|
|
}
|
|
|
|
func (ma *MockAccount) SetKeysForProvider(provider schemas.ModelProvider, keys []schemas.Key) {
|
|
ma.mu.Lock()
|
|
defer ma.mu.Unlock()
|
|
ma.keys[provider] = keys
|
|
}
|
|
|
|
// mockKVStore implements schemas.KVStore for session stickiness tests.
|
|
type mockKVStore struct {
|
|
mu sync.RWMutex
|
|
data map[string]struct {
|
|
value any
|
|
ttl time.Duration
|
|
}
|
|
}
|
|
|
|
func newMockKVStore() *mockKVStore {
|
|
return &mockKVStore{data: make(map[string]struct {
|
|
value any
|
|
ttl time.Duration
|
|
})}
|
|
}
|
|
|
|
func (m *mockKVStore) Get(key string) (any, error) {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
if e, ok := m.data[key]; ok {
|
|
return e.value, nil
|
|
}
|
|
return nil, fmt.Errorf("key not found")
|
|
}
|
|
|
|
func (m *mockKVStore) SetWithTTL(key string, value any, ttl time.Duration) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
m.data[key] = struct {
|
|
value any
|
|
ttl time.Duration
|
|
}{value: value, ttl: ttl}
|
|
return nil
|
|
}
|
|
|
|
func (m *mockKVStore) SetNXWithTTL(key string, value any, ttl time.Duration) (bool, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
if _, ok := m.data[key]; ok {
|
|
return false, nil
|
|
}
|
|
m.data[key] = struct {
|
|
value any
|
|
ttl time.Duration
|
|
}{value: value, ttl: ttl}
|
|
return true, nil
|
|
}
|
|
|
|
func (m *mockKVStore) Delete(key string) (bool, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
if _, ok := m.data[key]; ok {
|
|
delete(m.data, key)
|
|
return true, nil
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
// Test selectKeyFromProviderForModelWithPool with session stickiness
|
|
func TestSelectKeyFromProviderForModel_SessionStickiness(t *testing.T) {
|
|
kvStore := newMockKVStore()
|
|
account := NewMockAccount()
|
|
account.AddProvider(schemas.OpenAI, 5, 1000)
|
|
// Use 2 keys so we hit the keySelector path (single key returns early)
|
|
account.SetKeysForProvider(schemas.OpenAI, []schemas.Key{
|
|
{ID: "key-a", Name: "Key A", Value: *schemas.NewEnvVar("sk-a"), Models: schemas.WhiteList{"*"}, Weight: 1},
|
|
{ID: "key-b", Name: "Key B", Value: *schemas.NewEnvVar("sk-b"), Models: schemas.WhiteList{"*"}, Weight: 1},
|
|
})
|
|
|
|
var keySelectorCalls int
|
|
deterministicSelector := func(ctx *schemas.BifrostContext, keys []schemas.Key, _ schemas.ModelProvider, _ string) (schemas.Key, error) {
|
|
keySelectorCalls++
|
|
return keys[0], nil // always return first key
|
|
}
|
|
|
|
ctx := context.Background()
|
|
bifrost, err := Init(ctx, schemas.BifrostConfig{
|
|
Account: account,
|
|
Logger: NewDefaultLogger(schemas.LogLevelError),
|
|
KVStore: kvStore,
|
|
KeySelector: deterministicSelector,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Init failed: %v", err)
|
|
}
|
|
|
|
bfCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
bfCtx.SetValue(schemas.BifrostContextKeySessionID, "sess-123")
|
|
|
|
// First call: cache miss, keySelector runs, key stored; returns single-element pool (canRotate=false)
|
|
keys1, canRotate1, err := bifrost.selectKeyFromProviderForModelWithPool(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI)
|
|
if err != nil {
|
|
t.Fatalf("first selectKeyFromProviderForModelWithPool: %v", err)
|
|
}
|
|
if canRotate1 {
|
|
t.Error("first call: canRotate should be false for session-sticky request")
|
|
}
|
|
if len(keys1) != 1 || keys1[0].ID != "key-a" {
|
|
t.Errorf("first call: expected [key-a], got %v", keys1)
|
|
}
|
|
if keySelectorCalls != 1 {
|
|
t.Errorf("first call: expected 1 keySelector call, got %d", keySelectorCalls)
|
|
}
|
|
|
|
// Verify kvstore was written
|
|
kvKey := buildSessionKey(schemas.OpenAI, "sess-123", "gpt-4")
|
|
if raw, err := kvStore.Get(kvKey); err != nil || raw != "key-a" {
|
|
t.Errorf("kvstore after first call: expected key-a, got %v (err=%v)", raw, err)
|
|
}
|
|
|
|
// Second call: cache hit, same key returned, keySelector NOT called
|
|
keys2, canRotate2, err := bifrost.selectKeyFromProviderForModelWithPool(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI)
|
|
if err != nil {
|
|
t.Fatalf("second selectKeyFromProviderForModelWithPool: %v", err)
|
|
}
|
|
if canRotate2 {
|
|
t.Error("second call: canRotate should be false for session-sticky request")
|
|
}
|
|
if len(keys2) != 1 || keys2[0].ID != "key-a" {
|
|
t.Errorf("second call: expected [key-a] (sticky), got %v", keys2)
|
|
}
|
|
if keySelectorCalls != 1 {
|
|
t.Errorf("second call: keySelector should not run (cache hit), got %d calls", keySelectorCalls)
|
|
}
|
|
}
|
|
|
|
// Test selectKeyFromProviderForModelWithPool - no stickiness when session ID absent
|
|
func TestSelectKeyFromProviderForModel_NoStickinessWithoutSessionID(t *testing.T) {
|
|
kvStore := newMockKVStore()
|
|
account := NewMockAccount()
|
|
account.AddProvider(schemas.OpenAI, 5, 1000)
|
|
account.SetKeysForProvider(schemas.OpenAI, []schemas.Key{
|
|
{ID: "key-a", Name: "Key A", Value: *schemas.NewEnvVar("sk-a"), Models: schemas.WhiteList{"*"}, Weight: 1},
|
|
{ID: "key-b", Name: "Key B", Value: *schemas.NewEnvVar("sk-b"), Models: schemas.WhiteList{"*"}, Weight: 1},
|
|
})
|
|
|
|
var keySelectorCalls int
|
|
deterministicSelector := func(ctx *schemas.BifrostContext, keys []schemas.Key, _ schemas.ModelProvider, _ string) (schemas.Key, error) {
|
|
keySelectorCalls++
|
|
return keys[0], nil
|
|
}
|
|
|
|
ctx := context.Background()
|
|
bifrost, err := Init(ctx, schemas.BifrostConfig{
|
|
Account: account,
|
|
Logger: NewDefaultLogger(schemas.LogLevelError),
|
|
KVStore: kvStore,
|
|
KeySelector: deterministicSelector,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Init failed: %v", err)
|
|
}
|
|
|
|
bfCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
// No session ID set — pool is returned with canRotate=true; keySelector is called each time.
|
|
|
|
for i := 0; i < 2; i++ {
|
|
pool, canRotate, err := bifrost.selectKeyFromProviderForModelWithPool(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI)
|
|
if err != nil {
|
|
t.Fatalf("selectKeyFromProviderForModelWithPool call %d: %v", i+1, err)
|
|
}
|
|
if !canRotate {
|
|
t.Fatalf("call %d: canRotate should be true without a session id", i+1)
|
|
}
|
|
if len(pool) == 0 {
|
|
t.Fatalf("call %d: expected non-empty pool", i+1)
|
|
}
|
|
}
|
|
if keySelectorCalls != 0 {
|
|
t.Errorf("expected 0 keySelector calls from pool building (no session id), got %d", keySelectorCalls)
|
|
}
|
|
// KVStore should not have a sticky entry for an empty session id
|
|
if _, err := kvStore.Get(buildSessionKey(schemas.OpenAI, "", "gpt-4")); err == nil {
|
|
t.Error("kvstore should not have a sticky entry for an empty session id")
|
|
}
|
|
}
|
|
|
|
// TestSelectKeyFromProviderForModel_SessionStickinessNoRotation verifies that when a session ID
|
|
// is present, rate-limit retries reuse the sticky key rather than rotating to another key.
|
|
func TestSelectKeyFromProviderForModel_SessionStickinessNoRotation(t *testing.T) {
|
|
kvStore := newMockKVStore()
|
|
account := NewMockAccount()
|
|
account.AddProvider(schemas.OpenAI, 5, 1000)
|
|
account.SetKeysForProvider(schemas.OpenAI, []schemas.Key{
|
|
{ID: "key-a", Name: "Key A", Value: *schemas.NewEnvVar("sk-a"), Models: schemas.WhiteList{"*"}, Weight: 1},
|
|
{ID: "key-b", Name: "Key B", Value: *schemas.NewEnvVar("sk-b"), Models: schemas.WhiteList{"*"}, Weight: 1},
|
|
})
|
|
|
|
deterministicSelector := func(ctx *schemas.BifrostContext, keys []schemas.Key, _ schemas.ModelProvider, _ string) (schemas.Key, error) {
|
|
return keys[0], nil // always picks key-a when pool includes it
|
|
}
|
|
|
|
ctx := context.Background()
|
|
bifrost, err := Init(ctx, schemas.BifrostConfig{
|
|
Account: account,
|
|
Logger: NewDefaultLogger(schemas.LogLevelError),
|
|
KVStore: kvStore,
|
|
KeySelector: deterministicSelector,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Init failed: %v", err)
|
|
}
|
|
|
|
bfCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
bfCtx.SetValue(schemas.BifrostContextKeySessionID, "sess-sticky")
|
|
bfCtx.SetValue(schemas.BifrostContextKeyTracer, &schemas.NoOpTracer{})
|
|
|
|
config := createTestConfig(3, 0, 0)
|
|
logger := NewDefaultLogger(schemas.LogLevelError)
|
|
|
|
// Build keyProvider the same way requestWorker does.
|
|
pool, canRotate, poolErr := bifrost.selectKeyFromProviderForModelWithPool(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI)
|
|
if poolErr != nil {
|
|
t.Fatalf("pool build failed: %v", poolErr)
|
|
}
|
|
if canRotate {
|
|
t.Fatal("expected canRotate=false for session-sticky request")
|
|
}
|
|
if len(pool) != 1 || pool[0].ID != "key-a" {
|
|
t.Fatalf("expected sticky pool=[key-a], got %v", pool)
|
|
}
|
|
|
|
fixedKey := pool[0]
|
|
keyProvider := func(_ map[string]bool) (schemas.Key, error) { return fixedKey, nil }
|
|
|
|
// Simulate 3 rate-limit failures then success; all attempts must use key-a.
|
|
var usedKeyIDs []string
|
|
callCount := 0
|
|
handler := func(k schemas.Key) (string, *schemas.BifrostError) {
|
|
usedKeyIDs = append(usedKeyIDs, k.ID)
|
|
callCount++
|
|
if callCount <= 3 {
|
|
return "", createBifrostError("rate limit exceeded", Ptr(429), nil, false)
|
|
}
|
|
return "ok", nil
|
|
}
|
|
|
|
result, retryErr := executeRequestWithRetries(bfCtx, config, handler, keyProvider,
|
|
schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", nil, logger)
|
|
|
|
if retryErr != nil {
|
|
t.Fatalf("expected success, got error: %v", retryErr)
|
|
}
|
|
if result != "ok" {
|
|
t.Errorf("expected 'ok', got %s", result)
|
|
}
|
|
for i, id := range usedKeyIDs {
|
|
if id != "key-a" {
|
|
t.Errorf("attempt %d: expected sticky key-a, got %s (full sequence: %v)", i, id, usedKeyIDs)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestSelectKeyFromProviderForModel_BlacklistedModels(t *testing.T) {
|
|
account := NewMockAccount()
|
|
account.AddProvider(schemas.OpenAI, 5, 1000)
|
|
|
|
ctx := context.Background()
|
|
bifrost, err := Init(ctx, schemas.BifrostConfig{
|
|
Account: account,
|
|
Logger: NewDefaultLogger(schemas.LogLevelError),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Init failed: %v", err)
|
|
}
|
|
bfCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
|
|
t.Run("all keys blacklist model", func(t *testing.T) {
|
|
account.SetKeysForProvider(schemas.OpenAI, []schemas.Key{
|
|
{ID: "k1", Name: "K1", Value: *schemas.NewEnvVar("sk-1"), Weight: 1, BlacklistedModels: []string{"gpt-4"}},
|
|
})
|
|
_, _, err := bifrost.selectKeyFromProviderForModelWithPool(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI)
|
|
if err == nil {
|
|
t.Fatal("expected error when model is only blacklisted")
|
|
}
|
|
if !strings.Contains(err.Error(), "no keys found that support model") {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("blacklist wins over models allow list", func(t *testing.T) {
|
|
account.SetKeysForProvider(schemas.OpenAI, []schemas.Key{
|
|
{
|
|
ID: "k1", Name: "K1", Value: *schemas.NewEnvVar("sk-1"), Weight: 1,
|
|
Models: []string{"gpt-4"},
|
|
BlacklistedModels: []string{"gpt-4"},
|
|
},
|
|
})
|
|
_, _, err := bifrost.selectKeyFromProviderForModelWithPool(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI)
|
|
if err == nil {
|
|
t.Fatal("expected error when model is both allowed and blacklisted")
|
|
}
|
|
})
|
|
|
|
t.Run("second key used when first blacklists", func(t *testing.T) {
|
|
account.SetKeysForProvider(schemas.OpenAI, []schemas.Key{
|
|
{ID: "k1", Name: "K1", Value: *schemas.NewEnvVar("sk-1"), Weight: 1, BlacklistedModels: []string{"gpt-4"}},
|
|
{ID: "k2", Name: "K2", Value: *schemas.NewEnvVar("sk-2"), Weight: 1, Models: []string{"*"}},
|
|
})
|
|
pool, canRotate, err := bifrost.selectKeyFromProviderForModelWithPool(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
// After filtering, only k2 remains — single key returns canRotate=false.
|
|
if canRotate {
|
|
t.Fatal("expected canRotate=false for single-key pool after filtering")
|
|
}
|
|
if len(pool) != 1 || pool[0].ID != "k2" {
|
|
t.Fatalf("expected pool=[k2], got %v", pool)
|
|
}
|
|
})
|
|
}
|
|
|
|
// Test key rotation in executeRequestWithRetries on rate-limit errors
|
|
func TestExecuteRequestWithRetries_KeyRotation(t *testing.T) {
|
|
config := createTestConfig(3, 0, 0)
|
|
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
ctx.SetValue(schemas.BifrostContextKeyTracer, &schemas.NoOpTracer{})
|
|
logger := NewDefaultLogger(schemas.LogLevelError)
|
|
|
|
keys := []schemas.Key{
|
|
{ID: "k1", Name: "K1"},
|
|
{ID: "k2", Name: "K2"},
|
|
{ID: "k3", Name: "K3"},
|
|
}
|
|
|
|
t.Run("RotatesKeyOnRateLimitRetry", func(t *testing.T) {
|
|
var selectedKeyIDs []string
|
|
keyProvider := func(usedKeyIDs map[string]bool) (schemas.Key, error) {
|
|
for _, k := range keys {
|
|
if !usedKeyIDs[k.ID] {
|
|
return k, nil
|
|
}
|
|
}
|
|
// Fresh round
|
|
for id := range usedKeyIDs {
|
|
delete(usedKeyIDs, id)
|
|
}
|
|
return keys[0], nil
|
|
}
|
|
|
|
handler := func(k schemas.Key) (string, *schemas.BifrostError) {
|
|
selectedKeyIDs = append(selectedKeyIDs, k.ID)
|
|
// First two calls rate-limit, third succeeds
|
|
if len(selectedKeyIDs) <= 2 {
|
|
return "", createBifrostError("rate limit exceeded", Ptr(429), nil, false)
|
|
}
|
|
return "success", nil
|
|
}
|
|
|
|
result, err := executeRequestWithRetries(ctx, config, handler, keyProvider,
|
|
schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", nil, logger)
|
|
|
|
if err != nil {
|
|
t.Fatalf("expected success, got error: %v", err)
|
|
}
|
|
if result != "success" {
|
|
t.Errorf("expected 'success', got %s", result)
|
|
}
|
|
if len(selectedKeyIDs) != 3 {
|
|
t.Fatalf("expected 3 attempts, got %d", len(selectedKeyIDs))
|
|
}
|
|
// Each attempt should use a different key
|
|
seen := map[string]struct{}{}
|
|
for _, id := range selectedKeyIDs {
|
|
seen[id] = struct{}{}
|
|
}
|
|
if len(seen) != len(selectedKeyIDs) {
|
|
t.Errorf("expected distinct keys per rate-limit retry, got %v", selectedKeyIDs)
|
|
}
|
|
})
|
|
|
|
t.Run("SameKeyOnNetworkError", func(t *testing.T) {
|
|
var selectedKeyIDs []string
|
|
keyProviderCalls := 0
|
|
keyProvider := func(usedKeyIDs map[string]bool) (schemas.Key, error) {
|
|
keyProviderCalls++
|
|
for _, k := range keys {
|
|
if !usedKeyIDs[k.ID] {
|
|
return k, nil
|
|
}
|
|
}
|
|
for id := range usedKeyIDs {
|
|
delete(usedKeyIDs, id)
|
|
}
|
|
return keys[0], nil
|
|
}
|
|
|
|
callCount := 0
|
|
handler := func(k schemas.Key) (string, *schemas.BifrostError) {
|
|
selectedKeyIDs = append(selectedKeyIDs, k.ID)
|
|
callCount++
|
|
if callCount <= 2 {
|
|
return "", createBifrostError(schemas.ErrProviderDoRequest, nil, nil, false)
|
|
}
|
|
return "success", nil
|
|
}
|
|
|
|
result, err := executeRequestWithRetries(ctx, config, handler, keyProvider,
|
|
schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", nil, logger)
|
|
|
|
if err != nil {
|
|
t.Fatalf("expected success, got error: %v", err)
|
|
}
|
|
if result != "success" {
|
|
t.Errorf("expected 'success', got %s", result)
|
|
}
|
|
if len(selectedKeyIDs) != 3 {
|
|
t.Fatalf("expected 3 attempts, got %d", len(selectedKeyIDs))
|
|
}
|
|
if keyProviderCalls != 1 {
|
|
t.Fatalf("expected keyProvider to be called once for network retries, got %d", keyProviderCalls)
|
|
}
|
|
// All attempts should use the same key (network error = same key)
|
|
for i := 1; i < len(selectedKeyIDs); i++ {
|
|
if selectedKeyIDs[i] != selectedKeyIDs[0] {
|
|
t.Errorf("expected same key for all network-error retries, got %v", selectedKeyIDs)
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("CyclesFreshRoundWhenPoolExhausted", func(t *testing.T) {
|
|
var selectedKeyIDs []string
|
|
// 3 keys, 6 retries — should cycle through all 3 keys twice
|
|
config6 := createTestConfig(5, 0, 0) // 5 retries = 6 total attempts
|
|
keyProvider := func(usedKeyIDs map[string]bool) (schemas.Key, error) {
|
|
available := make([]schemas.Key, 0)
|
|
for _, k := range keys {
|
|
if !usedKeyIDs[k.ID] {
|
|
available = append(available, k)
|
|
}
|
|
}
|
|
if len(available) == 0 {
|
|
for id := range usedKeyIDs {
|
|
delete(usedKeyIDs, id)
|
|
}
|
|
available = keys
|
|
}
|
|
return available[0], nil
|
|
}
|
|
|
|
handler := func(k schemas.Key) (string, *schemas.BifrostError) {
|
|
selectedKeyIDs = append(selectedKeyIDs, k.ID)
|
|
return "", createBifrostError("rate limit exceeded", Ptr(429), nil, false)
|
|
}
|
|
|
|
executeRequestWithRetries(ctx, config6, handler, keyProvider,
|
|
schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", nil, logger)
|
|
|
|
if len(selectedKeyIDs) != 6 {
|
|
t.Fatalf("expected 6 attempts (1 initial + 5 retries), got %d", len(selectedKeyIDs))
|
|
}
|
|
// First cycle: k1, k2, k3; second cycle: k1, k2, k3
|
|
expected := []string{"k1", "k2", "k3", "k1", "k2", "k3"}
|
|
for i, id := range selectedKeyIDs {
|
|
if id != expected[i] {
|
|
t.Errorf("attempt %d: expected key %s, got %s (full sequence: %v)", i, expected[i], id, selectedKeyIDs)
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("NilKeyProviderUsesZeroKey", func(t *testing.T) {
|
|
cleanCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
cleanCtx.SetValue(schemas.BifrostContextKeyTracer, &schemas.NoOpTracer{})
|
|
|
|
var receivedKey schemas.Key
|
|
handler := func(k schemas.Key) (string, *schemas.BifrostError) {
|
|
receivedKey = k
|
|
return "ok", nil
|
|
}
|
|
|
|
result, err := executeRequestWithRetries(cleanCtx, config, handler, nil,
|
|
schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", nil, logger)
|
|
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if result != "ok" {
|
|
t.Errorf("expected 'ok', got %s", result)
|
|
}
|
|
if receivedKey.ID != "" {
|
|
t.Errorf("expected zero Key when keyProvider is nil, got ID=%s", receivedKey.ID)
|
|
}
|
|
if trail, ok := cleanCtx.Value(schemas.BifrostContextKeyAttemptTrail).([]schemas.KeyAttemptRecord); ok && len(trail) > 0 {
|
|
t.Fatalf("expected no attempt trail for nil keyProvider, got %v", trail)
|
|
}
|
|
if selectedID, _ := cleanCtx.Value(schemas.BifrostContextKeySelectedKeyID).(string); selectedID != "" {
|
|
t.Fatalf("expected empty selected key id, got %q", selectedID)
|
|
}
|
|
if selectedName, _ := cleanCtx.Value(schemas.BifrostContextKeySelectedKeyName).(string); selectedName != "" {
|
|
t.Fatalf("expected empty selected key name, got %q", selectedName)
|
|
}
|
|
})
|
|
}
|
|
|
|
// Test UpdateProvider functionality
|
|
func TestUpdateProvider(t *testing.T) {
|
|
t.Run("SuccessfulUpdate", func(t *testing.T) {
|
|
// Setup mock account with initial configuration
|
|
account := NewMockAccount()
|
|
account.AddProvider(schemas.OpenAI, 5, 1000)
|
|
|
|
// Initialize Bifrost
|
|
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
bifrost, err := Init(ctx, schemas.BifrostConfig{
|
|
Account: account,
|
|
Logger: NewDefaultLogger(schemas.LogLevelError), // Keep tests quiet
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to initialize Bifrost: %v", err)
|
|
}
|
|
|
|
// Verify initial provider exists
|
|
initialProvider := bifrost.getProviderByKey(schemas.OpenAI)
|
|
if initialProvider == nil {
|
|
t.Fatalf("Initial provider not found")
|
|
}
|
|
|
|
// Update configuration
|
|
account.UpdateProviderConfig(schemas.OpenAI, 10, 2000)
|
|
|
|
// Perform update
|
|
err = bifrost.UpdateProvider(schemas.OpenAI)
|
|
if err != nil {
|
|
t.Fatalf("UpdateProvider failed: %v", err)
|
|
}
|
|
|
|
// Verify provider was replaced
|
|
updatedProvider := bifrost.getProviderByKey(schemas.OpenAI)
|
|
if updatedProvider == nil {
|
|
t.Fatalf("Updated provider not found")
|
|
}
|
|
|
|
// Verify it's a different instance (provider should have been recreated)
|
|
if initialProvider == updatedProvider {
|
|
t.Errorf("Provider instance was not replaced - same memory address")
|
|
}
|
|
|
|
// Verify provider key is still correct
|
|
if updatedProvider.GetProviderKey() != schemas.OpenAI {
|
|
t.Errorf("Updated provider has wrong key: got %s, want %s",
|
|
updatedProvider.GetProviderKey(), schemas.OpenAI)
|
|
}
|
|
})
|
|
|
|
t.Run("UpdateNonExistentProvider", func(t *testing.T) {
|
|
// Setup account without the provider we'll try to update
|
|
account := NewMockAccount()
|
|
account.AddProvider(schemas.OpenAI, 5, 1000)
|
|
|
|
ctx := context.Background()
|
|
bifrost, err := Init(ctx, schemas.BifrostConfig{
|
|
Account: account,
|
|
Logger: NewDefaultLogger(schemas.LogLevelError),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to initialize Bifrost: %v", err)
|
|
}
|
|
|
|
// Try to update a provider not in the account
|
|
err = bifrost.UpdateProvider(schemas.Anthropic)
|
|
if err == nil {
|
|
t.Errorf("Expected error when updating non-existent provider, got nil")
|
|
}
|
|
|
|
// Verify error message
|
|
expectedErrMsg := "failed to get updated config for provider anthropic"
|
|
if err != nil && !strings.Contains(err.Error(), expectedErrMsg) {
|
|
t.Errorf("Expected error containing '%s', got: %v", expectedErrMsg, err)
|
|
}
|
|
})
|
|
|
|
t.Run("UpdateInactiveProvider", func(t *testing.T) {
|
|
// Setup account with provider but don't initialize it in Bifrost
|
|
account := NewMockAccount()
|
|
|
|
ctx := context.Background()
|
|
bifrost, err := Init(ctx, schemas.BifrostConfig{
|
|
Account: account,
|
|
Logger: NewDefaultLogger(schemas.LogLevelError),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to initialize Bifrost: %v", err)
|
|
}
|
|
|
|
// Verify provider doesn't exist initially
|
|
// Note: Use Ollama (not in dynamicallyConfigurableProviders) to test truly inactive provider
|
|
if bifrost.getProviderByKey(schemas.Ollama) != nil {
|
|
t.Fatal("Provider should not exist initially")
|
|
}
|
|
|
|
// Add provider to account after bifrost initialization
|
|
// Note: Ollama requires a BaseURL
|
|
account.AddProviderWithBaseURL(schemas.Ollama, 3, 500, "http://localhost:11434")
|
|
|
|
// Update should succeed and initialize the provider
|
|
err = bifrost.UpdateProvider(schemas.Ollama)
|
|
if err != nil {
|
|
t.Fatalf("UpdateProvider should succeed for inactive provider: %v", err)
|
|
}
|
|
|
|
// Verify provider now exists
|
|
provider := bifrost.getProviderByKey(schemas.Ollama)
|
|
if provider == nil {
|
|
t.Fatal("Provider should exist after update")
|
|
}
|
|
|
|
if provider.GetProviderKey() != schemas.Ollama {
|
|
t.Errorf("Provider has wrong key: got %s, want %s",
|
|
provider.GetProviderKey(), schemas.Ollama)
|
|
}
|
|
})
|
|
|
|
t.Run("MultipleProviderUpdates", func(t *testing.T) {
|
|
// Test updating multiple different providers
|
|
account := NewMockAccount()
|
|
account.AddProvider(schemas.OpenAI, 5, 1000)
|
|
account.AddProvider(schemas.Anthropic, 3, 500)
|
|
account.AddProvider(schemas.Cohere, 2, 200)
|
|
|
|
ctx := context.Background()
|
|
bifrost, err := Init(ctx, schemas.BifrostConfig{
|
|
Account: account,
|
|
Logger: NewDefaultLogger(schemas.LogLevelError),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to initialize Bifrost: %v", err)
|
|
}
|
|
|
|
// Get initial provider references
|
|
initialOpenAI := bifrost.getProviderByKey(schemas.OpenAI)
|
|
initialAnthropic := bifrost.getProviderByKey(schemas.Anthropic)
|
|
initialCohere := bifrost.getProviderByKey(schemas.Cohere)
|
|
|
|
// Update configurations
|
|
account.UpdateProviderConfig(schemas.OpenAI, 10, 2000)
|
|
account.UpdateProviderConfig(schemas.Anthropic, 6, 1000)
|
|
account.UpdateProviderConfig(schemas.Cohere, 4, 400)
|
|
|
|
// Update all providers
|
|
providers := []schemas.ModelProvider{schemas.OpenAI, schemas.Anthropic, schemas.Cohere}
|
|
for _, provider := range providers {
|
|
err = bifrost.UpdateProvider(provider)
|
|
if err != nil {
|
|
t.Fatalf("Failed to update provider %s: %v", provider, err)
|
|
}
|
|
}
|
|
|
|
// Verify all providers were replaced
|
|
newOpenAI := bifrost.getProviderByKey(schemas.OpenAI)
|
|
newAnthropic := bifrost.getProviderByKey(schemas.Anthropic)
|
|
newCohere := bifrost.getProviderByKey(schemas.Cohere)
|
|
|
|
if initialOpenAI == newOpenAI {
|
|
t.Error("OpenAI provider was not replaced")
|
|
}
|
|
if initialAnthropic == newAnthropic {
|
|
t.Error("Anthropic provider was not replaced")
|
|
}
|
|
if initialCohere == newCohere {
|
|
t.Error("Cohere provider was not replaced")
|
|
}
|
|
|
|
// Verify all providers still have correct keys
|
|
if newOpenAI.GetProviderKey() != schemas.OpenAI {
|
|
t.Error("OpenAI provider has wrong key after update")
|
|
}
|
|
if newAnthropic.GetProviderKey() != schemas.Anthropic {
|
|
t.Error("Anthropic provider has wrong key after update")
|
|
}
|
|
if newCohere.GetProviderKey() != schemas.Cohere {
|
|
t.Error("Cohere provider has wrong key after update")
|
|
}
|
|
})
|
|
|
|
t.Run("ConcurrentProviderUpdates", func(t *testing.T) {
|
|
// Test updating the same provider concurrently (should be serialized by mutex)
|
|
account := NewMockAccount()
|
|
account.AddProvider(schemas.OpenAI, 5, 1000)
|
|
|
|
ctx := context.Background()
|
|
bifrost, err := Init(ctx, schemas.BifrostConfig{
|
|
Account: account,
|
|
Logger: NewDefaultLogger(schemas.LogLevelError),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to initialize Bifrost: %v", err)
|
|
}
|
|
|
|
// Launch concurrent updates
|
|
const numConcurrentUpdates = 5
|
|
errChan := make(chan error, numConcurrentUpdates)
|
|
|
|
for i := 0; i < numConcurrentUpdates; i++ {
|
|
go func(updateNum int) {
|
|
// Update with slightly different config each time
|
|
account.UpdateProviderConfig(schemas.OpenAI, 5+updateNum, 1000+updateNum*100)
|
|
err := bifrost.UpdateProvider(schemas.OpenAI)
|
|
errChan <- err
|
|
}(i)
|
|
}
|
|
|
|
// Collect results
|
|
var errors []error
|
|
for i := 0; i < numConcurrentUpdates; i++ {
|
|
if err := <-errChan; err != nil {
|
|
errors = append(errors, err)
|
|
}
|
|
}
|
|
|
|
// All updates should succeed (mutex should serialize them)
|
|
if len(errors) > 0 {
|
|
t.Fatalf("Expected no errors from concurrent updates, got: %v", errors)
|
|
}
|
|
|
|
// Verify provider still exists and has correct key
|
|
provider := bifrost.getProviderByKey(schemas.OpenAI)
|
|
if provider == nil {
|
|
t.Fatal("Provider should exist after concurrent updates")
|
|
}
|
|
if provider.GetProviderKey() != schemas.OpenAI {
|
|
t.Error("Provider has wrong key after concurrent updates")
|
|
}
|
|
})
|
|
}
|
|
|
|
// Test provider slice management during updates
|
|
func TestUpdateProvider_ProviderSliceIntegrity(t *testing.T) {
|
|
t.Run("ProviderSliceConsistency", func(t *testing.T) {
|
|
account := NewMockAccount()
|
|
account.AddProvider(schemas.OpenAI, 5, 1000)
|
|
account.AddProvider(schemas.Anthropic, 3, 500)
|
|
|
|
ctx := context.Background()
|
|
bifrost, err := Init(ctx, schemas.BifrostConfig{
|
|
Account: account,
|
|
Logger: NewDefaultLogger(schemas.LogLevelError),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to initialize Bifrost: %v", err)
|
|
}
|
|
|
|
// Get initial provider count
|
|
initialProviders := bifrost.providers.Load()
|
|
initialCount := len(*initialProviders)
|
|
|
|
// Update one provider
|
|
account.UpdateProviderConfig(schemas.OpenAI, 10, 2000)
|
|
err = bifrost.UpdateProvider(schemas.OpenAI)
|
|
if err != nil {
|
|
t.Fatalf("UpdateProvider failed: %v", err)
|
|
}
|
|
|
|
// Verify provider count is the same (replacement, not addition)
|
|
updatedProviders := bifrost.providers.Load()
|
|
updatedCount := len(*updatedProviders)
|
|
|
|
if initialCount != updatedCount {
|
|
t.Errorf("Provider count changed: initial=%d, updated=%d", initialCount, updatedCount)
|
|
}
|
|
|
|
// Verify both providers still exist with correct keys
|
|
foundOpenAI := false
|
|
foundAnthropic := false
|
|
|
|
for _, provider := range *updatedProviders {
|
|
switch provider.GetProviderKey() {
|
|
case schemas.OpenAI:
|
|
foundOpenAI = true
|
|
case schemas.Anthropic:
|
|
foundAnthropic = true
|
|
}
|
|
}
|
|
|
|
if !foundOpenAI {
|
|
t.Error("OpenAI provider not found in providers slice after update")
|
|
}
|
|
if !foundAnthropic {
|
|
t.Error("Anthropic provider not found in providers slice after update")
|
|
}
|
|
})
|
|
|
|
t.Run("ProviderSliceNoMemoryLeaks", func(t *testing.T) {
|
|
account := NewMockAccount()
|
|
account.AddProvider(schemas.OpenAI, 5, 1000)
|
|
|
|
ctx := context.Background()
|
|
bifrost, err := Init(ctx, schemas.BifrostConfig{
|
|
Account: account,
|
|
Logger: NewDefaultLogger(schemas.LogLevelError),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to initialize Bifrost: %v", err)
|
|
}
|
|
|
|
// Perform multiple updates to ensure no memory leaks in provider slice
|
|
for i := 0; i < 10; i++ {
|
|
account.UpdateProviderConfig(schemas.OpenAI, 5+i, 1000+i*100)
|
|
err = bifrost.UpdateProvider(schemas.OpenAI)
|
|
if err != nil {
|
|
t.Fatalf("UpdateProvider failed on iteration %d: %v", i, err)
|
|
}
|
|
|
|
// Verify only one OpenAI provider exists
|
|
providers := bifrost.providers.Load()
|
|
openAICount := 0
|
|
for _, provider := range *providers {
|
|
if provider.GetProviderKey() == schemas.OpenAI {
|
|
openAICount++
|
|
}
|
|
}
|
|
|
|
if openAICount != 1 {
|
|
t.Fatalf("Expected exactly 1 OpenAI provider, found %d on iteration %d", openAICount, i)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestProviderQueue_SendOnClosedChannel_Race demonstrates the TOCTOU race that
|
|
// caused the "send on closed channel" production panic in the OLD code.
|
|
//
|
|
// The old code called close(pq.queue) during provider shutdown. The sequence:
|
|
// 1. Producer calls isClosing() → false (queue is still open)
|
|
// 2. Concurrently: shutdown calls signalClosing() then close(pq.queue)
|
|
// 3. Producer enters select { case pq.queue <- msg: ... case <-pq.done: ... }
|
|
// → PANIC: Go's selectgo iterates cases in a randomised pollorder. When the
|
|
// closed-channel send case is checked first, it immediately panics via
|
|
// goto sclose — before it can reach the done case.
|
|
// The case <-pq.done: guard only saves you when done happens to be checked
|
|
// first in that random ordering (≈50 % of the time with two cases).
|
|
//
|
|
// THE FIX: pq.queue is never closed. See the ProviderQueue struct comment for
|
|
// the full explanation. This test is kept as a proof-of-concept showing why
|
|
// closing pq.queue is unsafe; the fix is validated by TestProviderQueue_NoPanicWithoutCloseQueue.
|
|
//
|
|
// We run many iterations so that the panic is statistically certain to surface
|
|
// at least once, confirming the hypothesis.
|
|
func TestProviderQueue_SendOnClosedChannel_Race(t *testing.T) {
|
|
// With two select cases each iteration has a ~50 % chance of panicking.
|
|
// The probability of never panicking in 200 iterations is (0.5)^200 ≈ 0.
|
|
const iterations = 200
|
|
panicCount := 0
|
|
|
|
for i := 0; i < iterations; i++ {
|
|
func() {
|
|
pq := &ProviderQueue{
|
|
queue: make(chan *ChannelMessage, 10),
|
|
done: make(chan struct{}),
|
|
signalOnce: sync.Once{},
|
|
}
|
|
|
|
// Synchronization barriers to force the exact race interleaving.
|
|
passedIsClosingCheck := make(chan struct{})
|
|
queueClosed := make(chan struct{})
|
|
|
|
var panicked bool
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
|
|
// Producer — mirrors the hot path in tryRequest.
|
|
go func() {
|
|
defer wg.Done()
|
|
defer func() {
|
|
if r := recover(); r != nil && fmt.Sprint(r) == "send on closed channel" {
|
|
panicked = true
|
|
}
|
|
}()
|
|
|
|
// Step 1: isClosing() passes — queue is open.
|
|
if pq.isClosing() {
|
|
return
|
|
}
|
|
|
|
// Signal: past the isClosing() gate.
|
|
close(passedIsClosingCheck)
|
|
|
|
// Wait for the queue to be closed. This represents the real work
|
|
// tryRequest does between the isClosing() check and the select
|
|
// (MCP setup, tracer lookup, plugin pipeline acquisition).
|
|
<-queueClosed
|
|
|
|
// Step 2: enter the exact select guard used in production.
|
|
// pq.queue is closed AND pq.done is closed.
|
|
// When selectgo picks the send case first in its random pollorder
|
|
// it hits goto sclose and panics — the done case cannot save it.
|
|
msg := &ChannelMessage{}
|
|
select {
|
|
case pq.queue <- msg: // panics ~50 % of iterations
|
|
case <-pq.done: // selected the other ~50 %
|
|
}
|
|
}()
|
|
|
|
// Closer — mirrors UpdateProvider / RemoveProvider.
|
|
go func() {
|
|
<-passedIsClosingCheck
|
|
pq.signalClosing() // closes done, sets closing = 1
|
|
close(pq.queue)
|
|
close(queueClosed) // release the producer into the select
|
|
}()
|
|
|
|
wg.Wait()
|
|
if panicked {
|
|
panicCount++
|
|
}
|
|
}()
|
|
}
|
|
|
|
if panicCount == 0 {
|
|
t.Fatalf("expected at least one 'send on closed channel' panic across %d iterations, got none", iterations)
|
|
}
|
|
t.Logf("confirmed: panic triggered in %d / %d iterations — hypothesis is correct", panicCount, iterations)
|
|
}
|
|
|
|
// =============================================================================
|
|
// ProviderQueue Unit Tests
|
|
//
|
|
// These tests exercise the ProviderQueue lifecycle in isolation — no full
|
|
// Bifrost instance required. They validate the core safety invariants that
|
|
// prevent the "send on closed channel" panic.
|
|
// =============================================================================
|
|
|
|
// newTestChannelMessage creates a minimal ChannelMessage suitable for drain tests.
|
|
// The Err channel is buffered (size 1) so the worker can send without blocking.
|
|
func newTestChannelMessage(ctx *schemas.BifrostContext) *ChannelMessage {
|
|
return &ChannelMessage{
|
|
BifrostRequest: schemas.BifrostRequest{
|
|
RequestType: schemas.ChatCompletionRequest,
|
|
ChatRequest: &schemas.BifrostChatRequest{
|
|
Provider: schemas.OpenAI,
|
|
Model: "gpt-4",
|
|
},
|
|
},
|
|
Context: ctx,
|
|
Response: make(chan *schemas.BifrostResponse, 1),
|
|
Err: make(chan schemas.BifrostError, 1),
|
|
}
|
|
}
|
|
|
|
// TestProviderQueue_IsClosingStateTransition verifies the atomic state flag:
|
|
// isClosing() must return false before signalClosing() and true after.
|
|
func TestProviderQueue_IsClosingStateTransition(t *testing.T) {
|
|
pq := &ProviderQueue{
|
|
queue: make(chan *ChannelMessage, 10),
|
|
done: make(chan struct{}),
|
|
signalOnce: sync.Once{},
|
|
}
|
|
|
|
if pq.isClosing() {
|
|
t.Fatal("isClosing() must be false before signalClosing() is called")
|
|
}
|
|
|
|
pq.signalClosing()
|
|
|
|
if !pq.isClosing() {
|
|
t.Fatal("isClosing() must be true after signalClosing() is called")
|
|
}
|
|
|
|
// done channel must also be closed
|
|
select {
|
|
case <-pq.done:
|
|
// correct: done is closed
|
|
default:
|
|
t.Fatal("pq.done must be closed after signalClosing()")
|
|
}
|
|
|
|
// queue channel must remain OPEN — this is the core of the fix
|
|
// (sending should not panic even though done is closed)
|
|
panicked := false
|
|
func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
panicked = true
|
|
}
|
|
}()
|
|
select {
|
|
case pq.queue <- &ChannelMessage{}:
|
|
case <-pq.done: // done is closed so this is always ready — no panic
|
|
}
|
|
}()
|
|
if panicked {
|
|
t.Fatal("queue channel must stay open after signalClosing() — sending to it must not panic")
|
|
}
|
|
}
|
|
|
|
// TestProviderQueue_SignalOnceIdempotent verifies that calling signalClosing()
|
|
// multiple times is safe. sync.Once ensures done is only closed once and the
|
|
// atomic store only happens once — no "close of closed channel" panic.
|
|
func TestProviderQueue_SignalOnceIdempotent(t *testing.T) {
|
|
pq := &ProviderQueue{
|
|
queue: make(chan *ChannelMessage, 10),
|
|
done: make(chan struct{}),
|
|
signalOnce: sync.Once{},
|
|
}
|
|
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
t.Fatalf("unexpected panic from multiple signalClosing() calls: %v", r)
|
|
}
|
|
}()
|
|
|
|
pq.signalClosing()
|
|
pq.signalClosing()
|
|
pq.signalClosing()
|
|
|
|
if !pq.isClosing() {
|
|
t.Fatal("isClosing() must be true after multiple signalClosing() calls")
|
|
}
|
|
}
|
|
|
|
// TestProviderQueue_WorkerExitsViaDone verifies that a worker running the
|
|
// fixed select loop exits cleanly after signalClosing() without closeQueue().
|
|
// Before the fix, workers used `for req := range pq.queue` which required
|
|
// the channel to be closed. After the fix, done is the exit signal.
|
|
func TestProviderQueue_WorkerExitsViaDone(t *testing.T) {
|
|
pq := &ProviderQueue{
|
|
queue: make(chan *ChannelMessage, 10),
|
|
done: make(chan struct{}),
|
|
signalOnce: sync.Once{},
|
|
}
|
|
|
|
workerExited := make(chan struct{})
|
|
|
|
// Minimal worker loop — mirrors the exact select pattern in requestWorker
|
|
go func() {
|
|
defer close(workerExited)
|
|
for {
|
|
select {
|
|
case r, ok := <-pq.queue:
|
|
if !ok {
|
|
return
|
|
}
|
|
_ = r // process (no-op in this test)
|
|
case <-pq.done:
|
|
// Drain remaining buffered items (queue is empty here)
|
|
for {
|
|
select {
|
|
case <-pq.queue:
|
|
default:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Worker is now blocked on the select. Signal shutdown WITHOUT closing queue.
|
|
pq.signalClosing()
|
|
|
|
select {
|
|
case <-workerExited:
|
|
// correct: worker exited via done
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("worker did not exit after signalClosing() — it may be stuck on range over unclosed channel")
|
|
}
|
|
}
|
|
|
|
// TestProviderQueue_WorkerDrainSendsErrors verifies the drain behaviour when
|
|
// done fires while items are still buffered: every buffered ChannelMessage must
|
|
// receive a "provider is shutting down" error on its Err channel. No client
|
|
// should be left blocked waiting for a response that will never come.
|
|
//
|
|
// This test exercises the drain path directly — same code as requestWorker's
|
|
// case <-pq.done: branch — to avoid a non-deterministic select race between the
|
|
// normal processing path and the done path.
|
|
func TestProviderQueue_WorkerDrainSendsErrors(t *testing.T) {
|
|
const numBuffered = 5
|
|
|
|
pq := &ProviderQueue{
|
|
queue: make(chan *ChannelMessage, numBuffered+2),
|
|
done: make(chan struct{}),
|
|
signalOnce: sync.Once{},
|
|
}
|
|
|
|
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
|
|
// Pre-fill queue — simulates requests buffered when done fires
|
|
msgs := make([]*ChannelMessage, numBuffered)
|
|
for i := 0; i < numBuffered; i++ {
|
|
msgs[i] = newTestChannelMessage(ctx)
|
|
pq.queue <- msgs[i]
|
|
}
|
|
|
|
// Signal closing: done is now closed
|
|
pq.signalClosing()
|
|
|
|
// Execute the drain path synchronously — exactly what requestWorker does in
|
|
// the case <-pq.done: branch. This is deterministic: we know done is closed
|
|
// and the queue has numBuffered items.
|
|
<-pq.done // fires immediately since signalClosing was already called
|
|
drainLoop:
|
|
for {
|
|
select {
|
|
case r := <-pq.queue:
|
|
provKey, mod, _ := r.GetRequestFields()
|
|
r.Err <- schemas.BifrostError{
|
|
IsBifrostError: false,
|
|
Error: &schemas.ErrorField{
|
|
Message: "provider is shutting down",
|
|
},
|
|
ExtraFields: schemas.BifrostErrorExtraFields{
|
|
RequestType: r.RequestType,
|
|
Provider: provKey,
|
|
OriginalModelRequested: mod,
|
|
},
|
|
}
|
|
default:
|
|
break drainLoop
|
|
}
|
|
}
|
|
|
|
// Verify every message received a shutdown error
|
|
for i, msg := range msgs {
|
|
select {
|
|
case bifrostErr := <-msg.Err:
|
|
if bifrostErr.Error == nil {
|
|
t.Errorf("message %d: received nil Error field", i)
|
|
continue
|
|
}
|
|
if bifrostErr.Error.Message != "provider is shutting down" {
|
|
t.Errorf("message %d: expected 'provider is shutting down', got %q",
|
|
i, bifrostErr.Error.Message)
|
|
}
|
|
if bifrostErr.ExtraFields.Provider != schemas.OpenAI {
|
|
t.Errorf("message %d: expected provider %s, got %s",
|
|
i, schemas.OpenAI, bifrostErr.ExtraFields.Provider)
|
|
}
|
|
if bifrostErr.ExtraFields.RequestType != schemas.ChatCompletionRequest {
|
|
t.Errorf("message %d: expected requestType %v, got %v",
|
|
i, schemas.ChatCompletionRequest, bifrostErr.ExtraFields.RequestType)
|
|
}
|
|
default:
|
|
t.Errorf("message %d: no error received — client would be left hanging indefinitely", i)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestProviderQueue_NoPanicWithoutCloseQueue verifies that the fixed hot path
|
|
// — select { case pq.queue <- msg | case <-pq.done } — never panics when
|
|
// signalClosing() fires but the queue channel is NOT closed.
|
|
//
|
|
// This is the direct inverse of TestProviderQueue_SendOnClosedChannel_Race:
|
|
// that test proves the old code panics ~50% of the time; this test proves
|
|
// the fixed code panics 0% of the time.
|
|
func TestProviderQueue_NoPanicWithoutCloseQueue(t *testing.T) {
|
|
const iterations = 500
|
|
|
|
for i := 0; i < iterations; i++ {
|
|
func() {
|
|
pq := &ProviderQueue{
|
|
queue: make(chan *ChannelMessage, 10),
|
|
done: make(chan struct{}),
|
|
signalOnce: sync.Once{},
|
|
}
|
|
|
|
passedIsClosingCheck := make(chan struct{})
|
|
shutdownDone := make(chan struct{})
|
|
|
|
var panicked bool
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
|
|
// Producer: mirrors the tryRequest hot path after the fix.
|
|
// Passes isClosing(), waits for signalClosing, then sends.
|
|
// The queue channel is NEVER closed — only done is closed.
|
|
go func() {
|
|
defer wg.Done()
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
panicked = true
|
|
}
|
|
}()
|
|
|
|
if pq.isClosing() {
|
|
return
|
|
}
|
|
close(passedIsClosingCheck)
|
|
<-shutdownDone
|
|
|
|
msg := &ChannelMessage{}
|
|
select {
|
|
case pq.queue <- msg: // queue is open → safe to send
|
|
case <-pq.done: // done is closed → selected immediately
|
|
}
|
|
}()
|
|
|
|
// Closer: signal shutdown but never close the queue channel
|
|
go func() {
|
|
<-passedIsClosingCheck
|
|
pq.signalClosing() // closes done; does NOT close queue
|
|
close(shutdownDone)
|
|
}()
|
|
|
|
wg.Wait()
|
|
|
|
if panicked {
|
|
t.Errorf("iteration %d: unexpected panic — queue must not be closed in the fixed path", i)
|
|
}
|
|
}()
|
|
|
|
if t.Failed() {
|
|
return
|
|
}
|
|
}
|
|
|
|
t.Logf("confirmed: zero panics in %d iterations with the fix applied", iterations)
|
|
}
|
|
|
|
// =============================================================================
|
|
// UpdateProvider Lifecycle Tests
|
|
//
|
|
// These tests verify the three key invariants of the UpdateProvider fix:
|
|
// 1. New queue is stored BEFORE signalClosing fires (stale producers re-route)
|
|
// 2. Transfer happens BEFORE signalClosing (items go to new workers, not errored)
|
|
// 3. Concurrent producers + UpdateProvider produce zero panics
|
|
// =============================================================================
|
|
|
|
// TestUpdateProvider_StaleProducerReroutes verifies that a "stale producer" —
|
|
// a goroutine that fetched oldPq before UpdateProvider atomically replaced it —
|
|
// can transparently re-route to newPq when it later detects isClosing().
|
|
//
|
|
// The re-routing logic in tryRequest is:
|
|
//
|
|
// if pq.isClosing() {
|
|
// if newPq, err := bifrost.getProviderQueue(provider); err == nil && newPq != pq {
|
|
// pq = newPq // transparent re-route
|
|
// }
|
|
// }
|
|
//
|
|
// This test exercises that exact sequence without a full Bifrost instance.
|
|
func TestUpdateProvider_StaleProducerReroutes(t *testing.T) {
|
|
var requestQueues sync.Map
|
|
provider := schemas.OpenAI
|
|
|
|
oldPq := &ProviderQueue{
|
|
queue: make(chan *ChannelMessage, 10),
|
|
done: make(chan struct{}),
|
|
signalOnce: sync.Once{},
|
|
}
|
|
newPq := &ProviderQueue{
|
|
queue: make(chan *ChannelMessage, 10),
|
|
done: make(chan struct{}),
|
|
signalOnce: sync.Once{},
|
|
}
|
|
|
|
// Initial state: requestQueues holds oldPq
|
|
requestQueues.Store(provider, oldPq)
|
|
|
|
// Stale producer: fetched its reference before UpdateProvider ran
|
|
stalePq := oldPq
|
|
|
|
// Simulate UpdateProvider steps 2 + 4:
|
|
// Step 2: atomically replace — new producers now get newPq
|
|
requestQueues.Store(provider, newPq)
|
|
// Step 4: signal old closing — stale producers will detect this
|
|
oldPq.signalClosing()
|
|
|
|
// --- Stale producer detects isClosing and attempts re-route ---
|
|
var reroutedPq *ProviderQueue
|
|
if stalePq.isClosing() {
|
|
if val, ok := requestQueues.Load(provider); ok {
|
|
candidate := val.(*ProviderQueue)
|
|
if candidate != stalePq {
|
|
reroutedPq = candidate
|
|
}
|
|
}
|
|
}
|
|
|
|
if reroutedPq == nil {
|
|
t.Fatal("stale producer failed to re-route: re-route returned nil (check step ordering)")
|
|
}
|
|
if reroutedPq != newPq {
|
|
t.Fatal("stale producer re-routed to wrong queue: expected newPq")
|
|
}
|
|
if reroutedPq.isClosing() {
|
|
t.Fatal("re-routed queue is already closing — re-route is useless (newPq must be fresh)")
|
|
}
|
|
|
|
// Verify: sending to re-routed queue succeeds without panic
|
|
panicked := false
|
|
func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
panicked = true
|
|
}
|
|
}()
|
|
msg := &ChannelMessage{}
|
|
select {
|
|
case reroutedPq.queue <- msg:
|
|
case <-reroutedPq.done:
|
|
t.Error("newPq.done fired — newPq should be open")
|
|
}
|
|
}()
|
|
if panicked {
|
|
t.Fatal("panic while sending to re-routed queue — queue must not be closed")
|
|
}
|
|
}
|
|
|
|
// TestUpdateProvider_TransferOrdering verifies the ordering invariant:
|
|
// items are moved from oldPq to newPq BEFORE signalClosing(oldPq) is called.
|
|
//
|
|
// Observable consequence: during the entire transfer loop, oldPq.isClosing()
|
|
// must remain false. Only after transfer completes does signalClosing fire.
|
|
func TestUpdateProvider_TransferOrdering(t *testing.T) {
|
|
const numMessages = 8
|
|
|
|
oldPq := &ProviderQueue{
|
|
queue: make(chan *ChannelMessage, numMessages+2),
|
|
done: make(chan struct{}),
|
|
signalOnce: sync.Once{},
|
|
}
|
|
newPq := &ProviderQueue{
|
|
queue: make(chan *ChannelMessage, numMessages+2),
|
|
done: make(chan struct{}),
|
|
signalOnce: sync.Once{},
|
|
}
|
|
|
|
// Pre-fill oldPq — simulates buffered requests at the moment UpdateProvider runs
|
|
for i := 0; i < numMessages; i++ {
|
|
oldPq.queue <- &ChannelMessage{}
|
|
}
|
|
|
|
// Invariant check before transfer begins
|
|
if oldPq.isClosing() {
|
|
t.Fatal("invariant violated: oldPq already closing before transfer begins")
|
|
}
|
|
|
|
// Perform transfer, mirroring UpdateProvider step 3.
|
|
// Record whether isClosing() ever fired during the loop.
|
|
closingDuringTransfer := false
|
|
transferred := 0
|
|
for {
|
|
select {
|
|
case msg := <-oldPq.queue:
|
|
if oldPq.isClosing() {
|
|
closingDuringTransfer = true
|
|
}
|
|
newPq.queue <- msg
|
|
transferred++
|
|
default:
|
|
goto transferComplete
|
|
}
|
|
}
|
|
transferComplete:
|
|
|
|
if closingDuringTransfer {
|
|
t.Error("invariant violated: oldPq was already closing during transfer — " +
|
|
"signalClosing must fire AFTER the transfer loop completes")
|
|
}
|
|
|
|
// NOW signal closing, mirroring UpdateProvider step 4
|
|
oldPq.signalClosing()
|
|
|
|
if !oldPq.isClosing() {
|
|
t.Error("expected isClosing() == true after signalClosing()")
|
|
}
|
|
|
|
// All messages must have moved to newPq
|
|
if transferred != numMessages {
|
|
t.Errorf("expected %d messages transferred, got %d", numMessages, transferred)
|
|
}
|
|
if len(newPq.queue) != numMessages {
|
|
t.Errorf("expected %d messages in newPq after transfer, got %d", numMessages, len(newPq.queue))
|
|
}
|
|
if len(oldPq.queue) != 0 {
|
|
t.Errorf("expected 0 messages remaining in oldPq after transfer, got %d", len(oldPq.queue))
|
|
}
|
|
}
|
|
|
|
// TestUpdateProvider_NoPanicConcurrentAccess verifies that concurrent producers
|
|
// sending to a queue that is being replaced (UpdateProvider-style) never cause
|
|
// a "send on closed channel" panic.
|
|
//
|
|
// This test directly models the production scenario that triggered the bug:
|
|
// many goroutines continuously send to a ProviderQueue while UpdateProvider
|
|
// atomically swaps the queue and signals the old one closing. With the fix
|
|
// (queue channel is never closed), the select in producers is always safe.
|
|
func TestUpdateProvider_NoPanicConcurrentAccess(t *testing.T) {
|
|
const (
|
|
numProducers = 10
|
|
numUpdates = 30
|
|
producerRunTime = 300 * time.Millisecond
|
|
)
|
|
|
|
var requestQueues sync.Map
|
|
provider := schemas.OpenAI
|
|
|
|
makePq := func() *ProviderQueue {
|
|
return &ProviderQueue{
|
|
queue: make(chan *ChannelMessage, 200),
|
|
done: make(chan struct{}),
|
|
signalOnce: sync.Once{},
|
|
}
|
|
}
|
|
|
|
initialPq := makePq()
|
|
requestQueues.Store(provider, initialPq)
|
|
|
|
var panicCount int64
|
|
var transferDropCount int64
|
|
|
|
stop := make(chan struct{})
|
|
var producerWg sync.WaitGroup
|
|
|
|
// Drainer: continuously empties queues so producers never block on a full queue
|
|
drainStop := make(chan struct{})
|
|
go func() {
|
|
for {
|
|
select {
|
|
case <-drainStop:
|
|
return
|
|
default:
|
|
if val, ok := requestQueues.Load(provider); ok {
|
|
pq := val.(*ProviderQueue)
|
|
select {
|
|
case <-pq.queue:
|
|
default:
|
|
}
|
|
}
|
|
runtime.Gosched()
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Producers: continuously simulate the tryRequest hot path
|
|
for i := 0; i < numProducers; i++ {
|
|
producerWg.Add(1)
|
|
go func() {
|
|
defer producerWg.Done()
|
|
for {
|
|
select {
|
|
case <-stop:
|
|
return
|
|
default:
|
|
}
|
|
|
|
val, ok := requestQueues.Load(provider)
|
|
if !ok {
|
|
runtime.Gosched()
|
|
continue
|
|
}
|
|
pq := val.(*ProviderQueue)
|
|
|
|
func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
atomic.AddInt64(&panicCount, 1)
|
|
}
|
|
}()
|
|
|
|
// Re-route check (mirrors tryRequest)
|
|
if pq.isClosing() {
|
|
if newVal, ok2 := requestQueues.Load(provider); ok2 {
|
|
if candidate := newVal.(*ProviderQueue); candidate != pq {
|
|
pq = candidate
|
|
}
|
|
}
|
|
// If still closing (RemoveProvider path), just return
|
|
if pq.isClosing() {
|
|
return
|
|
}
|
|
}
|
|
|
|
msg := &ChannelMessage{}
|
|
select {
|
|
case pq.queue <- msg:
|
|
case <-pq.done:
|
|
case <-stop: // unblock immediately when the test signals stop
|
|
}
|
|
}()
|
|
|
|
runtime.Gosched()
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Updater: repeatedly performs UpdateProvider-style queue replacements
|
|
var updaterWg sync.WaitGroup
|
|
updaterWg.Add(1)
|
|
go func() {
|
|
defer updaterWg.Done()
|
|
for i := 0; i < numUpdates; i++ {
|
|
val, ok := requestQueues.Load(provider)
|
|
if !ok {
|
|
continue
|
|
}
|
|
oldPq := val.(*ProviderQueue)
|
|
newPq := makePq()
|
|
|
|
// Mirror production UpdateProvider step order exactly:
|
|
// Step 2: expose newPq first so stale producers can re-route to it
|
|
// once they see oldPq is closing.
|
|
requestQueues.Store(provider, newPq)
|
|
|
|
// Step 3: transfer buffered messages oldPq → newPq.
|
|
drain:
|
|
for {
|
|
select {
|
|
case msg := <-oldPq.queue:
|
|
select {
|
|
case newPq.queue <- msg:
|
|
default:
|
|
// newPq full during transfer — mirrors production cancel path.
|
|
atomic.AddInt64(&transferDropCount, 1)
|
|
}
|
|
default:
|
|
break drain
|
|
}
|
|
}
|
|
|
|
// Step 4: signal closing — producers holding a stale oldPq ref now
|
|
// re-route to newPq (already in the map from step 2).
|
|
oldPq.signalClosing()
|
|
|
|
time.Sleep(5 * time.Millisecond)
|
|
}
|
|
}()
|
|
|
|
time.Sleep(producerRunTime)
|
|
close(stop)
|
|
close(drainStop)
|
|
producerWg.Wait()
|
|
updaterWg.Wait()
|
|
|
|
if n := atomic.LoadInt64(&panicCount); n > 0 {
|
|
t.Errorf("detected %d panic(s) — fix did not eliminate the concurrent-access race", n)
|
|
} else {
|
|
t.Logf("confirmed: zero panics across %d producers + %d queue replacements over %v",
|
|
numProducers, numUpdates, producerRunTime)
|
|
}
|
|
if drops := atomic.LoadInt64(&transferDropCount); drops > 0 {
|
|
t.Logf("note: %d message(s) dropped during transfer (oldPq had >200 buffered items) — does not affect panic correctness", drops)
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// RemoveProvider Lifecycle Tests
|
|
//
|
|
// These tests verify the behavioral contract of RemoveProvider:
|
|
// 1. signalClosing() blocks new producers (isClosing() → true)
|
|
// 2. Buffered items in the queue get "provider is shutting down" errors
|
|
// 3. Workers exit cleanly and the WaitGroup reaches zero
|
|
// =============================================================================
|
|
|
|
// TestRemoveProvider_BlocksNewProducers verifies that after signalClosing(),
|
|
// isClosing() returns true. Producers check this flag before sending and return
|
|
// a "provider is shutting down" error rather than trying to enqueue.
|
|
func TestRemoveProvider_BlocksNewProducers(t *testing.T) {
|
|
pq := &ProviderQueue{
|
|
queue: make(chan *ChannelMessage, 10),
|
|
done: make(chan struct{}),
|
|
signalOnce: sync.Once{},
|
|
}
|
|
|
|
// Sanity: before shutdown, producers can proceed
|
|
if pq.isClosing() {
|
|
t.Fatal("isClosing() must be false before RemoveProvider runs")
|
|
}
|
|
|
|
// RemoveProvider step 2: signal closing
|
|
pq.signalClosing()
|
|
|
|
// New producers must see isClosing() == true and abort
|
|
if !pq.isClosing() {
|
|
t.Fatal("isClosing() must be true after signalClosing() (RemoveProvider)")
|
|
}
|
|
|
|
// done must be closed so any producer blocked in the select unblocks immediately
|
|
select {
|
|
case <-pq.done:
|
|
// correct
|
|
default:
|
|
t.Fatal("pq.done must be closed after signalClosing() so blocking producers unblock")
|
|
}
|
|
|
|
// CRITICAL: queue channel must remain OPEN — closing it would cause panics in
|
|
// any producer that entered the select before seeing isClosing().
|
|
// With the fix, we NEVER close the queue channel.
|
|
panicked := false
|
|
func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
panicked = true
|
|
}
|
|
}()
|
|
// A select with done closed always takes the done case — safe, no panic
|
|
select {
|
|
case pq.queue <- &ChannelMessage{}:
|
|
case <-pq.done:
|
|
}
|
|
}()
|
|
if panicked {
|
|
t.Fatal("queue channel must stay open after signalClosing() — closing it causes panics")
|
|
}
|
|
}
|
|
|
|
// TestRemoveProvider_BufferedRequestsGetErrors verifies the drain contract:
|
|
// items queued BEFORE signalClosing fires must each receive a
|
|
// "provider is shutting down" error on their Err channel. No client should be
|
|
// left hanging.
|
|
//
|
|
// This test exercises the drain logic directly — the same code path that
|
|
// requestWorker executes in its case <-pq.done: branch — to avoid the
|
|
// non-deterministic select race where the normal processing path can pick up
|
|
// items before done fires.
|
|
func TestRemoveProvider_BufferedRequestsGetErrors(t *testing.T) {
|
|
const numBuffered = 8
|
|
|
|
pq := &ProviderQueue{
|
|
queue: make(chan *ChannelMessage, numBuffered+5),
|
|
done: make(chan struct{}),
|
|
signalOnce: sync.Once{},
|
|
}
|
|
|
|
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
|
|
// Buffer requests — simulates requests already queued when RemoveProvider runs
|
|
msgs := make([]*ChannelMessage, numBuffered)
|
|
for i := 0; i < numBuffered; i++ {
|
|
msgs[i] = newTestChannelMessage(ctx)
|
|
pq.queue <- msgs[i]
|
|
}
|
|
|
|
// RemoveProvider step 2: signal closing
|
|
pq.signalClosing()
|
|
|
|
// Execute the drain path — exactly what requestWorker does in case <-pq.done:
|
|
<-pq.done // fires immediately since signalClosing was already called
|
|
drainLoop:
|
|
for {
|
|
select {
|
|
case r := <-pq.queue:
|
|
provKey, mod, _ := r.GetRequestFields()
|
|
r.Err <- schemas.BifrostError{
|
|
IsBifrostError: false,
|
|
Error: &schemas.ErrorField{
|
|
Message: "provider is shutting down",
|
|
},
|
|
ExtraFields: schemas.BifrostErrorExtraFields{
|
|
RequestType: r.RequestType,
|
|
Provider: provKey,
|
|
OriginalModelRequested: mod,
|
|
},
|
|
}
|
|
default:
|
|
break drainLoop
|
|
}
|
|
}
|
|
|
|
// Every buffered message must have received a shutdown error
|
|
for i, msg := range msgs {
|
|
select {
|
|
case bifrostErr := <-msg.Err:
|
|
if bifrostErr.Error == nil {
|
|
t.Errorf("message %d: got nil Error field in BifrostError", i)
|
|
continue
|
|
}
|
|
if bifrostErr.Error.Message != "provider is shutting down" {
|
|
t.Errorf("message %d: expected 'provider is shutting down', got %q",
|
|
i, bifrostErr.Error.Message)
|
|
}
|
|
if bifrostErr.ExtraFields.Provider != schemas.OpenAI {
|
|
t.Errorf("message %d: expected provider %s, got %s",
|
|
i, schemas.OpenAI, bifrostErr.ExtraFields.Provider)
|
|
}
|
|
if bifrostErr.ExtraFields.RequestType != schemas.ChatCompletionRequest {
|
|
t.Errorf("message %d: expected requestType %v, got %v",
|
|
i, schemas.ChatCompletionRequest, bifrostErr.ExtraFields.RequestType)
|
|
}
|
|
default:
|
|
t.Errorf("message %d: no error received — client would be left hanging indefinitely", i)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestRemoveProvider_WorkerWaitGroupCompletes verifies that after signalClosing(),
|
|
// the worker goroutine decrements the WaitGroup and wg.Wait() returns promptly.
|
|
// This mirrors what RemoveProvider does: signal, then Wait() before cleanup.
|
|
func TestRemoveProvider_WorkerWaitGroupCompletes(t *testing.T) {
|
|
pq := &ProviderQueue{
|
|
queue: make(chan *ChannelMessage, 10),
|
|
done: make(chan struct{}),
|
|
signalOnce: sync.Once{},
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
|
|
// Worker goroutine — mirrors requestWorker's WaitGroup contract
|
|
go func() {
|
|
defer wg.Done()
|
|
for {
|
|
select {
|
|
case r, ok := <-pq.queue:
|
|
if !ok {
|
|
return
|
|
}
|
|
_ = r
|
|
case <-pq.done:
|
|
// Drain remaining (empty in this test)
|
|
for {
|
|
select {
|
|
case <-pq.queue:
|
|
default:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Tiny sleep to ensure worker is parked on select before we signal
|
|
time.Sleep(10 * time.Millisecond)
|
|
|
|
// RemoveProvider step 2: signal closing
|
|
pq.signalClosing()
|
|
|
|
// RemoveProvider step 3: wait for workers — must complete promptly
|
|
waitReturned := make(chan struct{})
|
|
go func() {
|
|
wg.Wait()
|
|
close(waitReturned)
|
|
}()
|
|
|
|
select {
|
|
case <-waitReturned:
|
|
// correct: WaitGroup reached zero after signalClosing()
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("wg.Wait() did not return after signalClosing() — worker is stuck (would deadlock RemoveProvider)")
|
|
}
|
|
}
|
|
|
|
// TestRemoveProvider_ConcurrentNewProducersDuringShutdown verifies that
|
|
// concurrent producers trying to enqueue after RemoveProvider calls
|
|
// signalClosing() all get safe "provider is shutting down" errors — none panic.
|
|
// This tests the TOCTOU window: producer passes isClosing() check, then done fires.
|
|
func TestRemoveProvider_ConcurrentNewProducersDuringShutdown(t *testing.T) {
|
|
const numProducers = 50
|
|
|
|
pq := &ProviderQueue{
|
|
queue: make(chan *ChannelMessage, numProducers+10),
|
|
done: make(chan struct{}),
|
|
signalOnce: sync.Once{},
|
|
}
|
|
|
|
var panicCount int64
|
|
var shutdownErrors int64
|
|
var successfulSends int64
|
|
|
|
// Gate: all producers start together after isClosing() passes
|
|
passedGate := make(chan struct{})
|
|
var gateOnce sync.Once
|
|
shutdownFired := make(chan struct{})
|
|
|
|
var producerWg sync.WaitGroup
|
|
|
|
for i := 0; i < numProducers; i++ {
|
|
producerWg.Add(1)
|
|
go func() {
|
|
defer producerWg.Done()
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
atomic.AddInt64(&panicCount, 1)
|
|
}
|
|
}()
|
|
|
|
// Each producer checks isClosing() first (mirrors tryRequest)
|
|
if pq.isClosing() {
|
|
atomic.AddInt64(&shutdownErrors, 1)
|
|
return
|
|
}
|
|
|
|
// Signal that at least one producer passed the isClosing() check
|
|
gateOnce.Do(func() { close(passedGate) })
|
|
|
|
// Wait for shutdown to be signaled (the TOCTOU window)
|
|
<-shutdownFired
|
|
|
|
// Producers now enter the select — with the fix, done is closed but
|
|
// queue is NOT closed, so this select is always safe (no panic)
|
|
msg := &ChannelMessage{}
|
|
select {
|
|
case pq.queue <- msg:
|
|
atomic.AddInt64(&successfulSends, 1)
|
|
case <-pq.done:
|
|
atomic.AddInt64(&shutdownErrors, 1)
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Wait for at least one producer to pass the isClosing() gate
|
|
select {
|
|
case <-passedGate:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("no producer passed the isClosing() check within timeout")
|
|
}
|
|
|
|
// Signal shutdown (RemoveProvider step 2) — this is the TOCTOU race
|
|
pq.signalClosing()
|
|
close(shutdownFired)
|
|
|
|
producerWg.Wait()
|
|
|
|
if n := atomic.LoadInt64(&panicCount); n > 0 {
|
|
t.Errorf("detected %d panic(s) — queue must not be closed during concurrent shutdown", n)
|
|
}
|
|
|
|
t.Logf("result: %d successful sends, %d shutdown errors, %d panics across %d producers",
|
|
atomic.LoadInt64(&successfulSends),
|
|
atomic.LoadInt64(&shutdownErrors),
|
|
atomic.LoadInt64(&panicCount),
|
|
numProducers)
|
|
}
|