first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

View File

@@ -0,0 +1,144 @@
// Package huggingface provides a HuggingFace chat provider.
package huggingface
import (
"fmt"
"github.com/bytedance/sonic"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
)
// sanitizeMessagesForHuggingFace removes unsupported ChatAssistantMessage fields
// from chat messages. HuggingFace's OpenAI-compatible API doesn't support fields
// like reasoning_details, reasoning, annotations, audio, and refusal.
// Only ToolCalls is preserved from ChatAssistantMessage.
func sanitizeMessagesForHuggingFace(messages []schemas.ChatMessage) []schemas.ChatMessage {
sanitized := make([]schemas.ChatMessage, len(messages))
for i, msg := range messages {
sanitized[i] = schemas.ChatMessage{
Name: msg.Name,
Role: msg.Role,
Content: msg.Content,
ChatToolMessage: msg.ChatToolMessage,
}
// Only preserve ToolCalls from ChatAssistantMessage
if msg.ChatAssistantMessage != nil && len(msg.ChatAssistantMessage.ToolCalls) > 0 {
sanitized[i].ChatAssistantMessage = &schemas.ChatAssistantMessage{
ToolCalls: msg.ChatAssistantMessage.ToolCalls,
}
}
}
return sanitized
}
func ToHuggingFaceChatCompletionRequest(bifrostReq *schemas.BifrostChatRequest) (*HuggingFaceChatRequest, error) {
if bifrostReq == nil || bifrostReq.Input == nil {
return nil, nil
}
// Create the HuggingFace request
// Sanitize messages to remove unsupported fields like reasoning_details
hfReq := &HuggingFaceChatRequest{
Messages: sanitizeMessagesForHuggingFace(bifrostReq.Input),
Model: bifrostReq.Model,
}
// Map parameters if present
if bifrostReq.Params != nil {
params := bifrostReq.Params
if params.FrequencyPenalty != nil {
hfReq.FrequencyPenalty = params.FrequencyPenalty
}
if params.LogProbs != nil {
hfReq.Logprobs = params.LogProbs
}
if params.MaxCompletionTokens != nil {
hfReq.MaxTokens = params.MaxCompletionTokens
}
if params.PresencePenalty != nil {
hfReq.PresencePenalty = params.PresencePenalty
}
if params.Seed != nil {
hfReq.Seed = params.Seed
}
if len(params.Stop) > 0 {
hfReq.Stop = params.Stop
}
if params.Temperature != nil {
hfReq.Temperature = params.Temperature
}
if params.TopLogProbs != nil {
hfReq.TopLogprobs = params.TopLogProbs
}
if params.TopP != nil {
hfReq.TopP = params.TopP
}
// Handle response format (direct type assertion to avoid marshal→unmarshal round-trip)
if params.ResponseFormat != nil {
var hfRF *HuggingFaceResponseFormat
if rfMap, ok := (*params.ResponseFormat).(map[string]interface{}); ok {
hfRF = &HuggingFaceResponseFormat{}
if t, ok := rfMap["type"].(string); ok {
hfRF.Type = t
}
if jsVal, ok := rfMap["json_schema"]; ok {
jsBytes, err := providerUtils.MarshalSorted(jsVal)
if err != nil {
return nil, fmt.Errorf("failed to marshal json_schema: %w", err)
}
var hfSchema HuggingFaceJSONSchema
if err := sonic.Unmarshal(jsBytes, &hfSchema); err != nil {
return nil, fmt.Errorf("failed to unmarshal json_schema: %w", err)
}
hfRF.JSONSchema = &hfSchema
}
} else if converted, err := schemas.ConvertViaJSON[HuggingFaceResponseFormat](*params.ResponseFormat); err == nil {
hfRF = &converted
}
hfReq.ResponseFormat = hfRF
}
// Handle stream options
if params.StreamOptions != nil {
hfReq.StreamOptions = &schemas.ChatStreamOptions{
IncludeUsage: params.StreamOptions.IncludeUsage,
}
}
hfReq.Tools = params.Tools
// Handle tool choice
if params.ToolChoice != nil {
hfToolChoice := &HuggingFaceToolChoice{}
if params.ToolChoice.ChatToolChoiceStr != nil {
switch *params.ToolChoice.ChatToolChoiceStr {
case "auto":
auto := EnumStringTypeAuto
hfToolChoice.EnumValue = &auto
case "none":
none := EnumStringTypeNone
hfToolChoice.EnumValue = &none
case "required":
required := EnumStringTypeRequired
hfToolChoice.EnumValue = &required
}
} else if params.ToolChoice.ChatToolChoiceStruct != nil {
if params.ToolChoice.ChatToolChoiceStruct.Type == schemas.ChatToolChoiceTypeFunction && params.ToolChoice.ChatToolChoiceStruct.Function != nil {
hfToolChoice.Function = &schemas.ChatToolChoiceFunction{
Name: params.ToolChoice.ChatToolChoiceStruct.Function.Name,
}
}
}
if hfToolChoice.EnumValue != nil || hfToolChoice.Function != nil {
hfReq.ToolChoice = hfToolChoice
}
}
hfReq.ExtraParams = bifrostReq.Params.ExtraParams
}
return hfReq, nil
}

View File

@@ -0,0 +1,132 @@
package huggingface
import (
"encoding/json"
"testing"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestToHuggingFaceChatCompletionRequest_ResponseFormat(t *testing.T) {
makeReq := func(rf *interface{}) *schemas.BifrostChatRequest {
return &schemas.BifrostChatRequest{
Model: "test-model",
Input: []schemas.ChatMessage{{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("hello")}}},
Params: &schemas.ChatParameters{
ResponseFormat: rf,
},
}
}
tests := []struct {
name string
responseFormat *interface{}
wantErr bool
validate func(t *testing.T, result *HuggingFaceChatRequest)
}{
{
name: "nil_response_format",
responseFormat: nil,
validate: func(t *testing.T, result *HuggingFaceChatRequest) {
assert.Nil(t, result.ResponseFormat)
},
},
{
name: "map_type_only",
responseFormat: func() *interface{} {
var rf interface{} = map[string]interface{}{"type": "json_object"}
return &rf
}(),
validate: func(t *testing.T, result *HuggingFaceChatRequest) {
require.NotNil(t, result.ResponseFormat)
assert.Equal(t, "json_object", result.ResponseFormat.Type)
assert.Nil(t, result.ResponseFormat.JSONSchema)
},
},
{
name: "map_with_json_schema",
responseFormat: func() *interface{} {
var rf interface{} = map[string]interface{}{
"type": "json_schema",
"json_schema": map[string]interface{}{
"name": "my_schema",
"description": "A test schema",
"strict": true,
"schema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"answer": map[string]interface{}{"type": "string"},
},
},
},
}
return &rf
}(),
validate: func(t *testing.T, result *HuggingFaceChatRequest) {
require.NotNil(t, result.ResponseFormat)
assert.Equal(t, "json_schema", result.ResponseFormat.Type)
require.NotNil(t, result.ResponseFormat.JSONSchema)
assert.Equal(t, "my_schema", result.ResponseFormat.JSONSchema.Name)
assert.Equal(t, "A test schema", result.ResponseFormat.JSONSchema.Description)
require.NotNil(t, result.ResponseFormat.JSONSchema.Strict)
assert.True(t, *result.ResponseFormat.JSONSchema.Strict)
require.NotNil(t, result.ResponseFormat.JSONSchema.Schema)
// Verify schema content round-tripped correctly
var schemaMap map[string]interface{}
err := json.Unmarshal(result.ResponseFormat.JSONSchema.Schema, &schemaMap)
require.NoError(t, err)
assert.Equal(t, "object", schemaMap["type"])
props, ok := schemaMap["properties"].(map[string]interface{})
require.True(t, ok)
assert.Contains(t, props, "answer")
},
},
{
name: "struct_fallback_via_convert",
responseFormat: func() *interface{} {
var rf interface{} = HuggingFaceResponseFormat{
Type: "json_schema",
JSONSchema: &HuggingFaceJSONSchema{
Name: "fallback_schema",
Strict: schemas.Ptr(true),
},
}
return &rf
}(),
validate: func(t *testing.T, result *HuggingFaceChatRequest) {
require.NotNil(t, result.ResponseFormat, "ResponseFormat should not be nil — ConvertViaJSON fallback must handle struct values")
assert.Equal(t, "json_schema", result.ResponseFormat.Type)
require.NotNil(t, result.ResponseFormat.JSONSchema)
assert.Equal(t, "fallback_schema", result.ResponseFormat.JSONSchema.Name)
require.NotNil(t, result.ResponseFormat.JSONSchema.Strict)
assert.True(t, *result.ResponseFormat.JSONSchema.Strict)
},
},
{
name: "inconvertible_value_graceful_nil",
responseFormat: func() *interface{} {
var rf interface{} = 42
return &rf
}(),
validate: func(t *testing.T, result *HuggingFaceChatRequest) {
assert.Nil(t, result.ResponseFormat, "inconvertible value should gracefully result in nil ResponseFormat")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := makeReq(tt.responseFormat)
result, err := ToHuggingFaceChatCompletionRequest(req)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.NotNil(t, result)
tt.validate(t, result)
})
}
}

View File

