first commit
This commit is contained in:
110
core/providers/vllm/chat_test.go
Normal file
110
core/providers/vllm/chat_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package vllm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// TestChatCompletion_ExtraParamsForwardedAutomatically verifies that provider-specific
|
||||
// extra params (e.g. chat_template_kwargs) are forwarded to vLLM without requiring
|
||||
// the caller to set BifrostContextKeyPassthroughExtraParams on the context.
|
||||
func TestChatCompletion_ExtraParamsForwardedAutomatically(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var capturedBody map[string]interface{}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "read error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if err := json.Unmarshal(body, &capturedBody); err != nil {
|
||||
http.Error(w, "json error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, `{
|
||||
"id": "chatcmpl-test",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "gemma",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": "Hello!"},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {"prompt_tokens": 5, "completion_tokens": 3, "total_tokens": 8}
|
||||
}`)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := newTestVLLMProvider()
|
||||
key := schemas.Key{
|
||||
ID: "test-key",
|
||||
Value: schemas.EnvVar{Val: "test-api-key"},
|
||||
VLLMKeyConfig: &schemas.VLLMKeyConfig{
|
||||
URL: schemas.EnvVar{Val: server.URL},
|
||||
},
|
||||
}
|
||||
|
||||
// Intentionally do NOT set BifrostContextKeyPassthroughExtraParams — the provider
|
||||
// should set it automatically.
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
|
||||
hello := "Hello"
|
||||
req := &schemas.BifrostChatRequest{
|
||||
Provider: schemas.VLLM,
|
||||
Model: "gemma",
|
||||
Input: []schemas.ChatMessage{
|
||||
{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{ContentStr: &hello},
|
||||
},
|
||||
},
|
||||
Params: &schemas.ChatParameters{
|
||||
ExtraParams: map[string]interface{}{
|
||||
"chat_template_kwargs": map[string]interface{}{
|
||||
"enable_thinking": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, bifrostErr := provider.ChatCompletion(ctx, key, req)
|
||||
if bifrostErr != nil {
|
||||
t.Fatalf("ChatCompletion returned error: %v", bifrostErr.Error.Message)
|
||||
}
|
||||
|
||||
if capturedBody == nil {
|
||||
t.Fatal("mock server did not receive a request body")
|
||||
}
|
||||
|
||||
rawKwargs, ok := capturedBody["chat_template_kwargs"]
|
||||
if !ok {
|
||||
t.Fatalf("chat_template_kwargs missing from outgoing request body; got keys: %v", keys(capturedBody))
|
||||
}
|
||||
|
||||
kwargsMap, ok := rawKwargs.(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected chat_template_kwargs to be an object, got %T", rawKwargs)
|
||||
}
|
||||
if kwargsMap["enable_thinking"] != true {
|
||||
t.Fatalf("expected enable_thinking=true, got %v", kwargsMap["enable_thinking"])
|
||||
}
|
||||
}
|
||||
|
||||
func keys(m map[string]interface{}) []string {
|
||||
out := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
out = append(out, k)
|
||||
}
|
||||
return out
|
||||
}
|
||||
216
core/providers/vllm/list_models_test.go
Normal file
216
core/providers/vllm/list_models_test.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package vllm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// newTestVLLMProvider creates a VLLMProvider suitable for unit tests.
|
||||
// It uses a short timeout and no base URL (per-key URLs are expected).
|
||||
func newTestVLLMProvider() *VLLMProvider {
|
||||
return &VLLMProvider{
|
||||
client: &fasthttp.Client{
|
||||
ReadTimeout: 5 * time.Second,
|
||||
WriteTimeout: 5 * time.Second,
|
||||
},
|
||||
networkConfig: schemas.NetworkConfig{},
|
||||
}
|
||||
}
|
||||
|
||||
// modelsJSON returns a minimal OpenAI-compatible /v1/models response listing the given model IDs.
|
||||
func modelsJSON(ids ...string) string {
|
||||
data := ""
|
||||
for i, id := range ids {
|
||||
if i > 0 {
|
||||
data += ","
|
||||
}
|
||||
data += fmt.Sprintf(`{"id":%q,"object":"model","owned_by":"vllm"}`, id)
|
||||
}
|
||||
return fmt.Sprintf(`{"object":"list","data":[%s]}`, data)
|
||||
}
|
||||
|
||||
func TestListModels_QueriesAllBackends(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Spin up two mock vLLM servers, each serving a different model.
|
||||
var hits1, hits2 atomic.Int32
|
||||
|
||||
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hits1.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, modelsJSON("model-from-backend-1"))
|
||||
}))
|
||||
defer server1.Close()
|
||||
|
||||
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hits2.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, modelsJSON("model-from-backend-2"))
|
||||
}))
|
||||
defer server2.Close()
|
||||
|
||||
provider := newTestVLLMProvider()
|
||||
|
||||
keys := []schemas.Key{
|
||||
{
|
||||
ID: "key-1",
|
||||
Value: schemas.EnvVar{Val: "test-api-key-1"},
|
||||
VLLMKeyConfig: &schemas.VLLMKeyConfig{
|
||||
URL: schemas.EnvVar{Val: server1.URL},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "key-2",
|
||||
Value: schemas.EnvVar{Val: "test-api-key-2"},
|
||||
VLLMKeyConfig: &schemas.VLLMKeyConfig{
|
||||
URL: schemas.EnvVar{Val: server2.URL},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
request := &schemas.BifrostListModelsRequest{
|
||||
Provider: schemas.VLLM,
|
||||
Unfiltered: true,
|
||||
}
|
||||
|
||||
resp, bifrostErr := provider.ListModels(ctx, keys, request)
|
||||
if bifrostErr != nil {
|
||||
t.Fatalf("ListModels returned error: %v", bifrostErr.Error)
|
||||
}
|
||||
|
||||
// Both backends must have been queried.
|
||||
if hits1.Load() != 1 {
|
||||
t.Errorf("expected backend 1 to be queried once, got %d", hits1.Load())
|
||||
}
|
||||
if hits2.Load() != 1 {
|
||||
t.Errorf("expected backend 2 to be queried once, got %d", hits2.Load())
|
||||
}
|
||||
|
||||
// Response must contain models from both backends.
|
||||
found := map[string]bool{}
|
||||
for _, m := range resp.Data {
|
||||
found[m.ID] = true
|
||||
}
|
||||
// Model IDs are prefixed with "vllm/" by ToBifrostListModelsResponse.
|
||||
if !found["vllm/model-from-backend-1"] {
|
||||
t.Errorf("response missing vllm/model-from-backend-1, got models: %v", resp.Data)
|
||||
}
|
||||
if !found["vllm/model-from-backend-2"] {
|
||||
t.Errorf("response missing vllm/model-from-backend-2, got models: %v", resp.Data)
|
||||
}
|
||||
|
||||
// KeyStatuses should report success for both keys.
|
||||
if len(resp.KeyStatuses) != 2 {
|
||||
t.Fatalf("expected 2 key statuses, got %d", len(resp.KeyStatuses))
|
||||
}
|
||||
for _, ks := range resp.KeyStatuses {
|
||||
if ks.Status != schemas.KeyStatusSuccess {
|
||||
t.Errorf("key %s status = %s, want success", ks.KeyID, ks.Status)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestListModels_SingleBackendFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// One healthy backend, one that returns 500.
|
||||
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, modelsJSON("healthy-model"))
|
||||
}))
|
||||
defer server1.Close()
|
||||
|
||||
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
fmt.Fprint(w, `{"error":{"message":"internal error","type":"server_error"}}`)
|
||||
}))
|
||||
defer server2.Close()
|
||||
|
||||
provider := newTestVLLMProvider()
|
||||
|
||||
keys := []schemas.Key{
|
||||
{
|
||||
ID: "good-key",
|
||||
Value: schemas.EnvVar{Val: "key1"},
|
||||
VLLMKeyConfig: &schemas.VLLMKeyConfig{
|
||||
URL: schemas.EnvVar{Val: server1.URL},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "bad-key",
|
||||
Value: schemas.EnvVar{Val: "key2"},
|
||||
VLLMKeyConfig: &schemas.VLLMKeyConfig{
|
||||
URL: schemas.EnvVar{Val: server2.URL},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
request := &schemas.BifrostListModelsRequest{
|
||||
Provider: schemas.VLLM,
|
||||
Unfiltered: true,
|
||||
}
|
||||
|
||||
resp, bifrostErr := provider.ListModels(ctx, keys, request)
|
||||
if bifrostErr != nil {
|
||||
t.Fatalf("ListModels should not return a top-level error when one backend succeeds, got: %v", bifrostErr.Error)
|
||||
}
|
||||
|
||||
// Models from the healthy backend should still appear.
|
||||
found := false
|
||||
for _, m := range resp.Data {
|
||||
if m.ID == "vllm/healthy-model" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("response missing healthy-model from the working backend")
|
||||
}
|
||||
|
||||
// KeyStatuses should reflect partial success.
|
||||
statusByKey := map[string]schemas.KeyStatusType{}
|
||||
for _, ks := range resp.KeyStatuses {
|
||||
statusByKey[ks.KeyID] = ks.Status
|
||||
}
|
||||
if statusByKey["good-key"] != schemas.KeyStatusSuccess {
|
||||
t.Errorf("good-key status = %s, want success", statusByKey["good-key"])
|
||||
}
|
||||
if statusByKey["bad-key"] == schemas.KeyStatusSuccess {
|
||||
t.Errorf("bad-key should not have success status, got %s", statusByKey["bad-key"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestListModels_ErrorsWithoutPerKeyURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Keys without VLLMKeyConfig should error — there is no provider-level fallback.
|
||||
provider := newTestVLLMProvider()
|
||||
|
||||
keys := []schemas.Key{
|
||||
{
|
||||
ID: "no-config-key",
|
||||
Value: schemas.EnvVar{Val: "api-key"},
|
||||
// No VLLMKeyConfig — should produce an error.
|
||||
},
|
||||
}
|
||||
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
request := &schemas.BifrostListModelsRequest{
|
||||
Provider: schemas.VLLM,
|
||||
Unfiltered: true,
|
||||
}
|
||||
|
||||
_, bifrostErr := provider.ListModels(ctx, keys, request)
|
||||
if bifrostErr == nil {
|
||||
t.Fatal("expected error for key without vllm_key_config.url, got nil")
|
||||
}
|
||||
}
|
||||
17
core/providers/vllm/models.go
Normal file
17
core/providers/vllm/models.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package vllm
|
||||
|
||||
// vLLMRerankRequest is the vLLM rerank request body.
|
||||
type vLLMRerankRequest struct {
|
||||
Model string `json:"model"`
|
||||
Query string `json:"query"`
|
||||
Documents []string `json:"documents"`
|
||||
TopN *int `json:"top_n,omitempty"`
|
||||
MaxTokensPerDoc *int `json:"max_tokens_per_doc,omitempty"`
|
||||
Priority *int `json:"priority,omitempty"`
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// GetExtraParams returns passthrough parameters for providerUtils.CheckContextAndGetRequestBody.
|
||||
func (r *vLLMRerankRequest) GetExtraParams() map[string]interface{} {
|
||||
return r.ExtraParams
|
||||
}
|
||||
149
core/providers/vllm/rerank.go
Normal file
149
core/providers/vllm/rerank.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package vllm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToVLLMRerankRequest converts a Bifrost rerank request to vLLM format.
|
||||
func ToVLLMRerankRequest(bifrostReq *schemas.BifrostRerankRequest) *vLLMRerankRequest {
|
||||
if bifrostReq == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
vllmReq := &vLLMRerankRequest{
|
||||
Model: bifrostReq.Model,
|
||||
Query: bifrostReq.Query,
|
||||
Documents: make([]string, len(bifrostReq.Documents)),
|
||||
}
|
||||
|
||||
for i, doc := range bifrostReq.Documents {
|
||||
vllmReq.Documents[i] = doc.Text
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil {
|
||||
vllmReq.TopN = bifrostReq.Params.TopN
|
||||
vllmReq.MaxTokensPerDoc = bifrostReq.Params.MaxTokensPerDoc
|
||||
vllmReq.Priority = bifrostReq.Params.Priority
|
||||
vllmReq.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
}
|
||||
|
||||
return vllmReq
|
||||
}
|
||||
|
||||
// ToBifrostRerankResponse converts a vLLM rerank response payload to Bifrost format.
|
||||
func ToBifrostRerankResponse(payload map[string]interface{}, documents []schemas.RerankDocument, returnDocuments bool) (*schemas.BifrostRerankResponse, error) {
|
||||
if payload == nil {
|
||||
return nil, fmt.Errorf("vllm rerank response is nil")
|
||||
}
|
||||
|
||||
response := &schemas.BifrostRerankResponse{}
|
||||
|
||||
if id, ok := schemas.SafeExtractString(payload["id"]); ok {
|
||||
response.ID = id
|
||||
}
|
||||
if model, ok := schemas.SafeExtractString(payload["model"]); ok {
|
||||
response.Model = model
|
||||
}
|
||||
if usage, ok := parseVLLMUsage(payload["usage"]); ok {
|
||||
response.Usage = usage
|
||||
}
|
||||
|
||||
resultsRaw := payload["results"]
|
||||
if resultsRaw == nil {
|
||||
return nil, fmt.Errorf("invalid vllm rerank response: missing results")
|
||||
}
|
||||
|
||||
resultItems, ok := resultsRaw.([]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid vllm rerank response: results must be an array")
|
||||
}
|
||||
|
||||
seenIndices := make(map[int]struct{}, len(resultItems))
|
||||
response.Results = make([]schemas.RerankResult, 0, len(resultItems))
|
||||
|
||||
for _, item := range resultItems {
|
||||
itemMap, ok := item.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid vllm rerank response: result item must be an object")
|
||||
}
|
||||
|
||||
index, ok := schemas.SafeExtractInt(itemMap["index"])
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid vllm rerank response: result index is required")
|
||||
}
|
||||
if index < 0 || index >= len(documents) {
|
||||
return nil, fmt.Errorf("invalid vllm rerank response: result index %d out of range", index)
|
||||
}
|
||||
if _, exists := seenIndices[index]; exists {
|
||||
return nil, fmt.Errorf("invalid vllm rerank response: duplicate index %d", index)
|
||||
}
|
||||
seenIndices[index] = struct{}{}
|
||||
|
||||
relevanceScore, ok := schemas.SafeExtractFloat64(itemMap["relevance_score"])
|
||||
if !ok {
|
||||
relevanceScore, ok = schemas.SafeExtractFloat64(itemMap["score"])
|
||||
}
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid vllm rerank response: relevance_score/score is required")
|
||||
}
|
||||
|
||||
result := schemas.RerankResult{
|
||||
Index: index,
|
||||
RelevanceScore: relevanceScore,
|
||||
}
|
||||
|
||||
if returnDocuments {
|
||||
doc := documents[index]
|
||||
result.Document = &doc
|
||||
}
|
||||
|
||||
response.Results = append(response.Results, result)
|
||||
}
|
||||
|
||||
sort.SliceStable(response.Results, func(i, j int) bool {
|
||||
if response.Results[i].RelevanceScore == response.Results[j].RelevanceScore {
|
||||
return response.Results[i].Index < response.Results[j].Index
|
||||
}
|
||||
return response.Results[i].RelevanceScore > response.Results[j].RelevanceScore
|
||||
})
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func parseVLLMUsage(rawUsage interface{}) (*schemas.BifrostLLMUsage, bool) {
|
||||
usageMap, ok := rawUsage.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
promptTokens := 0
|
||||
if _, hasPromptTokens := usageMap["prompt_tokens"]; hasPromptTokens {
|
||||
promptTokens, _ = schemas.SafeExtractInt(usageMap["prompt_tokens"])
|
||||
} else {
|
||||
promptTokens, _ = schemas.SafeExtractInt(usageMap["input_tokens"])
|
||||
}
|
||||
|
||||
completionTokens := 0
|
||||
if _, hasCompletionTokens := usageMap["completion_tokens"]; hasCompletionTokens {
|
||||
completionTokens, _ = schemas.SafeExtractInt(usageMap["completion_tokens"])
|
||||
} else {
|
||||
completionTokens, _ = schemas.SafeExtractInt(usageMap["output_tokens"])
|
||||
}
|
||||
|
||||
totalTokens, ok := schemas.SafeExtractInt(usageMap["total_tokens"])
|
||||
if !ok {
|
||||
totalTokens = promptTokens + completionTokens
|
||||
}
|
||||
if promptTokens == 0 && completionTokens == 0 && totalTokens == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return &schemas.BifrostLLMUsage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: totalTokens,
|
||||
}, true
|
||||
}
|
||||
154
core/providers/vllm/rerank_test.go
Normal file
154
core/providers/vllm/rerank_test.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package vllm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRerankToVLLMRerankRequestNil(t *testing.T) {
|
||||
req := ToVLLMRerankRequest(nil)
|
||||
assert.Nil(t, req)
|
||||
}
|
||||
|
||||
func TestRerankToVLLMRerankRequest(t *testing.T) {
|
||||
topN := 2
|
||||
maxTokens := 128
|
||||
priority := 5
|
||||
|
||||
req := ToVLLMRerankRequest(&schemas.BifrostRerankRequest{
|
||||
Model: "BAAI/bge-reranker-v2-m3",
|
||||
Query: "what is machine learning",
|
||||
Documents: []schemas.RerankDocument{
|
||||
{Text: "Machine learning is a subset of AI."},
|
||||
{Text: "The weather is sunny."},
|
||||
},
|
||||
Params: &schemas.RerankParameters{
|
||||
TopN: &topN,
|
||||
MaxTokensPerDoc: &maxTokens,
|
||||
Priority: &priority,
|
||||
ExtraParams: map[string]interface{}{
|
||||
"user": "test-user",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.NotNil(t, req)
|
||||
assert.Equal(t, "BAAI/bge-reranker-v2-m3", req.Model)
|
||||
assert.Equal(t, "what is machine learning", req.Query)
|
||||
assert.Equal(t, []string{"Machine learning is a subset of AI.", "The weather is sunny."}, req.Documents)
|
||||
require.NotNil(t, req.TopN)
|
||||
assert.Equal(t, 2, *req.TopN)
|
||||
require.NotNil(t, req.MaxTokensPerDoc)
|
||||
assert.Equal(t, 128, *req.MaxTokensPerDoc)
|
||||
require.NotNil(t, req.Priority)
|
||||
assert.Equal(t, 5, *req.Priority)
|
||||
assert.Equal(t, "test-user", req.ExtraParams["user"])
|
||||
}
|
||||
|
||||
func TestRerankToBifrostRerankResponse(t *testing.T) {
|
||||
documents := []schemas.RerankDocument{
|
||||
{Text: "doc-0"},
|
||||
{Text: "doc-1"},
|
||||
{Text: "doc-2"},
|
||||
}
|
||||
|
||||
response, err := ToBifrostRerankResponse(map[string]interface{}{
|
||||
"id": "rerank-id",
|
||||
"model": "BAAI/bge-reranker-v2-m3",
|
||||
"usage": map[string]interface{}{
|
||||
"prompt_tokens": 10,
|
||||
"total_tokens": 10,
|
||||
},
|
||||
"results": []interface{}{
|
||||
map[string]interface{}{"index": 1, "relevance_score": 0.1},
|
||||
map[string]interface{}{"index": 0, "relevance_score": 0.9},
|
||||
},
|
||||
}, documents, true)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
assert.Equal(t, "rerank-id", response.ID)
|
||||
assert.Equal(t, "BAAI/bge-reranker-v2-m3", response.Model)
|
||||
require.NotNil(t, response.Usage)
|
||||
assert.Equal(t, 10, response.Usage.PromptTokens)
|
||||
assert.Equal(t, 10, response.Usage.TotalTokens)
|
||||
require.Len(t, response.Results, 2)
|
||||
assert.Equal(t, 0, response.Results[0].Index)
|
||||
assert.Equal(t, 0.9, response.Results[0].RelevanceScore)
|
||||
require.NotNil(t, response.Results[0].Document)
|
||||
assert.Equal(t, "doc-0", response.Results[0].Document.Text)
|
||||
assert.Equal(t, 1, response.Results[1].Index)
|
||||
assert.Equal(t, 0.1, response.Results[1].RelevanceScore)
|
||||
}
|
||||
|
||||
func TestRerankToBifrostRerankResponseDuplicateIndices(t *testing.T) {
|
||||
documents := []schemas.RerankDocument{
|
||||
{Text: "doc-0"},
|
||||
{Text: "doc-1"},
|
||||
}
|
||||
|
||||
_, err := ToBifrostRerankResponse(map[string]interface{}{
|
||||
"results": []interface{}{
|
||||
map[string]interface{}{"index": 0, "relevance_score": 0.9},
|
||||
map[string]interface{}{"index": 0, "relevance_score": 0.8},
|
||||
},
|
||||
}, documents, true)
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "duplicate index")
|
||||
}
|
||||
|
||||
func TestRerankToBifrostRerankResponseOutOfRangeIndex(t *testing.T) {
|
||||
documents := []schemas.RerankDocument{
|
||||
{Text: "doc-0"},
|
||||
}
|
||||
|
||||
_, err := ToBifrostRerankResponse(map[string]interface{}{
|
||||
"results": []interface{}{
|
||||
map[string]interface{}{"index": 1, "relevance_score": 0.9},
|
||||
},
|
||||
}, documents, true)
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "out of range")
|
||||
}
|
||||
|
||||
func TestRerankToBifrostRerankResponseEmptyResults(t *testing.T) {
|
||||
documents := []schemas.RerankDocument{
|
||||
{Text: "doc-0"},
|
||||
}
|
||||
|
||||
response, err := ToBifrostRerankResponse(map[string]interface{}{
|
||||
"results": []interface{}{},
|
||||
}, documents, false)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
assert.Len(t, response.Results, 0)
|
||||
}
|
||||
|
||||
func TestRerankToBifrostRerankResponseZeroRelevanceScoreDoesNotFallback(t *testing.T) {
|
||||
documents := []schemas.RerankDocument{
|
||||
{Text: "doc-0"},
|
||||
}
|
||||
|
||||
response, err := ToBifrostRerankResponse(map[string]interface{}{
|
||||
"results": []interface{}{
|
||||
map[string]interface{}{"index": 0, "relevance_score": 0.0, "score": 0.99},
|
||||
},
|
||||
}, documents, false)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
require.Len(t, response.Results, 1)
|
||||
assert.Equal(t, 0.0, response.Results[0].RelevanceScore)
|
||||
}
|
||||
|
||||
func TestRerankParseVLLMUsageZeroUsage(t *testing.T) {
|
||||
usage, ok := parseVLLMUsage(map[string]interface{}{})
|
||||
assert.False(t, ok)
|
||||
assert.Nil(t, usage)
|
||||
}
|
||||
42
core/providers/vllm/transcription.go
Normal file
42
core/providers/vllm/transcription.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package vllm
|
||||
|
||||
import (
|
||||
"github.com/bytedance/sonic"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// parseVLLMTranscriptionStreamChunk parses vLLM's transcription stream JSON and returns
|
||||
// a BifrostTranscriptionStreamResponse. It returns (nil, false) if the payload is not
|
||||
// valid vLLM format or has no content to emit.
|
||||
func parseVLLMTranscriptionStreamChunk(jsonData []byte) (*schemas.BifrostTranscriptionStreamResponse, bool) {
|
||||
var chunk vLLMTranscriptionStreamChunk
|
||||
response := &schemas.BifrostTranscriptionStreamResponse{}
|
||||
if err := sonic.Unmarshal(jsonData, &chunk); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
// Done chunk: has usage (e.g. final event)
|
||||
if chunk.Usage != nil {
|
||||
return &schemas.BifrostTranscriptionStreamResponse{
|
||||
Type: schemas.TranscriptionStreamResponseTypeDone,
|
||||
Usage: chunk.Usage,
|
||||
}, true
|
||||
}
|
||||
// Delta chunk: has choices[].delta.content
|
||||
if len(chunk.Choices) == 0 || chunk.Choices[0].Delta.Content == nil {
|
||||
return nil, false
|
||||
}
|
||||
if len(chunk.Choices) > 0 {
|
||||
reason := chunk.Choices[0].FinishReason
|
||||
if reason == nil && chunk.Choices[0].StopReason != nil {
|
||||
reason = chunk.Choices[0].StopReason
|
||||
}
|
||||
if reason != nil && *reason == "stop" {
|
||||
response.Text = *chunk.Choices[0].Delta.Content
|
||||
response.Type = schemas.TranscriptionStreamResponseTypeDone
|
||||
} else {
|
||||
response.Type = schemas.TranscriptionStreamResponseTypeDelta
|
||||
}
|
||||
response.Delta = chunk.Choices[0].Delta.Content
|
||||
}
|
||||
return response, true
|
||||
}
|
||||
19
core/providers/vllm/types.go
Normal file
19
core/providers/vllm/types.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package vllm
|
||||
|
||||
import (
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// vLLMTranscriptionStreamChunk represents a single transcription streaming chunk from vLLM.
|
||||
type vLLMTranscriptionStreamChunk struct {
|
||||
Object string `json:"object"`
|
||||
Choices []struct {
|
||||
Delta struct {
|
||||
Content *string `json:"content"`
|
||||
ReasoningContent *string `json:"reasoning_content"`
|
||||
} `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason,omitempty"`
|
||||
StopReason *string `json:"stop_reason,omitempty"`
|
||||
} `json:"choices"`
|
||||
Usage *schemas.TranscriptionUsage `json:"usage,omitempty"`
|
||||
}
|
||||
19
core/providers/vllm/utils.go
Normal file
19
core/providers/vllm/utils.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package vllm
|
||||
|
||||
import (
|
||||
"github.com/bytedance/sonic"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func HandleVLLMResponse[T any](responseBody []byte, response *T, requestBody []byte, sendBackRawRequest bool, sendBackRawResponse bool) (rawRequest interface{}, rawResponse interface{}, bifrostErr *schemas.BifrostError) {
|
||||
var errorResp schemas.BifrostError
|
||||
rawRequest, rawResponse, bifrostErr = providerUtils.HandleProviderResponse(responseBody, response, requestBody, sendBackRawRequest, sendBackRawResponse)
|
||||
if bifrostErr != nil {
|
||||
return rawRequest, rawResponse, bifrostErr
|
||||
}
|
||||
if err := sonic.Unmarshal(responseBody, &errorResp); err == nil && errorResp.Error != nil && errorResp.Error.Message != "" {
|
||||
return rawRequest, rawResponse, &errorResp
|
||||
}
|
||||
return rawRequest, rawResponse, nil
|
||||
}
|
||||
801
core/providers/vllm/vllm.go
Normal file
801
core/providers/vllm/vllm.go
Normal file
@@ -0,0 +1,801 @@
|
||||
// Package vllm implements the vLLM LLM provider (OpenAI-compatible).
|
||||
package vllm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/openai"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// VLLMProvider implements the Provider interface for vLLM's OpenAI-compatible API.
|
||||
type VLLMProvider struct {
|
||||
logger schemas.Logger // Logger for provider operations
|
||||
client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response)
|
||||
streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader)
|
||||
networkConfig schemas.NetworkConfig // Network configuration including extra headers
|
||||
sendBackRawRequest bool // Whether to include raw request in BifrostResponse
|
||||
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
|
||||
}
|
||||
|
||||
// NewVLLMProvider creates a new vLLM provider instance.
|
||||
func NewVLLMProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*VLLMProvider, error) {
|
||||
config.CheckAndSetDefaults()
|
||||
|
||||
requestTimeout := time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)
|
||||
client := &fasthttp.Client{
|
||||
ReadTimeout: requestTimeout,
|
||||
WriteTimeout: requestTimeout,
|
||||
MaxConnsPerHost: config.NetworkConfig.MaxConnsPerHost,
|
||||
MaxIdleConnDuration: 30 * time.Second,
|
||||
MaxConnWaitTimeout: requestTimeout,
|
||||
MaxConnDuration: time.Second * time.Duration(schemas.DefaultMaxConnDurationInSeconds),
|
||||
ConnPoolStrategy: fasthttp.FIFO,
|
||||
}
|
||||
|
||||
client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger)
|
||||
client = providerUtils.ConfigureDialer(client)
|
||||
client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger)
|
||||
streamingClient := providerUtils.BuildStreamingClient(client)
|
||||
config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/")
|
||||
|
||||
// BaseURL is optional when keys have vllm_key_config with per-key URLs
|
||||
return &VLLMProvider{
|
||||
logger: logger,
|
||||
client: client,
|
||||
streamingClient: streamingClient,
|
||||
networkConfig: config.NetworkConfig,
|
||||
sendBackRawRequest: config.SendBackRawRequest,
|
||||
sendBackRawResponse: config.SendBackRawResponse,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetProviderKey returns the provider identifier for vLLM.
|
||||
func (provider *VLLMProvider) GetProviderKey() schemas.ModelProvider {
|
||||
return schemas.VLLM
|
||||
}
|
||||
|
||||
// getBaseURL resolves the base URL for a request from the per-key vllm_key_config.
|
||||
// Each vLLM key must have its own URL configured — there is no provider-level fallback.
|
||||
func (provider *VLLMProvider) getBaseURL(key schemas.Key) string {
|
||||
if key.VLLMKeyConfig != nil && key.VLLMKeyConfig.URL.GetValue() != "" {
|
||||
return strings.TrimRight(key.VLLMKeyConfig.URL.GetValue(), "/")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// baseURLOrError returns the resolved base URL or a BifrostError when none is configured.
|
||||
func (provider *VLLMProvider) baseURLOrError(key schemas.Key) (string, *schemas.BifrostError) {
|
||||
u := provider.getBaseURL(key)
|
||||
if u == "" {
|
||||
return "", providerUtils.NewBifrostOperationError(
|
||||
"no base URL configured: set vllm_key_config.url on the key",
|
||||
nil)
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// listModelsByKey performs a list models request for a single vLLM key,
|
||||
// resolving the per-key URL so each backend is queried individually.
|
||||
func (provider *VLLMProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
|
||||
baseURL, bifrostErr := provider.baseURLOrError(key)
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
url := baseURL + providerUtils.GetPathFromContext(ctx, "/v1/models")
|
||||
return openai.ListModelsByKey(
|
||||
ctx,
|
||||
provider.client,
|
||||
url,
|
||||
key,
|
||||
request.Unfiltered,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
provider.GetProviderKey(),
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
)
|
||||
}
|
||||
|
||||
// ListModels performs a list models request to vLLM's API.
|
||||
// Requests are made concurrently per key so that each backend is queried
|
||||
// with its own URL (from vllm_key_config).
|
||||
func (provider *VLLMProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
|
||||
return providerUtils.HandleMultipleListModelsRequests(
|
||||
ctx,
|
||||
keys,
|
||||
request,
|
||||
provider.listModelsByKey,
|
||||
)
|
||||
}
|
||||
|
||||
// TextCompletion performs a text completion request to vLLM's API.
|
||||
func (provider *VLLMProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) {
|
||||
ctx.SetValue(schemas.BifrostContextKeyPassthroughExtraParams, true)
|
||||
baseURL, bifrostErr := provider.baseURLOrError(key)
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
return openai.HandleOpenAITextCompletionRequest(
|
||||
ctx,
|
||||
provider.client,
|
||||
baseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"),
|
||||
request,
|
||||
key,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
provider.GetProviderKey(),
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
HandleVLLMResponse,
|
||||
nil,
|
||||
provider.logger,
|
||||
)
|
||||
}
|
||||
|
||||
// TextCompletionStream performs a streaming text completion request to vLLM's API.
|
||||
func (provider *VLLMProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
ctx.SetValue(schemas.BifrostContextKeyPassthroughExtraParams, true)
|
||||
baseURL, bifrostErr := provider.baseURLOrError(key)
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
var authHeader map[string]string
|
||||
if key.Value.GetValue() != "" {
|
||||
authHeader = map[string]string{"Authorization": "Bearer " + key.Value.GetValue()}
|
||||
}
|
||||
return openai.HandleOpenAITextCompletionStreaming(
|
||||
ctx,
|
||||
provider.streamingClient,
|
||||
baseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"),
|
||||
request,
|
||||
authHeader,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
provider.GetProviderKey(),
|
||||
nil,
|
||||
postHookRunner,
|
||||
HandleVLLMResponse,
|
||||
nil,
|
||||
provider.logger,
|
||||
postHookSpanFinalizer,
|
||||
)
|
||||
}
|
||||
|
||||
// ChatCompletion performs a chat completion request to vLLM's API.
|
||||
func (provider *VLLMProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
ctx.SetValue(schemas.BifrostContextKeyPassthroughExtraParams, true)
|
||||
baseURL, bifrostErr := provider.baseURLOrError(key)
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
return openai.HandleOpenAIChatCompletionRequest(
|
||||
ctx,
|
||||
provider.client,
|
||||
baseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"),
|
||||
request,
|
||||
key,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
provider.GetProviderKey(),
|
||||
HandleVLLMResponse,
|
||||
nil,
|
||||
provider.logger,
|
||||
)
|
||||
}
|
||||
|
||||
// ChatCompletionStream performs a streaming chat completion request to vLLM's API.
|
||||
func (provider *VLLMProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
ctx.SetValue(schemas.BifrostContextKeyPassthroughExtraParams, true)
|
||||
baseURL, bifrostErr := provider.baseURLOrError(key)
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
var authHeader map[string]string
|
||||
if key.Value.GetValue() != "" {
|
||||
authHeader = map[string]string{"Authorization": "Bearer " + key.Value.GetValue()}
|
||||
}
|
||||
return openai.HandleOpenAIChatCompletionStreaming(
|
||||
ctx,
|
||||
provider.streamingClient,
|
||||
baseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"),
|
||||
request,
|
||||
authHeader,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
provider.GetProviderKey(),
|
||||
postHookRunner,
|
||||
nil,
|
||||
HandleVLLMResponse,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
provider.logger,
|
||||
postHookSpanFinalizer,
|
||||
)
|
||||
}
|
||||
|
||||
// Embedding performs an embedding request to vLLM's API.
|
||||
func (provider *VLLMProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
|
||||
baseURL, bifrostErr := provider.baseURLOrError(key)
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
return openai.HandleOpenAIEmbeddingRequest(
|
||||
ctx,
|
||||
provider.client,
|
||||
baseURL+providerUtils.GetPathFromContext(ctx, "/v1/embeddings"),
|
||||
request,
|
||||
key,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
provider.GetProviderKey(),
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
HandleVLLMResponse,
|
||||
provider.logger,
|
||||
)
|
||||
}
|
||||
|
||||
// Responses performs a responses request to vLLM's API (via chat completion).
|
||||
func (provider *VLLMProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
response := chatResponse.ToBifrostResponsesResponse()
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// ResponsesStream performs a streaming responses request to vLLM's API (via chat completion stream).
|
||||
func (provider *VLLMProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true)
|
||||
return provider.ChatCompletionStream(
|
||||
ctx,
|
||||
postHookRunner,
|
||||
postHookSpanFinalizer,
|
||||
key,
|
||||
request.ToChatRequest(),
|
||||
)
|
||||
}
|
||||
|
||||
// Speech is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
func isRerankFallbackStatus(statusCode int) bool {
|
||||
// vLLM deployments may return 501 for unimplemented routes.
|
||||
// We fallback on 501 in addition to 404/405 for compatibility.
|
||||
return statusCode == fasthttp.StatusNotFound ||
|
||||
statusCode == fasthttp.StatusMethodNotAllowed ||
|
||||
statusCode == fasthttp.StatusNotImplemented
|
||||
}
|
||||
|
||||
func (provider *VLLMProvider) callVLLMRerankEndpoint(
|
||||
ctx *schemas.BifrostContext,
|
||||
key schemas.Key,
|
||||
request *schemas.BifrostRerankRequest,
|
||||
endpointPath string,
|
||||
jsonData []byte,
|
||||
) (map[string]interface{}, interface{}, interface{}, []byte, int, time.Duration, *schemas.BifrostError) {
|
||||
baseURL, bifrostErr := provider.baseURLOrError(key)
|
||||
if bifrostErr != nil {
|
||||
return nil, nil, nil, nil, 0, 0, bifrostErr
|
||||
}
|
||||
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
||||
|
||||
req.SetRequestURI(baseURL + endpointPath)
|
||||
req.Header.SetMethod(http.MethodPost)
|
||||
req.Header.SetContentType("application/json")
|
||||
|
||||
if key.Value.GetValue() != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
||||
}
|
||||
if !providerUtils.ApplyLargePayloadRequestBodyWithModelNormalization(ctx, req, schemas.VLLM) {
|
||||
req.SetBody(jsonData)
|
||||
}
|
||||
|
||||
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
||||
defer wait()
|
||||
if bifrostErr != nil {
|
||||
return nil, nil, nil, nil, 0, latency, bifrostErr
|
||||
}
|
||||
|
||||
statusCode := resp.StatusCode()
|
||||
if statusCode != fasthttp.StatusOK {
|
||||
rawErrBody := append([]byte(nil), resp.Body()...)
|
||||
return nil, nil, nil, rawErrBody, statusCode, latency, openai.ParseOpenAIError(resp)
|
||||
}
|
||||
|
||||
body, err := providerUtils.CheckAndDecodeBody(resp)
|
||||
if err != nil {
|
||||
rawErrBody := append([]byte(nil), resp.Body()...)
|
||||
return nil, nil, nil, rawErrBody, statusCode, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
||||
}
|
||||
|
||||
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
||||
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
||||
|
||||
responsePayload := make(map[string]interface{})
|
||||
rawRequest, rawResponse, bifrostErr := HandleVLLMResponse(body, &responsePayload, jsonData, sendBackRawRequest, sendBackRawResponse)
|
||||
if bifrostErr != nil {
|
||||
return nil, nil, nil, body, statusCode, latency, bifrostErr
|
||||
}
|
||||
|
||||
return responsePayload, rawRequest, rawResponse, body, statusCode, latency, nil
|
||||
}
|
||||
|
||||
// Rerank performs a rerank request to vLLM's API.
|
||||
func (provider *VLLMProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) {
|
||||
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
||||
ctx,
|
||||
request,
|
||||
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
||||
return ToVLLMRerankRequest(request), nil
|
||||
})
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
|
||||
resolvedPath := providerUtils.GetPathFromContext(ctx, "")
|
||||
hasPathOverride := resolvedPath != ""
|
||||
if !hasPathOverride {
|
||||
resolvedPath = "/v1/rerank"
|
||||
} else if !strings.HasPrefix(resolvedPath, "/") {
|
||||
resolvedPath = "/" + resolvedPath
|
||||
}
|
||||
|
||||
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
||||
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
||||
|
||||
responsePayload, rawRequest, rawResponse, responseBody, statusCode, latency, bifrostErr := provider.callVLLMRerankEndpoint(ctx, key, request, resolvedPath, jsonData)
|
||||
if bifrostErr != nil && !hasPathOverride && isRerankFallbackStatus(statusCode) {
|
||||
var fallbackLatency time.Duration
|
||||
responsePayload, rawRequest, rawResponse, responseBody, statusCode, fallbackLatency, bifrostErr = provider.callVLLMRerankEndpoint(ctx, key, request, "/rerank", jsonData)
|
||||
latency += fallbackLatency
|
||||
}
|
||||
if bifrostErr != nil {
|
||||
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, responseBody, sendBackRawRequest, sendBackRawResponse)
|
||||
}
|
||||
|
||||
returnDocuments := request.Params != nil && request.Params.ReturnDocuments != nil && *request.Params.ReturnDocuments
|
||||
bifrostResponse, err := ToBifrostRerankResponse(responsePayload, request.Documents, returnDocuments)
|
||||
if err != nil {
|
||||
return nil, providerUtils.EnrichError(
|
||||
ctx,
|
||||
providerUtils.NewBifrostOperationError("error converting rerank response", err),
|
||||
jsonData,
|
||||
responseBody,
|
||||
sendBackRawRequest,
|
||||
sendBackRawResponse,
|
||||
)
|
||||
}
|
||||
|
||||
// Keep requested model as the canonical model in Bifrost response.
|
||||
bifrostResponse.Model = request.Model
|
||||
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
|
||||
|
||||
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
||||
bifrostResponse.ExtraFields.RawRequest = rawRequest
|
||||
}
|
||||
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
||||
bifrostResponse.ExtraFields.RawResponse = rawResponse
|
||||
}
|
||||
|
||||
return bifrostResponse, nil
|
||||
}
|
||||
|
||||
// OCR is not supported by the Vllm provider.
|
||||
func (provider *VLLMProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostOCRRequest) (*schemas.BifrostOCRResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.OCRRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// SpeechStream is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Transcription performs a transcription request to vLLM's API.
|
||||
func (provider *VLLMProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
|
||||
baseURL, bifrostErr := provider.baseURLOrError(key)
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
return openai.HandleOpenAITranscriptionRequest(
|
||||
ctx,
|
||||
provider.client,
|
||||
baseURL+providerUtils.GetPathFromContext(ctx, "/v1/audio/transcriptions"),
|
||||
request,
|
||||
key,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
provider.GetProviderKey(),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
HandleVLLMResponse,
|
||||
provider.logger,
|
||||
)
|
||||
}
|
||||
|
||||
// TranscriptionStream performs a streaming transcription request to vLLM's API.
|
||||
func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
baseURL, bifrostErr := provider.baseURLOrError(key)
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
{
|
||||
logger := provider.logger
|
||||
providerName := provider.GetProviderKey()
|
||||
// Use centralized converter
|
||||
reqBody := openai.ToOpenAITranscriptionRequest(request)
|
||||
if reqBody == nil {
|
||||
return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil)
|
||||
}
|
||||
reqBody.Stream = schemas.Ptr(true)
|
||||
|
||||
// Create multipart form
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
|
||||
if bifrostErr := openai.ParseTranscriptionFormDataBodyFromRequest(writer, reqBody, providerName); bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
|
||||
// Prepare OpenAI headers
|
||||
headers := map[string]string{
|
||||
"Content-Type": writer.FormDataContentType(),
|
||||
"Accept": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
}
|
||||
if key.Value.GetValue() != "" {
|
||||
headers["Authorization"] = "Bearer " + key.Value.GetValue()
|
||||
}
|
||||
|
||||
// Create HTTP request for streaming
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
resp.StreamBody = true
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
|
||||
// Set any extra headers from network config
|
||||
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
||||
|
||||
req.Header.SetMethod(http.MethodPost)
|
||||
req.SetRequestURI(baseURL + providerUtils.GetPathFromContext(ctx, "/v1/audio/transcriptions"))
|
||||
|
||||
// Set headers
|
||||
for key, value := range headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
req.SetBody(body.Bytes())
|
||||
|
||||
// Make the request
|
||||
err := provider.streamingClient.Do(req, resp)
|
||||
if err != nil {
|
||||
defer providerUtils.ReleaseStreamingResponse(resp)
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
Error: &schemas.ErrorField{
|
||||
Type: schemas.Ptr(schemas.RequestCancelled),
|
||||
Message: schemas.ErrRequestCancelled,
|
||||
Error: err,
|
||||
},
|
||||
}
|
||||
}
|
||||
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err)
|
||||
}
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err)
|
||||
}
|
||||
|
||||
// Store provider response headers in context before status check so error responses also forward them
|
||||
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
||||
|
||||
// Check for HTTP errors
|
||||
if resp.StatusCode() != fasthttp.StatusOK {
|
||||
defer providerUtils.ReleaseStreamingResponse(resp)
|
||||
return nil, openai.ParseOpenAIError(resp)
|
||||
}
|
||||
|
||||
// Large payload streaming passthrough — pipe raw upstream SSE to client
|
||||
if providerUtils.SetupStreamingPassthrough(ctx, resp) {
|
||||
responseChan := make(chan *schemas.BifrostStreamChunk)
|
||||
close(responseChan)
|
||||
return responseChan, nil
|
||||
}
|
||||
|
||||
// Create response channel
|
||||
responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize)
|
||||
|
||||
providerUtils.SetStreamIdleTimeoutIfEmpty(ctx, provider.networkConfig.StreamIdleTimeoutInSeconds)
|
||||
|
||||
// Start streaming in a goroutine
|
||||
go func() {
|
||||
defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer)
|
||||
defer func() {
|
||||
if ctx.Err() == context.Canceled {
|
||||
providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer)
|
||||
} else if ctx.Err() == context.DeadlineExceeded {
|
||||
providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer)
|
||||
}
|
||||
close(responseChan)
|
||||
}()
|
||||
defer providerUtils.ReleaseStreamingResponse(resp)
|
||||
// Decompress gzip-encoded streams transparently (no-op for non-gzip)
|
||||
reader, releaseGzip := providerUtils.DecompressStreamBody(resp)
|
||||
defer releaseGzip()
|
||||
|
||||
// Wrap reader with idle timeout to detect stalled streams.
|
||||
reader, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(reader, resp.BodyStream(), providerUtils.GetStreamIdleTimeout(ctx))
|
||||
defer stopIdleTimeout()
|
||||
|
||||
// Setup cancellation handler to close the raw network stream on ctx cancellation,
|
||||
// which immediately unblocks any in-progress read (including reads blocked inside a gzip decompression layer).
|
||||
stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger)
|
||||
defer stopCancellation()
|
||||
|
||||
sseReader := providerUtils.GetSSEDataReader(ctx, reader)
|
||||
chunkIndex := -1
|
||||
|
||||
startTime := time.Now()
|
||||
lastChunkTime := startTime
|
||||
var fullTranscriptionText strings.Builder
|
||||
|
||||
for {
|
||||
// If context was cancelled/timed out, let defer handle it
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
dataBytes, readErr := sseReader.ReadDataLine()
|
||||
if readErr != nil {
|
||||
if readErr != io.EOF {
|
||||
// If context was cancelled/timed out, let defer handle it
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
||||
logger.Warn("Error reading stream: %v", readErr)
|
||||
providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger, postHookSpanFinalizer)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
jsonData := string(dataBytes)
|
||||
|
||||
// Skip empty data
|
||||
if strings.TrimSpace(jsonData) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var response schemas.BifrostTranscriptionStreamResponse
|
||||
var bifrostErr *schemas.BifrostError
|
||||
|
||||
_, _, bifrostErr = HandleVLLMResponse(dataBytes, &response, nil, false, false)
|
||||
if bifrostErr != nil {
|
||||
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
||||
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, body.Bytes(), dataBytes, false, providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)), responseChan, logger, postHookSpanFinalizer)
|
||||
return
|
||||
}
|
||||
|
||||
customChunk, ok := parseVLLMTranscriptionStreamChunk(dataBytes)
|
||||
if !ok || customChunk == nil {
|
||||
logger.Warn("customChunkParser returned no chunk")
|
||||
continue
|
||||
}
|
||||
response = *customChunk
|
||||
|
||||
chunkIndex++
|
||||
if response.Delta != nil {
|
||||
fullTranscriptionText.WriteString(*response.Delta)
|
||||
}
|
||||
|
||||
response.ExtraFields = schemas.BifrostResponseExtraFields{
|
||||
ChunkIndex: chunkIndex,
|
||||
Latency: time.Since(lastChunkTime).Milliseconds(),
|
||||
}
|
||||
lastChunkTime = time.Now()
|
||||
|
||||
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
||||
response.ExtraFields.RawResponse = jsonData
|
||||
}
|
||||
if response.Usage != nil || response.Type == schemas.TranscriptionStreamResponseTypeDone {
|
||||
response.ExtraFields.Latency = time.Since(startTime).Milliseconds()
|
||||
response.Text = fullTranscriptionText.String()
|
||||
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
||||
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, &response, nil), responseChan, postHookSpanFinalizer)
|
||||
return
|
||||
}
|
||||
|
||||
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, &response, nil), responseChan, postHookSpanFinalizer)
|
||||
}
|
||||
}()
|
||||
|
||||
return responseChan, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ImageGeneration is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageGenerationStream is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageEdit is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageEditStream is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageVariation is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) ImageVariation(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageVariationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageVariationRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoGeneration is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) VideoGeneration(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoGenerationRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoRetrieve is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) VideoRetrieve(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoDownload is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) VideoDownload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDownloadRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoDelete is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) VideoDelete(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDeleteRequest) (*schemas.BifrostVideoDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoList is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) VideoList(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoListRequest) (*schemas.BifrostVideoListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoRemix is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) VideoRemix(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRemixRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRemixRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileUpload is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileList is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileRetrieve is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileDelete is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileContent is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchCreate is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchList is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchRetrieve is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchCancel is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchDelete is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) BatchDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchDeleteRequest) (*schemas.BifrostBatchDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchResults is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// CountTokens is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerCreate is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) ContainerCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerCreateRequest) (*schemas.BifrostContainerCreateResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerCreateRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerList is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) ContainerList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerListRequest) (*schemas.BifrostContainerListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerRetrieve is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) ContainerRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerRetrieveRequest) (*schemas.BifrostContainerRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerDelete is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) ContainerDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerDeleteRequest) (*schemas.BifrostContainerDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileCreate is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) ContainerFileCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerFileCreateRequest) (*schemas.BifrostContainerFileCreateResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileCreateRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileList is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) ContainerFileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileListRequest) (*schemas.BifrostContainerFileListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileRetrieve is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) ContainerFileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileRetrieveRequest) (*schemas.BifrostContainerFileRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileContent is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) ContainerFileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileContentRequest) (*schemas.BifrostContainerFileContentResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileContentRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileDelete is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) ContainerFileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileDeleteRequest) (*schemas.BifrostContainerFileDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Passthrough is not supported by the vLLM provider.
|
||||
func (provider *VLLMProvider) Passthrough(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
func (provider *VLLMProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
83
core/providers/vllm/vllm_test.go
Normal file
83
core/providers/vllm/vllm_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package vllm_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/internal/llmtests"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func TestVLLM(t *testing.T) {
|
||||
t.Parallel()
|
||||
baseURL := strings.TrimSpace(os.Getenv("VLLM_BASE_URL"))
|
||||
if baseURL == "" {
|
||||
t.Skip("Skipping vLLM tests because VLLM_BASE_URL is not set")
|
||||
}
|
||||
|
||||
client, ctx, cancel, err := llmtests.SetupTest()
|
||||
if err != nil {
|
||||
t.Fatalf("Error initializing test setup: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
defer client.Shutdown()
|
||||
|
||||
chatModel := getEnvWithDefault("VLLM_CHAT_MODEL", "Qwen/Qwen3-0.6B")
|
||||
textModel := getEnvWithDefault("VLLM_TEXT_MODEL", "Qwen/Qwen3-0.6B")
|
||||
reasoningModel := getEnvWithDefault("VLLM_REASONING_MODEL", "Qwen/Qwen3-0.6B")
|
||||
embeddingModel := getEnvWithDefault("VLLM_EMBEDDING_MODEL", "Qwen3-Embedding-0.6B")
|
||||
rerankModel := strings.TrimSpace(os.Getenv("VLLM_RERANK_MODEL"))
|
||||
|
||||
testConfig := llmtests.ComprehensiveTestConfig{
|
||||
Provider: schemas.VLLM,
|
||||
ChatModel: chatModel,
|
||||
TextModel: textModel,
|
||||
ReasoningModel: reasoningModel,
|
||||
EmbeddingModel: embeddingModel,
|
||||
RerankModel: rerankModel,
|
||||
Scenarios: llmtests.TestScenarios{
|
||||
TextCompletion: true,
|
||||
TextCompletionStream: true,
|
||||
SimpleChat: true,
|
||||
CompletionStream: true,
|
||||
MultiTurnConversation: true,
|
||||
ToolCalls: true,
|
||||
ToolCallsStreaming: true,
|
||||
MultipleToolCalls: true,
|
||||
MultipleToolCallsStreaming: true,
|
||||
End2EndToolCalling: true,
|
||||
AutomaticFunctionCall: true,
|
||||
ImageURL: false,
|
||||
ImageBase64: false,
|
||||
MultipleImages: false,
|
||||
CompleteEnd2End: true,
|
||||
Embedding: true,
|
||||
Rerank: rerankModel != "",
|
||||
ListModels: true,
|
||||
Reasoning: true,
|
||||
PassThroughExtraParams: true,
|
||||
SpeechSynthesis: false,
|
||||
SpeechSynthesisStream: false,
|
||||
Transcription: true,
|
||||
TranscriptionStream: false,
|
||||
ImageGeneration: false,
|
||||
ImageGenerationStream: false,
|
||||
ImageEdit: false,
|
||||
ImageEditStream: false,
|
||||
ImageVariation: false,
|
||||
ImageVariationStream: false,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("VLLMTests", func(t *testing.T) {
|
||||
llmtests.RunAllComprehensiveTests(t, client, ctx, testConfig)
|
||||
})
|
||||
}
|
||||
|
||||
func getEnvWithDefault(key, def string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return def
|
||||
}
|
||||
Reference in New Issue
Block a user