Files
bifrost/core/internal/llmtests/passthrough_api.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

252 lines
7.8 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package llmtests
import (
"context"
"fmt"
"os"
"testing"
"github.com/bytedance/sonic"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/providers/anthropic"
"github.com/maximhq/bifrost/core/providers/gemini"
"github.com/maximhq/bifrost/core/providers/openai"
"github.com/maximhq/bifrost/core/schemas"
)
// passthroughChatReq holds the provider-native path and JSON body for a
// minimal one-turn chat request used by the passthrough API tests.
type passthroughChatReq struct {
path string
body []byte
query string
}
// basePassthroughChatRequest returns a minimal BifrostChatRequest suitable for
// conversion into a provider-native passthrough body.
func basePassthroughChatRequest(model string) *schemas.BifrostChatRequest {
return &schemas.BifrostChatRequest{
Model: model,
Input: []schemas.ChatMessage{
CreateBasicChatMessage("Say hello in one word"),
},
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(300),
},
}
}
// buildPassthroughChatReq converts a minimal BifrostChatRequest into the
// provider-native HTTP path and JSON body using each provider's own converter.
//
// Streaming is requested when stream is true.
// Returns (req, true) for supported providers, (zero, false) to signal skip.
func buildPassthroughChatReq(t *testing.T, provider schemas.ModelProvider, model string, stream bool) (passthroughChatReq, bool) {
bfReq := basePassthroughChatRequest(model)
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
switch provider {
case schemas.OpenAI:
nativeReq := openai.ToOpenAIChatRequest(ctx, bfReq)
if stream {
nativeReq.Stream = bifrost.Ptr(true)
}
body, err := sonic.Marshal(nativeReq)
if err != nil {
t.Fatalf("openai: failed to marshal passthrough chat request: %v", err)
}
return passthroughChatReq{path: "/v1/chat/completions", body: body}, true
case schemas.Azure:
nativeReq := openai.ToOpenAIChatRequest(ctx, bfReq)
if stream {
nativeReq.Stream = bifrost.Ptr(true)
}
body, err := sonic.Marshal(nativeReq)
if err != nil {
t.Fatalf("azure: failed to marshal passthrough chat request: %v", err)
}
// Azure passthrough expects the deployment-based path; api-version is
// injected automatically by buildPassthroughURL from the key config.
return passthroughChatReq{path: fmt.Sprintf("/openai/deployments/%s/chat/completions", model), body: body}, true
case schemas.Anthropic:
nativeReq, err := anthropic.ToAnthropicChatRequest(ctx, bfReq)
if err != nil {
return passthroughChatReq{}, false
}
if stream {
nativeReq.Stream = bifrost.Ptr(true)
}
body, err := sonic.Marshal(nativeReq)
if err != nil {
t.Fatalf("anthropic: failed to marshal passthrough chat request: %v", err)
}
return passthroughChatReq{path: "/v1/messages", body: body}, true
case schemas.Gemini:
nativeReq, err := gemini.ToGeminiChatCompletionRequest(bfReq)
if err != nil {
return passthroughChatReq{}, false
}
body, err := sonic.Marshal(nativeReq)
if err != nil {
t.Fatalf("gemini: failed to marshal passthrough chat request: %v", err)
}
endpoint := ":generateContent"
query := ""
if stream {
endpoint = ":streamGenerateContent"
query = "alt=sse"
}
req := passthroughChatReq{
path: fmt.Sprintf("/models/%s%s", model, endpoint),
body: body,
}
if query != "" {
req.query = query
}
return req, true
default:
return passthroughChatReq{}, false
}
}
// resolvePassthroughModel returns the model to use for passthrough tests:
// PassthroughModel if set, otherwise ChatModel.
func resolvePassthroughModel(cfg ComprehensiveTestConfig) string {
if cfg.PassthroughModel != "" {
return cfg.PassthroughModel
}
return cfg.ChatModel
}
// RunPassthroughAPITest exercises Bifrost's raw HTTP passthrough API for the
// configured provider using two sub-tests:
//
// - PassthroughAPI/NonStream calls client.Passthrough and verifies a 2xx
// response with a non-empty body and correct ExtraFields.
// - PassthroughAPI/Stream calls client.PassthroughStream and verifies
// that at least one chunk with body data is received.
//
// The test is skipped when Scenarios.PassthroughAPI is false or the provider's
// native request format is not yet covered by buildPassthroughChatReq.
func RunPassthroughAPITest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.PassthroughAPI {
t.Logf("PassthroughAPI not enabled for provider %s, skipping", testConfig.Provider)
return
}
model := resolvePassthroughModel(testConfig)
if model == "" {
t.Logf("No model configured for PassthroughAPI test on provider %s, skipping", testConfig.Provider)
return
}
t.Run("PassthroughAPI", func(t *testing.T) {
t.Run("NonStream", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
req, ok := buildPassthroughChatReq(t, testConfig.Provider, model, false)
if !ok {
t.Skipf("PassthroughAPI/NonStream: no native request format defined for provider %s", testConfig.Provider)
}
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
resp, bifrostErr := client.Passthrough(bfCtx, testConfig.Provider, &schemas.BifrostPassthroughRequest{
Method: "POST",
Path: req.path,
Body: req.body,
SafeHeaders: map[string]string{
"content-type": "application/json",
},
Model: model,
})
if bifrostErr != nil {
t.Fatalf("❌ Passthrough request failed: %s", GetErrorMessage(bifrostErr))
}
if resp == nil {
t.Fatal("❌ Passthrough response is nil")
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
t.Fatalf("❌ Passthrough returned non-2xx status %d; body: %s", resp.StatusCode, string(resp.Body))
}
if len(resp.Body) == 0 {
t.Fatal("❌ Passthrough response body is empty")
}
if resp.ExtraFields.Provider == "" {
t.Error("❌ ExtraFields.Provider is empty")
}
if resp.ExtraFields.Latency <= 0 {
t.Error("❌ ExtraFields.Latency is not positive")
}
if resp.ExtraFields.RequestType != schemas.PassthroughRequest {
t.Errorf("❌ ExtraFields.RequestType = %q, want %q", resp.ExtraFields.RequestType, schemas.PassthroughRequest)
}
t.Logf("✅ Passthrough non-streaming OK: status=%d body_len=%d latency=%dms",
resp.StatusCode, len(resp.Body), resp.ExtraFields.Latency)
})
t.Run("Stream", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
req, ok := buildPassthroughChatReq(t, testConfig.Provider, model, true)
if !ok {
t.Skipf("PassthroughAPI/Stream: no native request format defined for provider %s", testConfig.Provider)
}
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
ch, bifrostErr := client.PassthroughStream(bfCtx, testConfig.Provider, &schemas.BifrostPassthroughRequest{
Method: "POST",
Path: req.path,
Body: req.body,
RawQuery: req.query,
SafeHeaders: map[string]string{
"content-type": "application/json",
},
Model: model,
})
if bifrostErr != nil {
t.Fatalf("❌ PassthroughStream failed: %s", GetErrorMessage(bifrostErr))
}
if ch == nil {
t.Fatal("❌ PassthroughStream returned nil channel")
}
var totalBytes int
var chunkCount int
for chunk := range ch {
if chunk == nil {
continue
}
if chunk.BifrostError != nil {
t.Fatalf("❌ Stream chunk contained error: %s", GetErrorMessage(chunk.BifrostError))
}
if chunk.BifrostPassthroughResponse != nil {
totalBytes += len(chunk.BifrostPassthroughResponse.Body)
if len(chunk.BifrostPassthroughResponse.Body) > 0 {
chunkCount++
}
}
}
if chunkCount == 0 {
t.Fatal("❌ PassthroughStream received no chunks with body data")
}
t.Logf("✅ Passthrough streaming OK: %d chunks, %d total bytes", chunkCount, totalBytes)
})
})
}