@@ -0,0 +1,165 @@
package huggingface
import (
"fmt"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/schemas"
)
// ToHuggingFaceEmbeddingRequest converts a Bifrost embedding request to HuggingFace format
func ToHuggingFaceEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) (*HuggingFaceEmbeddingRequest, error) {
if bifrostReq == nil {
return nil, nil
}
inferenceProvider, modelName, nameErr := splitIntoModelProvider(bifrostReq.Model)
if nameErr != nil {
return nil, nameErr
}
var hfReq *HuggingFaceEmbeddingRequest
if inferenceProvider != hfInference {
hfReq = &HuggingFaceEmbeddingRequest{
Model: schemas.Ptr(modelName),
Provider: schemas.Ptr(string(inferenceProvider)),
}
} else {
hfReq = &HuggingFaceEmbeddingRequest{}
}
// Convert input
if bifrostReq.Input != nil {
var input InputsCustomType
if bifrostReq.Input.Text != nil {
input = InputsCustomType{Text: bifrostReq.Input.Text}
} else if bifrostReq.Input.Texts != nil {
input = InputsCustomType{Texts: bifrostReq.Input.Texts}
}
if inferenceProvider == hfInference {
hfReq.Inputs = &input
} else {
hfReq.Input = &input
}
}
// Map parameters
if bifrostReq.Params != nil {
params := bifrostReq.Params
// Map standard parameters
if params.EncodingFormat != nil {
encodingType := EncodingType(*params.EncodingFormat)
hfReq.EncodingFormat = &encodingType
}
if params.Dimensions != nil {
hfReq.Dimensions = params.Dimensions
}
// Check for HuggingFace-specific parameters in ExtraParams
if params.ExtraParams != nil {
if normalize, ok := params.ExtraParams["normalize"].(bool); ok {
delete(params.ExtraParams, "normalize")
hfReq.Normalize = &normalize
}
if promptName, ok := params.ExtraParams["prompt_name"].(string); ok {
delete(params.ExtraParams, "prompt_name")
hfReq.PromptName = &promptName
}
if truncate, ok := params.ExtraParams["truncate"].(bool); ok {
delete(params.ExtraParams, "truncate")
hfReq.Truncate = &truncate
}
if truncationDirection, ok := params.ExtraParams["truncation_direction"].(string); ok {
delete(params.ExtraParams, "truncation_direction")
hfReq.TruncationDirection = &truncationDirection
}
}
hfReq.ExtraParams = params.ExtraParams
}
return hfReq, nil
}
// UnmarshalHuggingFaceEmbeddingResponse unmarshals HuggingFace API response directly into BifrostEmbeddingResponse
// Handles multiple formats: standard object, 2D array, or 1D array
func UnmarshalHuggingFaceEmbeddingResponse(data []byte, model string) (*schemas.BifrostEmbeddingResponse, error) {
if data == nil {
return nil, fmt.Errorf("response data is nil")
}
// Try standard object format first
type tempResponse struct {
Data []schemas.EmbeddingData `json:"data,omitempty"`
Model *string `json:"model,omitempty"`
Usage *schemas.BifrostLLMUsage `json:"usage,omitempty"`
}
var obj tempResponse
if err := sonic.Unmarshal(data, &obj); err == nil {
if obj.Data != nil || obj.Model != nil || obj.Usage != nil {
bifrostResponse := &schemas.BifrostEmbeddingResponse{
Data: obj.Data,
Model: model,
Object: "list",
}
if obj.Model != nil {
bifrostResponse.Model = *obj.Model
}
if obj.Usage != nil {
bifrostResponse.Usage = obj.Usage
} else {
bifrostResponse.Usage = &schemas.BifrostLLMUsage{
PromptTokens: 0,
CompletionTokens: 0,
TotalTokens: 0,
}
}
return bifrostResponse, nil
}
}
// Try 2D array: [[num, ...], ...]
var arr2D [][]float64
if err := sonic.Unmarshal(data, &arr2D); err == nil {
embeddings := make([]schemas.EmbeddingData, len(arr2D))
for idx, embedding := range arr2D {
embeddings[idx] = schemas.EmbeddingData{
Embedding: schemas.EmbeddingStruct{EmbeddingArray: append([]float64(nil), embedding...)},
Index: idx,
Object: "embedding",
}
}
return &schemas.BifrostEmbeddingResponse{
Data: embeddings,
Model: model,
Object: "list",
Usage: &schemas.BifrostLLMUsage{
PromptTokens: 0,
CompletionTokens: 0,
TotalTokens: 0,
},
}, nil
}
// Try 1D array: [num, ...]
var arr1D []float64
if err := sonic.Unmarshal(data, &arr1D); err == nil {
return &schemas.BifrostEmbeddingResponse{
Data: []schemas.EmbeddingData{{
Embedding: schemas.EmbeddingStruct{EmbeddingArray: append([]float64(nil), arr1D...)},
Index: 0,
Object: "embedding",
}},
Model: model,
Object: "list",
Usage: &schemas.BifrostLLMUsage{
PromptTokens: 0,
CompletionTokens: 0,
TotalTokens: 0,
},
}, nil
}
return nil, fmt.Errorf("failed to unmarshal HuggingFace embedding response: unexpected structure")
}

View File

@@ -0,0 +1,57 @@
package huggingface
import (
"fmt"
"strings"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
// parseHuggingFaceImageError parses HuggingFace error responses
func parseHuggingFaceImageError(resp *fasthttp.Response) *schemas.BifrostError {
var errorResp HuggingFaceResponseError
bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp)
if strings.TrimSpace(errorResp.Type) != "" {
typeCopy := errorResp.Type
bifrostErr.Type = &typeCopy
}
if bifrostErr.Error == nil {
bifrostErr.Error = &schemas.ErrorField{}
}
// Handle FastAPI validation errors
if len(errorResp.Detail) > 0 {
var errorMessages []string
for _, detail := range errorResp.Detail {
msg := detail.Msg
if len(detail.Loc) > 0 {
// Build location string from loc array
var locParts []string
for _, locPart := range detail.Loc {
if locStr, ok := locPart.(string); ok {
locParts = append(locParts, locStr)
} else if locNum, ok := locPart.(float64); ok {
locParts = append(locParts, fmt.Sprintf("%.0f", locNum))
}
}
if len(locParts) > 0 {
msg = fmt.Sprintf("%s at %s", msg, strings.Join(locParts, "."))
}
}
errorMessages = append(errorMessages, msg)
}
if len(errorMessages) > 0 {
bifrostErr.Error.Message = strings.Join(errorMessages, "; ")
}
} else if strings.TrimSpace(errorResp.Message) != "" {
bifrostErr.Error.Message = errorResp.Message
} else if strings.TrimSpace(errorResp.Error) != "" {
bifrostErr.Error.Message = errorResp.Error
}
return bifrostErr
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,132 @@
package huggingface_test
import (
"os"
"testing"
"github.com/maximhq/bifrost/core/internal/llmtests"
"github.com/maximhq/bifrost/core/providers/huggingface"
"github.com/maximhq/bifrost/core/schemas"
)
func TestHuggingface(t *testing.T) {
t.Parallel()
if os.Getenv("HUGGING_FACE_API_KEY") == "" {
t.Skip("Skipping HuggingFace tests because HUGGING_FACE_API_KEY is not set")
}
client, ctx, cancel, err := llmtests.SetupTest()
if err != nil {
t.Fatalf("Error initializing test setup: %v", err)
}
defer cancel()
defer client.Shutdown()
testConfig := llmtests.ComprehensiveTestConfig{
Provider: schemas.HuggingFace,
ChatModel: "groq/meta-llama/Llama-3.3-70B-Instruct",
VisionModel: "cohere/CohereLabs/aya-vision-32b",
EmbeddingModel: "sambanova/intfloat/e5-mistral-7b-instruct",
TranscriptionModel: "fal-ai/openai/whisper-large-v3",
SpeechSynthesisModel: "fal-ai/hexgrad/Kokoro-82M",
SpeechSynthesisFallbacks: []schemas.Fallback{
{Provider: schemas.HuggingFace, Model: "fal-ai/ResembleAI/chatterbox"},
},
ReasoningModel: "groq/openai/gpt-oss-120b",
ImageGenerationModel: "fal-ai/fal-ai/flux/dev",
ImageEditModel: "fal-ai/fal-ai/flux-2/edit",
Scenarios: llmtests.TestScenarios{
TextCompletion: false,
TextCompletionStream: false,
SimpleChat: true,
CompletionStream: true,
MultiTurnConversation: true,
ToolCalls: true,
ToolCallsStreaming: true,
MultipleToolCalls: false,
End2EndToolCalling: true,
AutomaticFunctionCall: true,
ImageURL: true,
ImageBase64: true,
MultipleImages: true,
CompleteEnd2End: true,
Embedding: false,
Transcription: true,
TranscriptionStream: false,
SpeechSynthesis: true,
SpeechSynthesisStream: false,
Reasoning: true,
ListModels: true,
BatchCreate: false,
BatchList: false,
BatchRetrieve: false,
BatchCancel: false,
BatchResults: false,
FileUpload: false,
FileList: false,
FileRetrieve: false,
FileDelete: false,
FileContent: false,
FileBatchInput: false,
ImageGeneration: true,
ImageGenerationStream: true,
ImageEdit: true,
ImageEditStream: true,
},
}
t.Run("HuggingFaceTests", func(t *testing.T) {
llmtests.RunAllComprehensiveTests(t, client, ctx, testConfig)
})
}
func TestUnmarshalHuggingFaceEmbeddingResponsePreservesPrecision(t *testing.T) {
const want = 0.12345678901234568
resp, err := huggingface.UnmarshalHuggingFaceEmbeddingResponse([]byte(`[[0.12345678901234568]]`), "test-model")
if err != nil {
t.Fatalf("UnmarshalHuggingFaceEmbeddingResponse failed: %v", err)
}
if resp == nil || len(resp.Data) != 1 {
t.Fatalf("expected single embedding response, got %#v", resp)
}
if len(resp.Data[0].Embedding.EmbeddingArray) != 1 {
t.Fatalf("expected single embedding value, got %#v", resp.Data[0].Embedding.EmbeddingArray)
}
got := resp.Data[0].Embedding.EmbeddingArray[0]
if got != want {
t.Fatalf("expected %0.18f, got %0.18f", want, got)
}
if got == float64(float32(want)) {
t.Fatalf("expected preserved precision, got float32-rounded value %0.18f", got)
}
}
func TestUnmarshalHuggingFaceEmbeddingResponse1DPreservesPrecision(t *testing.T) {
const want = 0.12345678901234568
resp, err := huggingface.UnmarshalHuggingFaceEmbeddingResponse([]byte(`[0.12345678901234568]`), "test-model")
if err != nil {
t.Fatalf("UnmarshalHuggingFaceEmbeddingResponse failed: %v", err)
}
if resp == nil || len(resp.Data) != 1 {
t.Fatalf("expected single embedding response, got %#v", resp)
}
if len(resp.Data[0].Embedding.EmbeddingArray) != 1 {
t.Fatalf("expected single embedding value, got %#v", resp.Data[0].Embedding.EmbeddingArray)
}
got := resp.Data[0].Embedding.EmbeddingArray[0]
if got != want {
t.Fatalf("expected %0.18f, got %0.18f", want, got)
}
if got == float64(float32(want)) {
t.Fatalf("expected preserved precision, got float32-rounded value %0.18f", got)
}
}

View File

@@ -0,0 +1,603 @@
package huggingface
import (
"encoding/base64"
"fmt"
"net/http"
"strconv"
"strings"
"github.com/bytedance/sonic"
nebiusProvider "github.com/maximhq/bifrost/core/providers/nebius"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
)
// Models that support multiple images (image_urls)
var falAIMultiImageEditModels = map[string]bool{
"fal-ai/flux-2/edit": true,
"fal-ai/flux-2-pro/edit": true,
}
// Models that only support single image (image_url)
var falAISingleImageEditModels = map[string]bool{
"fal-ai/flux-pro/kontext": true,
"fal-ai/flux/dev/image-to-image": true,
}
// ToHuggingFaceImageGenerationRequest converts a Bifrost image generation request to provider-specific format
func ToHuggingFaceImageGenerationRequest(bifrostReq *schemas.BifrostImageGenerationRequest) (providerUtils.RequestBodyWithExtraParams, error) {
if bifrostReq == nil || bifrostReq.Input == nil {
return nil, fmt.Errorf("bifrost request is nil or input is nil")
}
inferenceProvider, model, nameErr := splitIntoModelProvider(bifrostReq.Model)
if nameErr != nil {
return nil, nameErr
}
switch inferenceProvider {
case nebius:
req := &nebiusProvider.NebiusImageGenerationRequest{
Model: &model,
Prompt: &bifrostReq.Input.Prompt,
}
if bifrostReq.Params != nil {
if bifrostReq.Params.ResponseFormat != nil {
req.ResponseFormat = bifrostReq.Params.ResponseFormat
}
if bifrostReq.Params.Size != nil && strings.ToLower(*bifrostReq.Params.Size) != "auto" {
size := strings.Split(strings.ToLower(*bifrostReq.Params.Size), "x")
if len(size) != 2 {
return nil, fmt.Errorf("invalid size format: expected 'WIDTHxHEIGHT', got %q", *bifrostReq.Params.Size)
}
width, err := strconv.Atoi(size[0])
if err != nil {
return nil, fmt.Errorf("invalid width in size %q: %w", *bifrostReq.Params.Size, err)
}
height, err := strconv.Atoi(size[1])
if err != nil {
return nil, fmt.Errorf("invalid height in size %q: %w", *bifrostReq.Params.Size, err)
}
req.Width = &width
req.Height = &height
}
if bifrostReq.Params.OutputFormat != nil {
req.ResponseExtension = bifrostReq.Params.OutputFormat
}
// Handle nebius inconsistency - normalize ResponseExtension case-insensitively
if req.ResponseExtension != nil && strings.ToLower(*req.ResponseExtension) == "jpeg" {
req.ResponseExtension = schemas.Ptr("jpg")
}
// Map seed from direct field
if bifrostReq.Params.Seed != nil {
req.Seed = bifrostReq.Params.Seed
}
// Map negative_prompt from direct field
if bifrostReq.Params.NegativePrompt != nil {
req.NegativePrompt = bifrostReq.Params.NegativePrompt
}
// Handle extra params for nebius
if bifrostReq.Params.ExtraParams != nil {
req.ExtraParams = bifrostReq.Params.ExtraParams
// Map num_inference_steps
if v, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["num_inference_steps"]); ok {
delete(req.ExtraParams, "num_inference_steps")
req.NumInferenceSteps = v
}
// Map guidance_scale
if v, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["guidance_scale"]); ok {
delete(req.ExtraParams, "guidance_scale")
req.GuidanceScale = v
}
// Map loras
if lorasValue, exists := bifrostReq.Params.ExtraParams["loras"]; exists && lorasValue != nil {
delete(req.ExtraParams, "loras")
if lorasArray, ok := lorasValue.([]interface{}); ok {
for _, item := range lorasArray {
if loraMap, ok := item.(map[string]interface{}); ok {
if url, ok := schemas.SafeExtractString(loraMap["url"]); ok {
if scale, ok := schemas.SafeExtractInt(loraMap["scale"]); ok {
req.Loras = append(req.Loras, nebiusProvider.NebiusLora{URL: url, Scale: scale})
}
}
}
}
}
}
}
}
return req, nil
case hfInference:
req := &HuggingFaceHFInferenceImageGenerationRequest{
Inputs: bifrostReq.Input.Prompt,
}
if bifrostReq.Params != nil {
req.ExtraParams = bifrostReq.Params.ExtraParams
}
return req, nil
case falAI:
req := &HuggingFaceFalAIImageGenerationRequest{
Prompt: bifrostReq.Input.Prompt,
}
if bifrostReq.Params != nil {
// Map n to num_images for fal-ai
if bifrostReq.Params.N != nil {
req.NumImages = bifrostReq.Params.N
}
// Pass through response_format
if bifrostReq.Params.ResponseFormat != nil {
req.ResponseFormat = bifrostReq.Params.ResponseFormat
}
// Pass through output_format
if bifrostReq.Params.OutputFormat != nil {
if strings.ToLower(*bifrostReq.Params.OutputFormat) == "jpg" {
req.OutputFormat = schemas.Ptr("jpeg")
} else {
req.OutputFormat = bifrostReq.Params.OutputFormat
}
}
// Convert size from "WxH" format to fal-ai's image_size object
if bifrostReq.Params.Size != nil && strings.ToLower(*bifrostReq.Params.Size) != "auto" {
size := strings.Split(*bifrostReq.Params.Size, "x")
if len(size) == 2 {
width, err := strconv.Atoi(size[0])
if err == nil {
height, err := strconv.Atoi(size[1])
if err == nil {
req.ImageSize = &HuggingFaceFalAISize{
Width: width,
Height: height,
}
}
}
}
}
if bifrostReq.Params.ResponseFormat != nil && *bifrostReq.Params.ResponseFormat == "b64_json" {
req.SyncMode = schemas.Ptr(true)
}
if bifrostReq.Params.Moderation != nil && *bifrostReq.Params.Moderation == "low" {
req.EnableSafetyChecker = schemas.Ptr(false)
}
// Map seed from direct field
if bifrostReq.Params.Seed != nil {
req.Seed = bifrostReq.Params.Seed
}
// Map negative_prompt from direct field
if bifrostReq.Params.NegativePrompt != nil {
req.NegativePrompt = bifrostReq.Params.NegativePrompt
}
// Map num_inference_steps from direct field
if bifrostReq.Params.NumInferenceSteps != nil {
req.NumInferenceSteps = bifrostReq.Params.NumInferenceSteps
}
// Parse fal-ai specific params from ExtraParams
if bifrostReq.Params.ExtraParams != nil {
req.ExtraParams = bifrostReq.Params.ExtraParams
// Map guidance_scale
if v, ok := schemas.SafeExtractFloat64Pointer(bifrostReq.Params.ExtraParams["guidance_scale"]); ok {
delete(req.ExtraParams, "guidance_scale")
req.GuidanceScale = v
}
// Map acceleration
if v, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["acceleration"]); ok {
delete(req.ExtraParams, "acceleration")
req.Acceleration = v
}
// Map enable_prompt_expansion
if v, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["enable_prompt_expansion"]); ok {
delete(req.ExtraParams, "enable_prompt_expansion")
req.EnablePromptExpansion = v
}
// Map enable_safety_checker
if v, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["enable_safety_checker"]); ok {
delete(req.ExtraParams, "enable_safety_checker")
req.EnableSafetyChecker = v
}
}
}
return req, nil
case together:
req := &HuggingFaceTogetherImageGenerationRequest{
Prompt: bifrostReq.Input.Prompt,
Model: model,
}
if bifrostReq.Params != nil {
req.ExtraParams = bifrostReq.Params.ExtraParams
if bifrostReq.Params.ResponseFormat != nil {
req.ResponseFormat = bifrostReq.Params.ResponseFormat
}
if bifrostReq.Params.Size != nil {
req.Size = bifrostReq.Params.Size
}
if bifrostReq.Params.N != nil {
req.N = bifrostReq.Params.N
}
if bifrostReq.Params.ResponseFormat != nil && *bifrostReq.Params.ResponseFormat == "b64_json" {
req.ResponseFormat = schemas.Ptr("base64")
}
if bifrostReq.Params.NumInferenceSteps != nil {
req.Steps = bifrostReq.Params.NumInferenceSteps
}
}
return req, nil
default:
return nil, fmt.Errorf("unsupported inference provider for image generation: %s", inferenceProvider)
}
}
// ToHuggingFaceImageStreamRequest converts a Bifrost image generation request to fal-ai streaming format
func ToHuggingFaceImageStreamRequest(bifrostReq *schemas.BifrostImageGenerationRequest) (*HuggingFaceFalAIImageStreamRequest, error) {
if bifrostReq == nil || bifrostReq.Input == nil {
return nil, fmt.Errorf("bifrost request is nil or input is nil")
}
req := &HuggingFaceFalAIImageStreamRequest{
Prompt: bifrostReq.Input.Prompt,
}
if bifrostReq.Params != nil {
req.ExtraParams = bifrostReq.Params.ExtraParams
// Map n to num_images for fal-ai
if bifrostReq.Params.N != nil {
req.NumImages = bifrostReq.Params.N
}
// Pass through response_format
if bifrostReq.Params.ResponseFormat != nil {
req.ResponseFormat = bifrostReq.Params.ResponseFormat
}
// Pass through output_format
// Convert "jpg" to "jpeg" for fal-ai (fal-ai only accepts "jpeg", "png", "webp")
if bifrostReq.Params.OutputFormat != nil {
if strings.ToLower(*bifrostReq.Params.OutputFormat) == "jpg" {
req.OutputFormat = schemas.Ptr("jpeg")
} else {
req.OutputFormat = bifrostReq.Params.OutputFormat
}
}
// Convert size from "WxH" format to fal-ai's image_size object
if bifrostReq.Params.Size != nil && strings.ToLower(*bifrostReq.Params.Size) != "auto" {
size := strings.Split(*bifrostReq.Params.Size, "x")
if len(size) == 2 {
width, err := strconv.Atoi(size[0])
if err == nil {
height, err := strconv.Atoi(size[1])
if err == nil {
req.ImageSize = &HuggingFaceFalAISize{
Width: width,
Height: height,
}
}
}
}
}
if bifrostReq.Params.Seed != nil {
req.Seed = bifrostReq.Params.Seed
}
if bifrostReq.Params.NumInferenceSteps != nil {
req.NumInferenceSteps = bifrostReq.Params.NumInferenceSteps
}
if bifrostReq.Params.ResponseFormat != nil && *bifrostReq.Params.ResponseFormat == "b64_json" {
req.SyncMode = schemas.Ptr(true)
}
if bifrostReq.Params.Moderation != nil && *bifrostReq.Params.Moderation == "low" {
req.EnableSafetyChecker = schemas.Ptr(false)
}
// Parse fal-ai specific params from ExtraParams
if bifrostReq.Params.ExtraParams != nil {
if v, ok := schemas.SafeExtractFloat64Pointer(bifrostReq.Params.ExtraParams["guidance_scale"]); ok {
delete(req.ExtraParams, "guidance_scale")
req.GuidanceScale = v
}
if v, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["acceleration"]); ok {
delete(req.ExtraParams, "acceleration")
req.Acceleration = v
}
if v, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["enable_prompt_expansion"]); ok {
delete(req.ExtraParams, "enable_prompt_expansion")
req.EnablePromptExpansion = v
}
if v, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["enable_safety_checker"]); ok {
delete(req.ExtraParams, "enable_safety_checker")
req.EnableSafetyChecker = v
}
}
}
return req, nil
}
// UnmarshalHuggingFaceImageGenerationResponse unmarshals HuggingFace image generation response to Bifrost format
func UnmarshalHuggingFaceImageGenerationResponse(data []byte, model string) (*schemas.BifrostImageGenerationResponse, error) {
if data == nil {
return nil, fmt.Errorf("response data is nil")
}
inferenceProvider, _, err := splitIntoModelProvider(model)
if err != nil {
return nil, err
}
switch inferenceProvider {
case nebius:
// Unmarshal into Nebius response format
var nebiusResponse nebiusProvider.NebiusImageGenerationResponse
if err := sonic.Unmarshal(data, &nebiusResponse); err != nil {
return nil, fmt.Errorf("failed to unmarshal Nebius response: %w", err)
}
// Convert to Bifrost format using Nebius converter
bifrostResponse := nebiusProvider.ToBifrostImageResponse(&nebiusResponse)
if bifrostResponse == nil {
return nil, fmt.Errorf("failed to convert Nebius response to Bifrost format")
}
// Set model field (Nebius converter doesn't set it, similar to embeddings pattern)
if bifrostResponse.Model == "" {
bifrostResponse.Model = model
}
return bifrostResponse, nil
case hfInference:
// Handle raw byte data - encode to base64
b64Data := base64.StdEncoding.EncodeToString(data)
return &schemas.BifrostImageGenerationResponse{
Model: model,
Data: []schemas.ImageData{
{
B64JSON: b64Data,
Index: 0,
},
},
}, nil
case falAI:
// Handle fal-ai JSON response
var falResponse HuggingFaceFalAIImageGenerationResponse
if err := sonic.Unmarshal(data, &falResponse); err != nil {
return nil, fmt.Errorf("failed to unmarshal fal-ai response: %w", err)
}
imageData := make([]schemas.ImageData, len(falResponse.Images))
for i, img := range falResponse.Images {
// Handle both URL and base64 responses
imageData[i] = schemas.ImageData{
URL: img.URL,
B64JSON: img.B64JSON,
Index: i,
}
}
return &schemas.BifrostImageGenerationResponse{
Model: model,
Data: imageData,
}, nil
case together:
// Handle together JSON response
var togetherResponse HuggingFaceTogetherImageGenerationResponse
if err := sonic.Unmarshal(data, &togetherResponse); err != nil {
return nil, fmt.Errorf("failed to unmarshal together response: %w", err)
}
imageData := make([]schemas.ImageData, len(togetherResponse.Data))
for i, img := range togetherResponse.Data {
imageData[i] = schemas.ImageData{
B64JSON: img.B64JSON,
URL: img.URL,
Index: i,
}
}
return &schemas.BifrostImageGenerationResponse{
Model: model,
Data: imageData,
}, nil
default:
return nil, fmt.Errorf("unsupported inference provider: %s", inferenceProvider)
}
}
// imageBytesToBase64DataURL converts raw image bytes to base64 data URL format
func imageBytesToBase64DataURL(imageBytes []byte) string {
mimeType := http.DetectContentType(imageBytes)
b64Data := base64.StdEncoding.EncodeToString(imageBytes)
return fmt.Sprintf("data:%s;base64,%s", mimeType, b64Data)
}
// mapFalAIImageEditParams maps common parameters from Bifrost request to fal-ai request
func mapFalAIImageEditParams(bifrostReq *schemas.BifrostImageEditRequest, req *HuggingFaceFalAIImageEditRequest) {
if bifrostReq.Params == nil {
return
}
// Map n to num_images for fal-ai
if bifrostReq.Params.N != nil {
req.NumImages = bifrostReq.Params.N
}
// Pass through output_format
if bifrostReq.Params.OutputFormat != nil {
if strings.ToLower(*bifrostReq.Params.OutputFormat) == "jpg" {
req.OutputFormat = schemas.Ptr("jpeg")
} else {
req.OutputFormat = bifrostReq.Params.OutputFormat
}
}
if bifrostReq.Params.ResponseFormat != nil && *bifrostReq.Params.ResponseFormat == "b64_json" {
req.SyncMode = schemas.Ptr(true)
}
// Convert size from "WxH" format to fal-ai's image_size object
if bifrostReq.Params.Size != nil && strings.ToLower(*bifrostReq.Params.Size) != "auto" {
size := strings.Split(*bifrostReq.Params.Size, "x")
if len(size) == 2 {
width, err := strconv.Atoi(size[0])
if err == nil {
height, err := strconv.Atoi(size[1])
if err == nil {
req.ImageSize = &HuggingFaceFalAISize{
Width: width,
Height: height,
}
}
}
}
}
// Pass-through num_inference_steps
if bifrostReq.Params.NumInferenceSteps != nil {
req.NumInferenceSteps = bifrostReq.Params.NumInferenceSteps
}
// Pass-through seed
if bifrostReq.Params.Seed != nil {
req.Seed = bifrostReq.Params.Seed
}
// Parse fal-ai specific params from ExtraParams
if bifrostReq.Params.ExtraParams != nil {
// Map guidance_scale
if v, ok := schemas.SafeExtractFloat64Pointer(bifrostReq.Params.ExtraParams["guidance_scale"]); ok {
delete(req.ExtraParams, "guidance_scale")
req.GuidanceScale = v
}
// Map acceleration
if v, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["acceleration"]); ok {
delete(req.ExtraParams, "acceleration")
req.Acceleration = v
}
// Map enable_safety_checker
if v, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["enable_safety_checker"]); ok {
delete(req.ExtraParams, "enable_safety_checker")
req.EnableSafetyChecker = v
}
}
}
// ToHuggingFaceImageEditRequest converts a Bifrost image edit request to fal-ai format
func ToHuggingFaceImageEditRequest(bifrostReq *schemas.BifrostImageEditRequest) (*HuggingFaceFalAIImageEditRequest, error) {
if bifrostReq == nil || bifrostReq.Input == nil {
return nil, fmt.Errorf("bifrost request is nil or input is nil")
}
if len(bifrostReq.Input.Images) == 0 {
return nil, fmt.Errorf("at least one image is required")
}
// Convert images to base64 data URLs
imageURLs := make([]string, 0, len(bifrostReq.Input.Images))
for _, img := range bifrostReq.Input.Images {
if len(img.Image) == 0 {
continue
}
imageURLs = append(imageURLs, imageBytesToBase64DataURL(img.Image))
}
if len(imageURLs) == 0 {
return nil, fmt.Errorf("no valid images found")
}
// Extract model name to determine image field strategy
_, modelName, err := splitIntoModelProvider(bifrostReq.Model)
if err != nil {
return nil, fmt.Errorf("failed to split model name: %w", err)
}
req := &HuggingFaceFalAIImageEditRequest{
Prompt: bifrostReq.Input.Prompt,
}
// Check for explicit override in ExtraParams
var useMultiImage *bool
if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil {
req.ExtraParams = bifrostReq.Params.ExtraParams
if v, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["use_image_urls"]); ok {
delete(req.ExtraParams, "use_image_urls")
useMultiImage = v
}
}
// Determine which image field to use based on model capabilities
if useMultiImage != nil {
// Explicit override from user
if *useMultiImage {
req.ImageURLs = imageURLs
} else if len(imageURLs) == 1 {
req.ImageURL = &imageURLs[0]
} else {
return nil, fmt.Errorf("use_image_urls is false but multiple images provided (%d images)", len(imageURLs))
}
} else if falAIMultiImageEditModels[modelName] {
// Model supports multiple images - always use image_urls
req.ImageURLs = imageURLs
} else if falAISingleImageEditModels[modelName] {
// Model only supports single image - validate and use image_url
if len(imageURLs) == 1 {
req.ImageURL = &imageURLs[0]
} else {
return nil, fmt.Errorf("model %s only supports single image, got %d images", modelName, len(imageURLs))
}
} else {
// Unknown model - fallback to count-based logic
if len(imageURLs) == 1 {
req.ImageURL = &imageURLs[0]
} else {
req.ImageURLs = imageURLs
}
}
// Map common parameters
mapFalAIImageEditParams(bifrostReq, req)
return req, nil
}
// extractImagesFromStreamResponse extracts images from a fal-ai streaming response.
// Handles both API envelope structure (Data.Images) and legacy flattened format (top-level Images).
func extractImagesFromStreamResponse(response *HuggingFaceFalAIImageStreamResponse) []FalAIImage {
// Prefer Data.Images if available (API envelope structure)
if response.Data != nil && len(response.Data.Images) > 0 {
return response.Data.Images
}
// Fall back to top-level Images (legacy format)
return response.Images
}

View File

@@ -0,0 +1,140 @@
package huggingface
import (
"fmt"
"slices"
"strings"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
)
const (
defaultModelFetchLimit = 200
maxModelFetchLimit = 1000
)
func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, inferenceProvider inferenceProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostListModelsResponse{
Data: make([]schemas.Model, 0, len(response.Models)),
}
pipeline := &providerUtils.ListModelsPipeline{
AllowedModels: allowedModels,
BlacklistedModels: blacklistedModels,
Aliases: aliases,
Unfiltered: unfiltered,
ProviderKey: providerKey,
MatchFns: providerUtils.DefaultMatchFns(),
}
if pipeline.ShouldEarlyExit() {
return bifrostResponse
}
included := make(map[string]bool)
for _, model := range response.Models {
if model.ModelID == "" {
continue
}
supported := deriveSupportedMethods(model.PipelineTag, model.Tags)
if len(supported) == 0 {
continue
}
// Aliases apply at the model level (model.ModelID), not at the compound
// "{providerKey}/{inferenceProvider}/{modelID}" level.
for _, result := range pipeline.FilterModel(model.ModelID) {
newModel := schemas.Model{
// inferenceProvider stays in the compound ID; aliases rename only the model segment
ID: fmt.Sprintf("%s/%s/%s", providerKey, inferenceProvider, result.ResolvedID),
Name: schemas.Ptr(model.ModelID),
SupportedMethods: supported,
HuggingFaceID: schemas.Ptr(model.ID),
}
if result.AliasValue != "" {
newModel.Alias = schemas.Ptr(result.AliasValue)
}
bifrostResponse.Data = append(bifrostResponse.Data, newModel)
included[strings.ToLower(result.ResolvedID)] = true
}
}
// Backfill: use standard pipeline. Note that backfilled HF entries use a simplified
// compound ID since we don't know which inferenceProvider to assign them to.
for _, m := range pipeline.BackfillModels(included) {
// Re-wrap the backfill ID to include the inferenceProvider segment
rawID := strings.TrimPrefix(m.ID, string(providerKey)+"/")
m.ID = fmt.Sprintf("%s/%s/%s", providerKey, inferenceProvider, rawID)
bifrostResponse.Data = append(bifrostResponse.Data, m)
}
return bifrostResponse
}
func deriveSupportedMethods(pipeline string, tags []string) []string {
normalized := strings.TrimSpace(strings.ToLower(pipeline))
methodsSet := map[schemas.RequestType]struct{}{}
addMethods := func(methods ...schemas.RequestType) {
for _, method := range methods {
methodsSet[method] = struct{}{}
}
}
switch normalized {
case "conversational", "chat-completion":
addMethods(schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest,
schemas.ResponsesRequest, schemas.ResponsesStreamRequest)
case "feature-extraction":
addMethods(schemas.EmbeddingRequest)
case "text-to-speech":
addMethods(schemas.SpeechRequest)
case "automatic-speech-recognition":
addMethods(schemas.TranscriptionRequest)
case "text-to-image":
addMethods(schemas.ImageGenerationRequest, schemas.ImageGenerationStreamRequest)
}
for _, tag := range tags {
tagLower := strings.ToLower(tag)
switch {
case tagLower == "text-embedding" || tagLower == "sentence-similarity" ||
tagLower == "feature-extraction" || tagLower == "embeddings" ||
tagLower == "sentence-transformers" || strings.Contains(tagLower, "embedding"):
addMethods(schemas.EmbeddingRequest)
case tagLower == "text-generation" || tagLower == "summarization" ||
tagLower == "conversational" || tagLower == "chat-completion" ||
tagLower == "text2text-generation" || tagLower == "question-answering" ||
strings.Contains(tagLower, "chat") || strings.Contains(tagLower, "completion"):
addMethods(schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest,
schemas.ResponsesRequest, schemas.ResponsesStreamRequest)
case tagLower == "text-to-speech" || tagLower == "tts" ||
strings.Contains(tagLower, "text-to-speech"):
addMethods(schemas.SpeechRequest)
case tagLower == "automatic-speech-recognition" ||
tagLower == "speech-to-text" || strings.Contains(tagLower, "speech-recognition"):
addMethods(schemas.TranscriptionRequest)
case tagLower == "text-to-image" || strings.Contains(tagLower, "image-generation"):
addMethods(schemas.ImageGenerationRequest, schemas.ImageGenerationStreamRequest)
}
}
if len(methodsSet) == 0 {
return nil
}
methods := make([]string, 0, len(methodsSet))
for method := range methodsSet {
methods = append(methods, string(method))
}
slices.Sort(methods)
return methods
}

View File

@@ -0,0 +1,51 @@
package huggingface
import (
"testing"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPayloadOrdering_HuggingFaceChatRequest(t *testing.T) {
req := &HuggingFaceChatRequest{
Model: "meta-llama/Llama-3-70B-Instruct",
Messages: []schemas.ChatMessage{
{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("hello")}},
},
Temperature: schemas.Ptr(0.7),
Stream: schemas.Ptr(true),
Tools: []schemas.ChatTool{
{
Type: "function",
Function: &schemas.ChatToolFunction{
Name: "get_weather",
Description: schemas.Ptr("Get weather"),
Parameters: &schemas.ToolFunctionParameters{
Type: "object",
Properties: schemas.NewOrderedMapFromPairs(
schemas.KV("location", map[string]interface{}{"type": "string"}),
),
Required: []string{"location"},
},
},
},
},
}
result, err := providerUtils.MarshalSorted(req)
require.NoError(t, err)
golden := `{"messages":[{"role":"user","content":"hello"}],"model":"meta-llama/Llama-3-70B-Instruct","stream":true,"temperature":0.7,"tools":[{"type":"function","function":{"name":"get_weather","description":"Get weather","parameters":{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}}}]}`
assert.Equal(t, golden, string(result), "payload field ordering changed — if intentional, update the golden string")
// Determinism: 100 iterations must produce identical bytes
for i := 0; i < 100; i++ {
iter, err := providerUtils.MarshalSorted(req)
require.NoError(t, err)
assert.Equal(t, string(result), string(iter), "non-deterministic marshal output on iteration %d", i)
}
}

View File

@@ -0,0 +1,49 @@
package huggingface
import (
"fmt"
"github.com/maximhq/bifrost/core/schemas"
)
// ToHuggingFaceResponsesRequest converts a Bifrost Responses request into the Hugging Face
// chat-completions payload that the provider already understands.
func ToHuggingFaceResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) (*HuggingFaceChatRequest, error) {
if bifrostReq == nil {
return nil, nil
}
chatReq := bifrostReq.ToChatRequest()
if chatReq == nil {
return nil, fmt.Errorf("failed to convert responses request to chat request")
}
hfReq, err := ToHuggingFaceChatCompletionRequest(chatReq)
if err != nil {
return nil, err
}
if hfReq == nil {
return nil, fmt.Errorf("failed to convert chat request to Hugging Face request")
}
return hfReq, nil
}
// ToBifrostResponsesResponseFromHuggingFace converts a Bifrost chat response into the
// Bifrost Responses response shape, preserving provider metadata.
func ToBifrostResponsesResponseFromHuggingFace(resp *schemas.BifrostChatResponse, requestedModel string) (*schemas.BifrostResponsesResponse, error) {
if resp == nil {
return nil, nil
}
// Ensure model is set
if resp.Model == "" {
resp.Model = requestedModel
}
responsesResp := resp.ToBifrostResponsesResponse()
if responsesResp != nil {
}
return responsesResp, nil
}

View File

@@ -0,0 +1,134 @@
package huggingface
import (
"fmt"
schemas "github.com/maximhq/bifrost/core/schemas"
)
func ToHuggingFaceSpeechRequest(request *schemas.BifrostSpeechRequest) (*HuggingFaceSpeechRequest, error) {
if request == nil {
return nil, nil
}
if request.Input == nil {
return nil, fmt.Errorf("speech request input cannot be nil")
}
inferenceProvider, modelName, nameErr := splitIntoModelProvider(request.Model)
if nameErr != nil {
return nil, nameErr
}
// HuggingFace expects text in the Text field (for TTS - Text To Speech)
hfRequest := &HuggingFaceSpeechRequest{
Text: request.Input.Input,
Model: modelName,
Provider: string(inferenceProvider),
}
// Map parameters if present
if request.Params != nil {
hfRequest.Parameters = &HuggingFaceSpeechParameters{}
// Map generation parameters from ExtraParams if available
if request.Params.ExtraParams != nil {
genParams := &HuggingFaceTranscriptionGenerationParameters{}
if val, ok := request.Params.ExtraParams["do_sample"].(bool); ok {
delete(request.Params.ExtraParams, "do_sample")
genParams.DoSample = &val
}
if v, ok := schemas.SafeExtractIntPointer(request.Params.ExtraParams["max_new_tokens"]); ok {
delete(request.Params.ExtraParams, "max_new_tokens")
genParams.MaxNewTokens = v
}
if v, ok := schemas.SafeExtractIntPointer(request.Params.ExtraParams["max_length"]); ok {
delete(request.Params.ExtraParams, "max_length")
genParams.MaxLength = v
}
if v, ok := schemas.SafeExtractIntPointer(request.Params.ExtraParams["min_length"]); ok {
delete(request.Params.ExtraParams, "min_length")
genParams.MinLength = v
}
if v, ok := schemas.SafeExtractIntPointer(request.Params.ExtraParams["min_new_tokens"]); ok {
delete(request.Params.ExtraParams, "min_new_tokens")
genParams.MinNewTokens = v
}
if v, ok := schemas.SafeExtractIntPointer(request.Params.ExtraParams["num_beams"]); ok {
delete(request.Params.ExtraParams, "num_beams")
genParams.NumBeams = v
}
if v, ok := schemas.SafeExtractIntPointer(request.Params.ExtraParams["num_beam_groups"]); ok {
delete(request.Params.ExtraParams, "num_beam_groups")
genParams.NumBeamGroups = v
}
if val, ok := request.Params.ExtraParams["penalty_alpha"].(float64); ok {
delete(request.Params.ExtraParams, "penalty_alpha")
genParams.PenaltyAlpha = &val
}
if val, ok := request.Params.ExtraParams["temperature"].(float64); ok {
delete(request.Params.ExtraParams, "temperature")
genParams.Temperature = &val
}
if v, ok := schemas.SafeExtractIntPointer(request.Params.ExtraParams["top_k"]); ok {
delete(request.Params.ExtraParams, "top_k")
genParams.TopK = v
}
if val, ok := request.Params.ExtraParams["top_p"].(float64); ok {
delete(request.Params.ExtraParams, "top_p")
genParams.TopP = &val
}
if val, ok := request.Params.ExtraParams["typical_p"].(float64); ok {
delete(request.Params.ExtraParams, "typical_p")
genParams.TypicalP = &val
}
if val, ok := request.Params.ExtraParams["use_cache"].(bool); ok {
delete(request.Params.ExtraParams, "use_cache")
genParams.UseCache = &val
}
if val, ok := request.Params.ExtraParams["epsilon_cutoff"].(float64); ok {
delete(request.Params.ExtraParams, "epsilon_cutoff")
genParams.EpsilonCutoff = &val
}
if val, ok := request.Params.ExtraParams["eta_cutoff"].(float64); ok {
delete(request.Params.ExtraParams, "eta_cutoff")
genParams.EtaCutoff = &val
}
// Handle early_stopping (can be bool or string "never")
if val, ok := request.Params.ExtraParams["early_stopping"].(bool); ok {
delete(request.Params.ExtraParams, "early_stopping")
genParams.EarlyStopping = &HuggingFaceTranscriptionEarlyStopping{BoolValue: &val}
} else if val, ok := request.Params.ExtraParams["early_stopping"].(string); ok {
delete(request.Params.ExtraParams, "early_stopping")
genParams.EarlyStopping = &HuggingFaceTranscriptionEarlyStopping{StringValue: &val}
}
hfRequest.Parameters.GenerationParameters = genParams
}
}
hfRequest.ExtraParams = request.Params.ExtraParams
return hfRequest, nil
}
func (response *HuggingFaceSpeechResponse) ToBifrostSpeechResponse(requestedModel string, audioData []byte) (*schemas.BifrostSpeechResponse, error) {
if response == nil {
return nil, nil
}
if requestedModel == "" {
return nil, fmt.Errorf("model name cannot be empty")
}
// Create the base Bifrost response with the downloaded audio data
bifrostResponse := &schemas.BifrostSpeechResponse{
Audio: audioData,
}
// Note: HuggingFace TTS API typically doesn't return usage information
// or alignment data, so we leave those fields as nil
return bifrostResponse, nil
}

View File

@@ -0,0 +1,170 @@
package huggingface
import (
"encoding/base64"
"fmt"
"github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
)
func ToHuggingFaceTranscriptionRequest(request *schemas.BifrostTranscriptionRequest) (*HuggingFaceTranscriptionRequest, error) {
if request == nil {
return nil, nil
}
if request.Input == nil {
return nil, fmt.Errorf("transcription request input cannot be nil")
}
if len(request.Input.File) == 0 {
return nil, fmt.Errorf("transcription request audio file cannot be empty")
}
inferenceProvider, modelName, nameErr := splitIntoModelProvider(request.Model)
if nameErr != nil {
return nil, nameErr
}
var hfRequest *HuggingFaceTranscriptionRequest
// HuggingFace expects audio data in the Inputs field (for ASR - Automatic Speech Recognition)
if inferenceProvider != falAI {
hfRequest = &HuggingFaceTranscriptionRequest{
Inputs: request.Input.File,
Model: schemas.Ptr(modelName),
Provider: schemas.Ptr(string(inferenceProvider)),
}
} else {
encoded := base64.StdEncoding.EncodeToString(request.Input.File)
mimeType := getMimeTypeForAudioType(utils.DetectAudioMimeType(request.Input.File))
if mimeType == "audio/wav" {
return nil, fmt.Errorf("fal-ai provider does not support audio/wav format; please use a different format like mp3 or ogg")
}
encoded = fmt.Sprintf("data:%s;base64,%s", mimeType, encoded)
hfRequest = &HuggingFaceTranscriptionRequest{
AudioURL: encoded,
}
}
// Map parameters if present
if request.Params != nil {
hfRequest.Parameters = &HuggingFaceTranscriptionRequestParameters{}
genParams := &HuggingFaceTranscriptionGenerationParameters{}
if v, ok := schemas.SafeExtractIntPointer(request.Params.MaxNewTokens); ok {
genParams.MaxNewTokens = v
}
if v, ok := schemas.SafeExtractIntPointer(request.Params.MaxLength); ok {
genParams.MaxLength = v
}
if v, ok := schemas.SafeExtractIntPointer(request.Params.MinLength); ok {
genParams.MinLength = v
}
if v, ok := schemas.SafeExtractIntPointer(request.Params.MinNewTokens); ok {
genParams.MinNewTokens = v
}
if request.Params.ExtraParams != nil {
extra := request.Params.ExtraParams
if val, ok := extra["do_sample"].(bool); ok {
delete(extra, "do_sample")
genParams.DoSample = &val
}
if v, ok := schemas.SafeExtractIntPointer(extra["num_beams"]); ok {
delete(extra, "num_beams")
genParams.NumBeams = v
}
if v, ok := schemas.SafeExtractIntPointer(extra["num_beam_groups"]); ok {
delete(extra, "num_beam_groups")
genParams.NumBeamGroups = v
}
if val, ok := extra["penalty_alpha"].(float64); ok {
delete(extra, "penalty_alpha")
genParams.PenaltyAlpha = &val
}
if val, ok := extra["temperature"].(float64); ok {
delete(extra, "temperature")
genParams.Temperature = &val
}
if v, ok := schemas.SafeExtractIntPointer(extra["top_k"]); ok {
delete(extra, "top_k")
genParams.TopK = v
}
if val, ok := extra["top_p"].(float64); ok {
delete(extra, "top_p")
genParams.TopP = &val
}
if val, ok := extra["typical_p"].(float64); ok {
delete(extra, "typical_p")
genParams.TypicalP = &val
}
if val, ok := extra["use_cache"].(bool); ok {
delete(extra, "use_cache")
genParams.UseCache = &val
}
if val, ok := extra["epsilon_cutoff"].(float64); ok {
delete(extra, "epsilon_cutoff")
genParams.EpsilonCutoff = &val
}
if val, ok := extra["eta_cutoff"].(float64); ok {
delete(extra, "eta_cutoff")
genParams.EtaCutoff = &val
}
// Handle early_stopping (can be bool or string "never")
if val, ok := extra["early_stopping"].(bool); ok {
delete(extra, "early_stopping")
genParams.EarlyStopping = &HuggingFaceTranscriptionEarlyStopping{BoolValue: &val}
} else if val, ok := extra["early_stopping"].(string); ok {
delete(extra, "early_stopping")
genParams.EarlyStopping = &HuggingFaceTranscriptionEarlyStopping{StringValue: &val}
}
// Handle return_timestamps
if val, ok := extra["return_timestamps"].(bool); ok {
delete(extra, "return_timestamps")
hfRequest.Parameters.ReturnTimestamps = &val
}
}
hfRequest.ExtraParams = request.Params.ExtraParams
hfRequest.Parameters.GenerationParameters = genParams
}
return hfRequest, nil
}
func (response *HuggingFaceTranscriptionResponse) ToBifrostTranscriptionResponse(requestedModel string) (*schemas.BifrostTranscriptionResponse, error) {
if response == nil {
return nil, nil
}
if requestedModel == "" {
return nil, fmt.Errorf("model name cannot be empty")
}
// Create the base Bifrost response
bifrostResponse := &schemas.BifrostTranscriptionResponse{
Text: response.Text,
}
// Map chunks to segments if available
if len(response.Chunks) > 0 {
segments := make([]schemas.TranscriptionSegment, len(response.Chunks))
for i, chunk := range response.Chunks {
var start, end float64
if len(chunk.Timestamp) >= 2 {
start = chunk.Timestamp[0]
end = chunk.Timestamp[1]
}
segments[i] = schemas.TranscriptionSegment{
ID: i,
Start: start,
End: end,
Text: chunk.Text,
}
}
bifrostResponse.Segments = segments
}
return bifrostResponse, nil
}

View File

@@ -0,0 +1,529 @@
package huggingface
import (
"bytes"
"encoding/json"
"fmt"
"github.com/bytedance/sonic"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// # MODELS TYPES
// refered from https://huggingface.co/api/models
type HuggingFaceModel struct {
ID string `json:"_id"`
ModelID string `json:"modelId"`
Likes int `json:"likes"`
TrendingScore int `json:"trendingScore"`
Private bool `json:"private"`
Downloads int `json:"downloads"`
Tags []string `json:"tags"`
PipelineTag string `json:"pipeline_tag"`
LibraryName string `json:"library_name"`
CreatedAt string `json:"createdAt"`
}
type HuggingFaceListModelsResponse struct {
Models []HuggingFaceModel `json:"models"`
}
// UnmarshalJSON supports both the older object form `{"models": [...]}`
// and the current API which returns a top-level JSON array `[...]`.
func (r *HuggingFaceListModelsResponse) UnmarshalJSON(data []byte) error {
// Try unmarshaling as an array first (most common for /api/models)
var arr []HuggingFaceModel
if err := sonic.Unmarshal(data, &arr); err == nil {
r.Models = arr
return nil
}
// Fallback: try object with a `models` field
var obj struct {
Models []HuggingFaceModel `json:"models"`
}
if err := sonic.Unmarshal(data, &obj); err == nil {
r.Models = obj.Models
return nil
}
return fmt.Errorf("failed to unmarshal HuggingFaceListModelsResponse: unexpected JSON structure")
}
type HuggingFaceInferenceProviderMappingResponse struct {
ID string `json:"_id"`
ModelID string `json:"id"`
PipelineTag string `json:"pipeline_tag"`
InferenceProviderMapping map[string]HuggingFaceInferenceProviderInfo `json:"inferenceProviderMapping"`
}
type HuggingFaceInferenceProviderInfo struct {
Status string `json:"status"`
ProviderModelID string `json:"providerId"`
Task string `json:"task"`
IsModelAuthor bool `json:"isModelAuthor"`
}
type HuggingFaceInferenceProviderMapping struct {
ProviderTask string
ProviderModelID string
}
// # CHAT TYPES
// Flexible/chat request types for HuggingFace-like chat completion payloads.
type HuggingFaceChatRequest struct {
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
Logprobs *bool `json:"logprobs,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"`
Messages []schemas.ChatMessage `json:"messages"`
Model string `json:"model" validate:"required"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
ResponseFormat *HuggingFaceResponseFormat `json:"response_format,omitempty"`
Seed *int `json:"seed,omitempty"`
Stop []string `json:"stop,omitempty"`
Stream *bool `json:"stream,omitempty"`
StreamOptions *schemas.ChatStreamOptions `json:"stream_options,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
ToolChoice *HuggingFaceToolChoice `json:"tool_choice,omitempty"`
ToolPrompt *string `json:"tool_prompt,omitempty"`
Tools []schemas.ChatTool `json:"tools,omitempty"`
TopLogprobs *int `json:"top_logprobs,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
ExtraParams map[string]interface{} `json:"-"`
}
func (req *HuggingFaceChatRequest) GetExtraParams() map[string]interface{} {
return req.ExtraParams
}
// HuggingFaceToolChoice represents the flexible `tool_choice` field which
// can be either one of the enum strings: "auto", "none", "required",
// or an object with a `function` sub-object containing a required `name`.
type HuggingFaceToolChoice struct {
EnumValue *EnumStringType
// Function holds the function object when the field is a JSON object.
Function *schemas.ChatToolChoiceFunction
}
type EnumStringType string
const (
EnumStringTypeAuto EnumStringType = "auto"
EnumStringTypeNone EnumStringType = "none"
EnumStringTypeRequired EnumStringType = "required"
)
// MarshalJSON will emit either a JSON string for enum values or an object
// containing the `function` key and `type` field.
func (t HuggingFaceToolChoice) MarshalJSON() ([]byte, error) {
if t.EnumValue != nil {
return providerUtils.MarshalSorted(*t.EnumValue)
}
if t.Function != nil {
return providerUtils.MarshalSorted(struct {
Type string `json:"type"`
Function *schemas.ChatToolChoiceFunction `json:"function"`
}{
Type: "function",
Function: t.Function,
})
}
return []byte("null"), nil
}
type HuggingFaceResponseFormat struct {
Type string `json:"type"`
JSONSchema *HuggingFaceJSONSchema `json:"json_schema,omitempty"`
}
type HuggingFaceJSONSchema struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Schema json.RawMessage `json:"schema,omitempty"`
Strict *bool `json:"strict,omitempty"`
}
// # RESPONSE TYPES
type HuggingFaceHubError struct {
Error string `json:"error"`
Message string `json:"message"`
}
type HuggingFaceResponseError struct {
Error string `json:"error"`
Type string `json:"type"`
Message string `json:"message"`
Detail []HuggingFaceErrorDetail `json:"detail,omitempty"` // FastAPI validation errors
}
type HuggingFaceErrorDetail struct {
Loc []interface{} `json:"loc"`
Msg string `json:"msg"`
Type string `json:"type"`
Ctx map[string]interface{} `json:"ctx,omitempty"`
}
// # EMBEDDING TYPES
// HuggingFaceEmbeddingRequest represents the request format for HuggingFace embeddings API
// Based on the HuggingFace Router API specification
type HuggingFaceEmbeddingRequest struct {
Input *InputsCustomType `json:"input,omitempty"` // string or []string used by all inference providers other than hf-inference
Inputs *InputsCustomType `json:"inputs,omitempty"` // string or []string used by hf-inference provider
Provider *string `json:"provider,omitempty"` // used by all inference providers other than hf-inference
Model *string `json:"model,omitempty"` // used by all inference providers other than hf-inference
Normalize *bool `json:"normalize,omitempty"`
PromptName *string `json:"prompt_name,omitempty"`
Truncate *bool `json:"truncate,omitempty"`
TruncationDirection *string `json:"truncation_direction,omitempty"` // "left" or "right"
EncodingFormat *EncodingType `json:"encoding_format,omitempty"`
Dimensions *int `json:"dimensions,omitempty"`
ExtraParams map[string]interface{} `json:"-"`
}
func (req *HuggingFaceEmbeddingRequest) GetExtraParams() map[string]interface{} {
return req.ExtraParams
}
type InputsCustomType struct {
Texts []string `json:"texts,omitempty"`
Text *string `json:"text,omitempty"`
}
func (i *InputsCustomType) UnmarshalJSON(data []byte) error {
if len(data) == 0 || bytes.Equal(bytes.TrimSpace(data), []byte("null")) {
return nil
}
// Try string
var text string
if err := sonic.Unmarshal(data, &text); err == nil {
i.Text = &text
return nil
}
// Try array
var texts []string
if err := sonic.Unmarshal(data, &texts); err == nil {
i.Texts = texts
return nil
}
// Try object
type alias InputsCustomType
var obj alias
if err := sonic.Unmarshal(data, &obj); err == nil {
*i = InputsCustomType(obj)
return nil
}
return fmt.Errorf("failed to unmarshal InputsCustomType: expected string, array, or object")
}
func (i InputsCustomType) MarshalJSON() ([]byte, error) {
if len(i.Texts) > 0 {
return providerUtils.MarshalSorted(i.Texts)
}
if i.Text != nil {
return providerUtils.MarshalSorted(*i.Text)
}
return []byte("null"), nil
}
type EncodingType string
const (
EncodingTypeFloat EncodingType = "float"
EncodingTypeBase64 EncodingType = "base64"
)
// # SPEECH TYPES
// Speech request represents the inputs for Text To Speech inference.
type HuggingFaceSpeechRequest struct {
Text string `json:"text"`
Provider string `json:"provider" validate:"required"`
Model string `json:"model" validate:"required"`
Parameters *HuggingFaceSpeechParameters `json:"parameters,omitempty"`
ExtraParams map[string]interface{} `json:"-"`
}
func (req *HuggingFaceSpeechRequest) GetExtraParams() map[string]interface{} {
return req.ExtraParams
}
// Speech parameters are additional inference parameters for Text To Speech
type HuggingFaceSpeechParameters struct {
GenerationParameters *HuggingFaceTranscriptionGenerationParameters `json:"generation_parameters,omitempty"`
}
// Speech response represents the outputs of inference for the Text To Speech task.
type HuggingFaceSpeechResponse struct {
Audio HuggingFaceSpeechAudio `json:"audio"`
}
// HuggingFaceSpeechAudio represents the audio object in the speech response
type HuggingFaceSpeechAudio struct {
URL string `json:"url"`
ContentType string `json:"content_type"`
FileName string `json:"file_name"`
FileSize int `json:"file_size"`
}
// # TRANSCRIPT TYPES
// HuggingFaceTranscriptionRequest represents the request for Automatic Speech Recognition inference
type HuggingFaceTranscriptionRequest struct {
Inputs []byte `json:"inputs,omitempty"` // raw audio bytes
AudioURL string `json:"audio_url,omitempty"` // URL to audio file only needed for fal ai
Provider *string `json:"provider,omitempty"`
Model *string `json:"model,omitempty"`
Parameters *HuggingFaceTranscriptionRequestParameters `json:"parameters,omitempty"`
ExtraParams map[string]interface{} `json:"-"`
}
func (req *HuggingFaceTranscriptionRequest) GetExtraParams() map[string]interface{} {
return req.ExtraParams
}
// HuggingFaceTranscriptionRequestParameters contains additional inference parameters for Automatic Speech Recognition
type HuggingFaceTranscriptionRequestParameters struct {
GenerationParameters *HuggingFaceTranscriptionGenerationParameters `json:"generation_parameters,omitempty"`
ReturnTimestamps *bool `json:"return_timestamps,omitempty"`
}
// HuggingFaceTranscriptionGenerationParameters contains parametrization of the text generation process
type HuggingFaceTranscriptionGenerationParameters struct {
DoSample *bool `json:"do_sample,omitempty"`
EarlyStopping *HuggingFaceTranscriptionEarlyStopping `json:"early_stopping,omitempty"`
EpsilonCutoff *float64 `json:"epsilon_cutoff,omitempty"`
EtaCutoff *float64 `json:"eta_cutoff,omitempty"`
MaxLength *int `json:"max_length,omitempty"`
MaxNewTokens *int `json:"max_new_tokens,omitempty"`
MinLength *int `json:"min_length,omitempty"`
MinNewTokens *int `json:"min_new_tokens,omitempty"`
NumBeamGroups *int `json:"num_beam_groups,omitempty"`
NumBeams *int `json:"num_beams,omitempty"`
PenaltyAlpha *float64 `json:"penalty_alpha,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopK *int `json:"top_k,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TypicalP *float64 `json:"typical_p,omitempty"`
UseCache *bool `json:"use_cache,omitempty"`
}
// HuggingFaceTranscriptionEarlyStopping controls the stopping condition for beam-based methods
// Can be a boolean or the string "never"
type HuggingFaceTranscriptionEarlyStopping struct {
BoolValue *bool
StringValue *string
}
// MarshalJSON implements custom JSON marshaling for HuggingFaceTranscriptionEarlyStopping
func (e HuggingFaceTranscriptionEarlyStopping) MarshalJSON() ([]byte, error) {
if e.BoolValue != nil {
return providerUtils.MarshalSorted(*e.BoolValue)
}
if e.StringValue != nil {
return providerUtils.MarshalSorted(*e.StringValue)
}
return []byte("null"), nil
}
// UnmarshalJSON implements custom JSON unmarshaling for HuggingFaceTranscriptionEarlyStopping
func (e *HuggingFaceTranscriptionEarlyStopping) UnmarshalJSON(data []byte) error {
// Try boolean first
var boolVal bool
if err := sonic.Unmarshal(data, &boolVal); err == nil {
e.BoolValue = &boolVal
return nil
}
// Try string
var stringVal string
if err := sonic.Unmarshal(data, &stringVal); err == nil {
e.StringValue = &stringVal
return nil
}
return fmt.Errorf("early_stopping must be a boolean or string, got: %s", string(data))
}
// HuggingFaceTranscriptionResponse represents the output of Automatic Speech Recognition inference
type HuggingFaceTranscriptionResponse struct {
Text string `json:"text"`
Chunks []HuggingFaceTranscriptionResponseChunk `json:"chunks,omitempty"`
}
// HuggingFaceTranscriptionResponseChunk represents an audio chunk identified by the model
type HuggingFaceTranscriptionResponseChunk struct {
Text string `json:"text"`
Timestamp []float64 `json:"timestamp"`
}
type HuggingFaceGenerationParameters = HuggingFaceTranscriptionGenerationParameters
type HuggingFaceEarlyStoppingUnion = HuggingFaceTranscriptionEarlyStopping
// # IMAGE GENERATION TYPES
// HuggingFaceHFInferenceImageGenerationRequest for hf-inference image generation
type HuggingFaceHFInferenceImageGenerationRequest struct {
Inputs string `json:"inputs"`
ExtraParams map[string]any `json:"-"`
}
func (req *HuggingFaceHFInferenceImageGenerationRequest) GetExtraParams() map[string]any {
return req.ExtraParams
}
// HuggingFaceFalAIImageGenerationRequest for fal-ai image generation
type HuggingFaceFalAIImageGenerationRequest struct {
Prompt string `json:"prompt"`
NumImages *int `json:"num_images,omitempty"`
ResponseFormat *string `json:"response_format,omitempty"`
ImageSize *HuggingFaceFalAISize `json:"image_size,omitempty"`
NegativePrompt *string `json:"negative_prompt,omitempty"`
GuidanceScale *float64 `json:"guidance_scale,omitempty"`
NumInferenceSteps *int `json:"num_inference_steps,omitempty"`
Seed *int `json:"seed,omitempty"`
OutputFormat *string `json:"output_format,omitempty"`
SyncMode *bool `json:"sync_mode,omitempty"`
EnableSafetyChecker *bool `json:"enable_safety_checker,omitempty"`
Acceleration *string `json:"acceleration,omitempty"`
EnablePromptExpansion *bool `json:"enable_prompt_expansion,omitempty"`
ExtraParams map[string]any `json:"-"`
}
func (req *HuggingFaceFalAIImageGenerationRequest) GetExtraParams() map[string]any {
return req.ExtraParams
}
type HuggingFaceFalAISize struct {
Width int `json:"width"`
Height int `json:"height"`
}
// HuggingFaceFalAIImageGenerationResponse for fal-ai image generation
// Matches the API envelope structure with top-level metadata and data array
type HuggingFaceFalAIImageGenerationResponse struct {
RequestID string `json:"request_id,omitempty"`
Status string `json:"status,omitempty"`
CreatedAt *int64 `json:"created_at,omitempty"`
Data *FalAIImageData `json:"data,omitempty"`
// Legacy flattened fields for backward compatibility
Images []FalAIImage `json:"images,omitempty"`
Timings *FalAITimings `json:"timings,omitempty"`
Seed *int64 `json:"seed,omitempty"`
HasNSFWConcepts []bool `json:"has_nsfw_concepts,omitempty"`
Prompt string `json:"prompt,omitempty"`
}
// FalAIImageData wraps the image data in the API envelope
type FalAIImageData struct {
Images []FalAIImage `json:"images,omitempty"`
Timings *FalAITimings `json:"timings,omitempty"`
Seed *int64 `json:"seed,omitempty"`
HasNSFWConcepts []bool `json:"has_nsfw_concepts,omitempty"`
Prompt string `json:"prompt,omitempty"`
}
type FalAIImage struct {
URL string `json:"url,omitempty"`
B64JSON string `json:"b64_json,omitempty"`
Width int `json:"width,omitempty"`
Height int `json:"height,omitempty"`
ContentType string `json:"content_type,omitempty"`
}
type FalAITimings struct {
Inference float64 `json:"inference"`
}
// HuggingFaceTogetherImageGenerationRequest for together image generation
type HuggingFaceTogetherImageGenerationRequest struct {
Prompt string `json:"prompt"`
Model string `json:"model"`
ResponseFormat *string `json:"response_format,omitempty"`
Size *string `json:"size,omitempty"`
Width *int `json:"width,omitempty"`
Height *int `json:"height,omitempty"`
N *int `json:"n,omitempty"`
Steps *int `json:"steps,omitempty"`
ExtraParams map[string]any `json:"-"`
}
func (req *HuggingFaceTogetherImageGenerationRequest) GetExtraParams() map[string]any {
return req.ExtraParams
}
// HuggingFaceTogetherImageGenerationResponse for together image generation
type HuggingFaceTogetherImageGenerationResponse struct {
ID string `json:"id"`
Model string `json:"model"`
Object string `json:"object"`
Data []HuggingFaceTogetherImageData `json:"data"`
}
type HuggingFaceTogetherImageData struct {
B64JSON string `json:"b64_json,omitempty"`
URL string `json:"url,omitempty"`
Index int `json:"index"`
Timings *HuggingFaceTogetherTimings `json:"timings,omitempty"`
}
type HuggingFaceTogetherTimings struct {
Inference float64 `json:"inference"`
}
// HuggingFaceFalAIImageStreamRequest for fal-ai image generation streaming
type HuggingFaceFalAIImageStreamRequest struct {
Prompt string `json:"prompt"`
ResponseFormat *string `json:"response_format,omitempty"`
NumImages *int `json:"num_images,omitempty"`
ImageSize *HuggingFaceFalAISize `json:"image_size,omitempty"`
GuidanceScale *float64 `json:"guidance_scale,omitempty"`
Seed *int `json:"seed,omitempty"`
NumInferenceSteps *int `json:"num_inference_steps,omitempty"`
Acceleration *string `json:"acceleration,omitempty"`
EnablePromptExpansion *bool `json:"enable_prompt_expansion,omitempty"`
SyncMode *bool `json:"sync_mode,omitempty"`
EnableSafetyChecker *bool `json:"enable_safety_checker,omitempty"`
OutputFormat *string `json:"output_format,omitempty"`
ExtraParams map[string]any `json:"-"`
}
func (req *HuggingFaceFalAIImageStreamRequest) GetExtraParams() map[string]any {
return req.ExtraParams
}
// HuggingFaceFalAIImageStreamResponse for fal-ai SSE events
type HuggingFaceFalAIImageStreamResponse struct {
Data *FalAIImageData `json:"data,omitempty"`
Images []FalAIImage `json:"images,omitempty"`
}
// HuggingFaceFalAIImageEditRequest for fal-ai image edit
type HuggingFaceFalAIImageEditRequest struct {
Prompt string `json:"prompt"`
ImageURL *string `json:"image_url,omitempty"` // For single image models
ImageURLs []string `json:"image_urls,omitempty"` // For multi-image models
NumImages *int `json:"num_images,omitempty"`
ImageSize *HuggingFaceFalAISize `json:"image_size,omitempty"`
GuidanceScale *float64 `json:"guidance_scale,omitempty"`
NumInferenceSteps *int `json:"num_inference_steps,omitempty"`
SyncMode *bool `json:"sync_mode,omitempty"`
Seed *int `json:"seed,omitempty"`
OutputFormat *string `json:"output_format,omitempty"`
EnableSafetyChecker *bool `json:"enable_safety_checker,omitempty"`
Acceleration *string `json:"acceleration,omitempty"`
ExtraParams map[string]any `json:"-"`
}
func (req *HuggingFaceFalAIImageEditRequest) GetExtraParams() map[string]any {
return req.ExtraParams
}

View File

@@ -0,0 +1,354 @@
package huggingface
import (
"context"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"github.com/bytedance/sonic"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
const (
// According to https://huggingface.co/docs/inference-providers/en/tasks/chat-completion the
// OpenAI-compatible router lives under the /v1 prefix, so we wire that in as the default base URL.
defaultInferenceBaseURL = "https://router.huggingface.co"
modelHubBaseURL = "https://huggingface.co"
//For custom deployments, HF offers inference endpoints under
// inferenceBaseEndpointsEndpointBaseURL = "https://api.endpoints.huggingface.cloud/v2"
)
type inferenceProvider string
const (
cerebras inferenceProvider = "cerebras"
cohere inferenceProvider = "cohere"
falAI inferenceProvider = "fal-ai"
featherlessAI inferenceProvider = "featherless-ai"
fireworksAI inferenceProvider = "fireworks-ai"
groq inferenceProvider = "groq"
hfInference inferenceProvider = "hf-inference"
hyperbolic inferenceProvider = "hyperbolic"
nebius inferenceProvider = "nebius"
novita inferenceProvider = "novita"
nscale inferenceProvider = "nscale"
ovhcloud inferenceProvider = "ovhcloud"
publicai inferenceProvider = "publicai"
replicate inferenceProvider = "replicate"
sambanova inferenceProvider = "sambanova"
scaleway inferenceProvider = "scaleway"
together inferenceProvider = "together"
wavespeed inferenceProvider = "wavespeed"
zaiOrg inferenceProvider = "zai-org"
auto inferenceProvider = "auto"
)
// List of supported inference providers (kept in sync with HF docs/JS SDK)
var INFERENCE_PROVIDERS = []inferenceProvider{
cerebras,
cohere,
falAI,
featherlessAI,
fireworksAI,
groq,
hfInference,
hyperbolic,
nebius,
novita,
nscale,
ovhcloud,
publicai,
replicate,
sambanova,
scaleway,
together,
wavespeed,
zaiOrg,
}
// PROVIDERS_OR_POLICIES is the above list plus the special "auto" policy
var PROVIDERS_OR_POLICIES = func() []inferenceProvider {
out := make([]inferenceProvider, 0, len(INFERENCE_PROVIDERS)+1)
out = append(out, INFERENCE_PROVIDERS...)
out = append(out, "auto")
return out
}()
func (provider *HuggingFaceProvider) buildModelHubURL(request *schemas.BifrostListModelsRequest, inferenceProvider inferenceProvider) string {
values := url.Values{}
// Add inference_provider parameter to filter models served by Hugging Face's inference provider
// According to https://huggingface.co/docs/inference-providers/hub-api
limit := request.PageSize
if limit <= 0 {
limit = defaultModelFetchLimit
}
if limit > maxModelFetchLimit {
limit = maxModelFetchLimit
}
values.Set("limit", strconv.Itoa(limit))
values.Set("full", "1")
values.Set("sort", "likes")
values.Set("direction", "-1")
values.Set("inference_provider", string(inferenceProvider))
for key, value := range request.ExtraParams {
switch typed := value.(type) {
case string:
if typed != "" {
values.Set(key, typed)
}
case fmt.Stringer:
values.Set(key, typed.String())
case int:
values.Set(key, strconv.Itoa(typed))
case float64:
values.Set(key, strconv.FormatFloat(typed, 'f', -1, 64))
case bool:
values.Set(key, strconv.FormatBool(typed))
default:
values.Set(key, fmt.Sprintf("%v", typed))
}
}
return fmt.Sprintf("%s/api/models?%s", modelHubBaseURL, values.Encode())
}
func (provider *HuggingFaceProvider) buildModelInferenceProviderURL(modelName string) string {
values := url.Values{}
values.Set("expand[]", "pipeline_tag")
values.Set("expand[]", "inferenceProviderMapping")
return fmt.Sprintf("%s/api/models/%s?%s", modelHubBaseURL, modelName, values.Encode())
}
func splitIntoModelProvider(bifrostModelName string) (inferenceProvider, string, error) {
// Extract provider and model name
t := strings.Count(bifrostModelName, "/")
if t == 0 {
return "", "", fmt.Errorf("invalid model name format: %s", bifrostModelName)
}
var prov inferenceProvider
var model string
if t > 1 {
before, after, _ := strings.Cut(bifrostModelName, "/")
prov = inferenceProvider(before)
model = after
} else if t == 1 {
prov = ""
model = bifrostModelName
}
return prov, model, nil
}
// Defined for tasks given by https://huggingface.co/docs/inference-providers/en/index and makeURL logic at https://github.com/huggingface/huggingface.js/blob/c02dd89eff24593b304d72715247f7eef79b3b73/packages/inference/src/providers/providerHelper.ts#L111
func (provider *HuggingFaceProvider) getInferenceProviderRouteURL(ctx *schemas.BifrostContext, inferenceProvider inferenceProvider, modelName string, requestType schemas.RequestType) (string, error) {
defaultPath := ""
switch inferenceProvider {
case falAI:
defaultPath = fmt.Sprintf("/fal-ai/%s", modelName)
case hfInference:
var pipeline string
switch requestType {
case schemas.EmbeddingRequest:
pipeline = "feature-extraction"
case schemas.SpeechRequest:
pipeline = "text-to-speech"
case schemas.ImageGenerationRequest:
return provider.buildRequestURL(ctx, fmt.Sprintf("/hf-inference/models/%s", modelName), requestType), nil
case schemas.TranscriptionRequest:
return provider.buildRequestURL(ctx, fmt.Sprintf("/hf-inference/models/%s", modelName), requestType), nil
default:
pipeline = "chat-completion"
}
defaultPath = fmt.Sprintf("/hf-inference/models/%s/pipeline/%s", modelName, pipeline)
case nebius:
if requestType == schemas.EmbeddingRequest {
defaultPath = "/nebius/v1/embeddings"
} else if requestType == schemas.ImageGenerationRequest {
defaultPath = "/nebius/v1/images/generations"
} else {
return "", fmt.Errorf("nebius provider only supports embedding and image generation requests")
}
case replicate:
defaultPath = "/replicate/v1/prediction"
case together:
if requestType == schemas.ImageGenerationRequest {
defaultPath = "/together/v1/images/generations"
} else {
return "", fmt.Errorf("together provider only supports image generation requests")
}
case sambanova:
if requestType == schemas.EmbeddingRequest {
defaultPath = "/sambanova/v1/embeddings"
} else {
return "", fmt.Errorf("sambanova provider only supports embedding requests")
}
case scaleway:
if requestType == schemas.EmbeddingRequest {
defaultPath = "/scaleway/v1/embeddings"
} else {
return "", fmt.Errorf("scaleway provider only supports embedding requests")
}
default:
return "", fmt.Errorf("unsupported inference provider: %s for action: %s", inferenceProvider, requestType)
}
return provider.buildRequestURL(ctx, defaultPath, requestType), nil
}
// convertToInferenceProviderMappings converts HuggingFaceInferenceProviderMappingResponse to a map of HuggingFaceInferenceProviderMapping with ProviderName as key
func convertToInferenceProviderMappings(resp *HuggingFaceInferenceProviderMappingResponse) map[inferenceProvider]HuggingFaceInferenceProviderMapping {
if resp == nil || resp.InferenceProviderMapping == nil {
return nil
}
mappings := make(map[inferenceProvider]HuggingFaceInferenceProviderMapping, len(resp.InferenceProviderMapping))
for providerKey, providerInfo := range resp.InferenceProviderMapping {
providerName := inferenceProvider(providerKey)
mappings[providerName] = HuggingFaceInferenceProviderMapping{
ProviderTask: providerInfo.Task,
ProviderModelID: providerInfo.ProviderModelID,
}
}
return mappings
}
func (provider *HuggingFaceProvider) getModelInferenceProviderMapping(ctx context.Context, huggingfaceModelName string) (map[inferenceProvider]HuggingFaceInferenceProviderMapping, *schemas.BifrostError) {
// Check cache first
if cached, ok := provider.modelProviderMappingCache.Load(huggingfaceModelName); ok {
if mappings, ok := cached.(map[inferenceProvider]HuggingFaceInferenceProviderMapping); ok {
return mappings, nil
}
}
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
req.SetRequestURI(provider.buildModelInferenceProviderURL(huggingfaceModelName))
req.Header.SetMethod(http.MethodGet)
req.Header.SetContentType("application/json")
_, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
defer wait()
if bifrostErr != nil {
return nil, bifrostErr
}
if resp.StatusCode() != fasthttp.StatusOK {
var errorResp HuggingFaceHubError
bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp)
if bifrostErr.Error == nil {
bifrostErr.Error = &schemas.ErrorField{}
}
if strings.TrimSpace(errorResp.Message) != "" {
bifrostErr.Error.Message = errorResp.Message
}
return nil, bifrostErr
}
body, err := providerUtils.CheckAndDecodeBody(resp)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
}
var mappingResp HuggingFaceInferenceProviderMappingResponse
if err := sonic.Unmarshal(body, &mappingResp); err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
}
mappings := convertToInferenceProviderMappings(&mappingResp)
// Store in cache
if mappings != nil {
provider.modelProviderMappingCache.Store(huggingfaceModelName, mappings)
}
return mappings, nil
}
// getValidatedProviderModelID fetches the inference provider mapping for a model
// and validates that the given inferenceProvider has a mapping with the expected task.
// On success it returns the provider-specific model id. On failure it returns a
// BifrostError indicating the operation isn't supported for the requested
// request type or provider.
func (provider *HuggingFaceProvider) getValidatedProviderModelID(ctx context.Context, inferenceProvider inferenceProvider, huggingfaceModelName string, requiredTask string, requestType schemas.RequestType) (string, *schemas.BifrostError) {
providerName := provider.GetProviderKey()
providerMapping, bifrostErr := provider.getModelInferenceProviderMapping(ctx, huggingfaceModelName)
if bifrostErr != nil {
return "", bifrostErr
}
if providerMapping == nil {
return "", providerUtils.NewUnsupportedOperationError(requestType, providerName)
}
mapping, ok := providerMapping[inferenceProvider]
if !ok || mapping.ProviderModelID == "" || mapping.ProviderTask != requiredTask {
return "", providerUtils.NewUnsupportedOperationError(requestType, providerName)
}
return mapping.ProviderModelID, nil
}
// downloadAudioFromURL downloads audio data from a URL
func (provider *HuggingFaceProvider) downloadAudioFromURL(ctx context.Context, audioURL string) ([]byte, error) {
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
req.SetRequestURI(audioURL)
req.Header.SetMethod(http.MethodGet)
_, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
defer wait()
if bifrostErr != nil {
return nil, fmt.Errorf("failed to download audio: %v", bifrostErr)
}
if resp.StatusCode() != fasthttp.StatusOK {
return nil, fmt.Errorf("failed to download audio: status=%d", resp.StatusCode())
}
body, err := providerUtils.CheckAndDecodeBody(resp)
if err != nil {
return nil, fmt.Errorf("failed to read audio data: %w", err)
}
// Copy the body to avoid use-after-free
audioCopy := append([]byte(nil), body...)
return audioCopy, nil
}
func getMimeTypeForAudioType(audioType string) string {
if audioType == "" {
return "audio/mpeg"
}
// Lowercase for comparison and trim parameters if present (e.g.);
t := strings.ToLower(strings.TrimSpace(audioType))
// If it already starts with "audio/", normalise some known variants
if strings.HasPrefix(t, "audio/") {
switch t {
case "audio/mp3":
return "audio/mpeg"
default:
return t
}
}
return "audio/mpeg"
}