first commit
This commit is contained in:
144
core/providers/huggingface/chat.go
Normal file
144
core/providers/huggingface/chat.go
Normal file
@@ -0,0 +1,144 @@
|
||||
// Package huggingface provides a HuggingFace chat provider.
|
||||
package huggingface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// sanitizeMessagesForHuggingFace removes unsupported ChatAssistantMessage fields
|
||||
// from chat messages. HuggingFace's OpenAI-compatible API doesn't support fields
|
||||
// like reasoning_details, reasoning, annotations, audio, and refusal.
|
||||
// Only ToolCalls is preserved from ChatAssistantMessage.
|
||||
func sanitizeMessagesForHuggingFace(messages []schemas.ChatMessage) []schemas.ChatMessage {
|
||||
sanitized := make([]schemas.ChatMessage, len(messages))
|
||||
for i, msg := range messages {
|
||||
sanitized[i] = schemas.ChatMessage{
|
||||
Name: msg.Name,
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
ChatToolMessage: msg.ChatToolMessage,
|
||||
}
|
||||
// Only preserve ToolCalls from ChatAssistantMessage
|
||||
if msg.ChatAssistantMessage != nil && len(msg.ChatAssistantMessage.ToolCalls) > 0 {
|
||||
sanitized[i].ChatAssistantMessage = &schemas.ChatAssistantMessage{
|
||||
ToolCalls: msg.ChatAssistantMessage.ToolCalls,
|
||||
}
|
||||
}
|
||||
}
|
||||
return sanitized
|
||||
}
|
||||
|
||||
func ToHuggingFaceChatCompletionRequest(bifrostReq *schemas.BifrostChatRequest) (*HuggingFaceChatRequest, error) {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Create the HuggingFace request
|
||||
// Sanitize messages to remove unsupported fields like reasoning_details
|
||||
hfReq := &HuggingFaceChatRequest{
|
||||
Messages: sanitizeMessagesForHuggingFace(bifrostReq.Input),
|
||||
Model: bifrostReq.Model,
|
||||
}
|
||||
|
||||
// Map parameters if present
|
||||
if bifrostReq.Params != nil {
|
||||
params := bifrostReq.Params
|
||||
|
||||
if params.FrequencyPenalty != nil {
|
||||
hfReq.FrequencyPenalty = params.FrequencyPenalty
|
||||
}
|
||||
if params.LogProbs != nil {
|
||||
hfReq.Logprobs = params.LogProbs
|
||||
}
|
||||
if params.MaxCompletionTokens != nil {
|
||||
hfReq.MaxTokens = params.MaxCompletionTokens
|
||||
}
|
||||
if params.PresencePenalty != nil {
|
||||
hfReq.PresencePenalty = params.PresencePenalty
|
||||
}
|
||||
if params.Seed != nil {
|
||||
hfReq.Seed = params.Seed
|
||||
}
|
||||
if len(params.Stop) > 0 {
|
||||
hfReq.Stop = params.Stop
|
||||
}
|
||||
if params.Temperature != nil {
|
||||
hfReq.Temperature = params.Temperature
|
||||
}
|
||||
if params.TopLogProbs != nil {
|
||||
hfReq.TopLogprobs = params.TopLogProbs
|
||||
}
|
||||
if params.TopP != nil {
|
||||
hfReq.TopP = params.TopP
|
||||
}
|
||||
|
||||
// Handle response format (direct type assertion to avoid marshal→unmarshal round-trip)
|
||||
if params.ResponseFormat != nil {
|
||||
var hfRF *HuggingFaceResponseFormat
|
||||
if rfMap, ok := (*params.ResponseFormat).(map[string]interface{}); ok {
|
||||
hfRF = &HuggingFaceResponseFormat{}
|
||||
if t, ok := rfMap["type"].(string); ok {
|
||||
hfRF.Type = t
|
||||
}
|
||||
if jsVal, ok := rfMap["json_schema"]; ok {
|
||||
jsBytes, err := providerUtils.MarshalSorted(jsVal)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal json_schema: %w", err)
|
||||
}
|
||||
var hfSchema HuggingFaceJSONSchema
|
||||
if err := sonic.Unmarshal(jsBytes, &hfSchema); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal json_schema: %w", err)
|
||||
}
|
||||
hfRF.JSONSchema = &hfSchema
|
||||
}
|
||||
} else if converted, err := schemas.ConvertViaJSON[HuggingFaceResponseFormat](*params.ResponseFormat); err == nil {
|
||||
hfRF = &converted
|
||||
}
|
||||
hfReq.ResponseFormat = hfRF
|
||||
}
|
||||
|
||||
// Handle stream options
|
||||
if params.StreamOptions != nil {
|
||||
hfReq.StreamOptions = &schemas.ChatStreamOptions{
|
||||
IncludeUsage: params.StreamOptions.IncludeUsage,
|
||||
}
|
||||
}
|
||||
|
||||
hfReq.Tools = params.Tools
|
||||
|
||||
// Handle tool choice
|
||||
if params.ToolChoice != nil {
|
||||
hfToolChoice := &HuggingFaceToolChoice{}
|
||||
if params.ToolChoice.ChatToolChoiceStr != nil {
|
||||
switch *params.ToolChoice.ChatToolChoiceStr {
|
||||
case "auto":
|
||||
auto := EnumStringTypeAuto
|
||||
hfToolChoice.EnumValue = &auto
|
||||
case "none":
|
||||
none := EnumStringTypeNone
|
||||
hfToolChoice.EnumValue = &none
|
||||
case "required":
|
||||
required := EnumStringTypeRequired
|
||||
hfToolChoice.EnumValue = &required
|
||||
}
|
||||
} else if params.ToolChoice.ChatToolChoiceStruct != nil {
|
||||
if params.ToolChoice.ChatToolChoiceStruct.Type == schemas.ChatToolChoiceTypeFunction && params.ToolChoice.ChatToolChoiceStruct.Function != nil {
|
||||
hfToolChoice.Function = &schemas.ChatToolChoiceFunction{
|
||||
Name: params.ToolChoice.ChatToolChoiceStruct.Function.Name,
|
||||
}
|
||||
}
|
||||
}
|
||||
if hfToolChoice.EnumValue != nil || hfToolChoice.Function != nil {
|
||||
hfReq.ToolChoice = hfToolChoice
|
||||
}
|
||||
}
|
||||
hfReq.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
}
|
||||
|
||||
return hfReq, nil
|
||||
}
|
||||
132
core/providers/huggingface/chat_test.go
Normal file
132
core/providers/huggingface/chat_test.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package huggingface
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestToHuggingFaceChatCompletionRequest_ResponseFormat(t *testing.T) {
|
||||
makeReq := func(rf *interface{}) *schemas.BifrostChatRequest {
|
||||
return &schemas.BifrostChatRequest{
|
||||
Model: "test-model",
|
||||
Input: []schemas.ChatMessage{{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("hello")}}},
|
||||
Params: &schemas.ChatParameters{
|
||||
ResponseFormat: rf,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
responseFormat *interface{}
|
||||
wantErr bool
|
||||
validate func(t *testing.T, result *HuggingFaceChatRequest)
|
||||
}{
|
||||
{
|
||||
name: "nil_response_format",
|
||||
responseFormat: nil,
|
||||
validate: func(t *testing.T, result *HuggingFaceChatRequest) {
|
||||
assert.Nil(t, result.ResponseFormat)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "map_type_only",
|
||||
responseFormat: func() *interface{} {
|
||||
var rf interface{} = map[string]interface{}{"type": "json_object"}
|
||||
return &rf
|
||||
}(),
|
||||
validate: func(t *testing.T, result *HuggingFaceChatRequest) {
|
||||
require.NotNil(t, result.ResponseFormat)
|
||||
assert.Equal(t, "json_object", result.ResponseFormat.Type)
|
||||
assert.Nil(t, result.ResponseFormat.JSONSchema)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "map_with_json_schema",
|
||||
responseFormat: func() *interface{} {
|
||||
var rf interface{} = map[string]interface{}{
|
||||
"type": "json_schema",
|
||||
"json_schema": map[string]interface{}{
|
||||
"name": "my_schema",
|
||||
"description": "A test schema",
|
||||
"strict": true,
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"answer": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return &rf
|
||||
}(),
|
||||
validate: func(t *testing.T, result *HuggingFaceChatRequest) {
|
||||
require.NotNil(t, result.ResponseFormat)
|
||||
assert.Equal(t, "json_schema", result.ResponseFormat.Type)
|
||||
require.NotNil(t, result.ResponseFormat.JSONSchema)
|
||||
assert.Equal(t, "my_schema", result.ResponseFormat.JSONSchema.Name)
|
||||
assert.Equal(t, "A test schema", result.ResponseFormat.JSONSchema.Description)
|
||||
require.NotNil(t, result.ResponseFormat.JSONSchema.Strict)
|
||||
assert.True(t, *result.ResponseFormat.JSONSchema.Strict)
|
||||
require.NotNil(t, result.ResponseFormat.JSONSchema.Schema)
|
||||
// Verify schema content round-tripped correctly
|
||||
var schemaMap map[string]interface{}
|
||||
err := json.Unmarshal(result.ResponseFormat.JSONSchema.Schema, &schemaMap)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "object", schemaMap["type"])
|
||||
props, ok := schemaMap["properties"].(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Contains(t, props, "answer")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "struct_fallback_via_convert",
|
||||
responseFormat: func() *interface{} {
|
||||
var rf interface{} = HuggingFaceResponseFormat{
|
||||
Type: "json_schema",
|
||||
JSONSchema: &HuggingFaceJSONSchema{
|
||||
Name: "fallback_schema",
|
||||
Strict: schemas.Ptr(true),
|
||||
},
|
||||
}
|
||||
return &rf
|
||||
}(),
|
||||
validate: func(t *testing.T, result *HuggingFaceChatRequest) {
|
||||
require.NotNil(t, result.ResponseFormat, "ResponseFormat should not be nil — ConvertViaJSON fallback must handle struct values")
|
||||
assert.Equal(t, "json_schema", result.ResponseFormat.Type)
|
||||
require.NotNil(t, result.ResponseFormat.JSONSchema)
|
||||
assert.Equal(t, "fallback_schema", result.ResponseFormat.JSONSchema.Name)
|
||||
require.NotNil(t, result.ResponseFormat.JSONSchema.Strict)
|
||||
assert.True(t, *result.ResponseFormat.JSONSchema.Strict)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "inconvertible_value_graceful_nil",
|
||||
responseFormat: func() *interface{} {
|
||||
var rf interface{} = 42
|
||||
return &rf
|
||||
}(),
|
||||
validate: func(t *testing.T, result *HuggingFaceChatRequest) {
|
||||
assert.Nil(t, result.ResponseFormat, "inconvertible value should gracefully result in nil ResponseFormat")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := makeReq(tt.responseFormat)
|
||||
result, err := ToHuggingFaceChatCompletionRequest(req)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
tt.validate(t, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
165
core/providers/huggingface/embedding.go
Normal file
165
core/providers/huggingface/embedding.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package huggingface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToHuggingFaceEmbeddingRequest converts a Bifrost embedding request to HuggingFace format
|
||||
func ToHuggingFaceEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) (*HuggingFaceEmbeddingRequest, error) {
|
||||
if bifrostReq == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
inferenceProvider, modelName, nameErr := splitIntoModelProvider(bifrostReq.Model)
|
||||
if nameErr != nil {
|
||||
return nil, nameErr
|
||||
}
|
||||
|
||||
var hfReq *HuggingFaceEmbeddingRequest
|
||||
if inferenceProvider != hfInference {
|
||||
hfReq = &HuggingFaceEmbeddingRequest{
|
||||
Model: schemas.Ptr(modelName),
|
||||
Provider: schemas.Ptr(string(inferenceProvider)),
|
||||
}
|
||||
} else {
|
||||
hfReq = &HuggingFaceEmbeddingRequest{}
|
||||
}
|
||||
|
||||
// Convert input
|
||||
if bifrostReq.Input != nil {
|
||||
var input InputsCustomType
|
||||
if bifrostReq.Input.Text != nil {
|
||||
input = InputsCustomType{Text: bifrostReq.Input.Text}
|
||||
|
||||
} else if bifrostReq.Input.Texts != nil {
|
||||
input = InputsCustomType{Texts: bifrostReq.Input.Texts}
|
||||
}
|
||||
if inferenceProvider == hfInference {
|
||||
hfReq.Inputs = &input
|
||||
} else {
|
||||
hfReq.Input = &input
|
||||
}
|
||||
}
|
||||
|
||||
// Map parameters
|
||||
if bifrostReq.Params != nil {
|
||||
params := bifrostReq.Params
|
||||
|
||||
// Map standard parameters
|
||||
if params.EncodingFormat != nil {
|
||||
encodingType := EncodingType(*params.EncodingFormat)
|
||||
hfReq.EncodingFormat = &encodingType
|
||||
}
|
||||
if params.Dimensions != nil {
|
||||
hfReq.Dimensions = params.Dimensions
|
||||
}
|
||||
|
||||
// Check for HuggingFace-specific parameters in ExtraParams
|
||||
if params.ExtraParams != nil {
|
||||
if normalize, ok := params.ExtraParams["normalize"].(bool); ok {
|
||||
delete(params.ExtraParams, "normalize")
|
||||
hfReq.Normalize = &normalize
|
||||
}
|
||||
if promptName, ok := params.ExtraParams["prompt_name"].(string); ok {
|
||||
delete(params.ExtraParams, "prompt_name")
|
||||
hfReq.PromptName = &promptName
|
||||
}
|
||||
if truncate, ok := params.ExtraParams["truncate"].(bool); ok {
|
||||
delete(params.ExtraParams, "truncate")
|
||||
hfReq.Truncate = &truncate
|
||||
}
|
||||
if truncationDirection, ok := params.ExtraParams["truncation_direction"].(string); ok {
|
||||
delete(params.ExtraParams, "truncation_direction")
|
||||
hfReq.TruncationDirection = &truncationDirection
|
||||
}
|
||||
}
|
||||
hfReq.ExtraParams = params.ExtraParams
|
||||
}
|
||||
|
||||
return hfReq, nil
|
||||
}
|
||||
|
||||
// UnmarshalHuggingFaceEmbeddingResponse unmarshals HuggingFace API response directly into BifrostEmbeddingResponse
|
||||
// Handles multiple formats: standard object, 2D array, or 1D array
|
||||
func UnmarshalHuggingFaceEmbeddingResponse(data []byte, model string) (*schemas.BifrostEmbeddingResponse, error) {
|
||||
if data == nil {
|
||||
return nil, fmt.Errorf("response data is nil")
|
||||
}
|
||||
|
||||
// Try standard object format first
|
||||
type tempResponse struct {
|
||||
Data []schemas.EmbeddingData `json:"data,omitempty"`
|
||||
Model *string `json:"model,omitempty"`
|
||||
Usage *schemas.BifrostLLMUsage `json:"usage,omitempty"`
|
||||
}
|
||||
var obj tempResponse
|
||||
if err := sonic.Unmarshal(data, &obj); err == nil {
|
||||
if obj.Data != nil || obj.Model != nil || obj.Usage != nil {
|
||||
bifrostResponse := &schemas.BifrostEmbeddingResponse{
|
||||
Data: obj.Data,
|
||||
Model: model,
|
||||
Object: "list",
|
||||
}
|
||||
if obj.Model != nil {
|
||||
bifrostResponse.Model = *obj.Model
|
||||
}
|
||||
if obj.Usage != nil {
|
||||
bifrostResponse.Usage = obj.Usage
|
||||
} else {
|
||||
bifrostResponse.Usage = &schemas.BifrostLLMUsage{
|
||||
PromptTokens: 0,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: 0,
|
||||
}
|
||||
}
|
||||
return bifrostResponse, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Try 2D array: [[num, ...], ...]
|
||||
var arr2D [][]float64
|
||||
if err := sonic.Unmarshal(data, &arr2D); err == nil {
|
||||
embeddings := make([]schemas.EmbeddingData, len(arr2D))
|
||||
for idx, embedding := range arr2D {
|
||||
embeddings[idx] = schemas.EmbeddingData{
|
||||
Embedding: schemas.EmbeddingStruct{EmbeddingArray: append([]float64(nil), embedding...)},
|
||||
Index: idx,
|
||||
Object: "embedding",
|
||||
}
|
||||
}
|
||||
return &schemas.BifrostEmbeddingResponse{
|
||||
Data: embeddings,
|
||||
Model: model,
|
||||
Object: "list",
|
||||
Usage: &schemas.BifrostLLMUsage{
|
||||
PromptTokens: 0,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: 0,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Try 1D array: [num, ...]
|
||||
var arr1D []float64
|
||||
if err := sonic.Unmarshal(data, &arr1D); err == nil {
|
||||
return &schemas.BifrostEmbeddingResponse{
|
||||
Data: []schemas.EmbeddingData{{
|
||||
Embedding: schemas.EmbeddingStruct{EmbeddingArray: append([]float64(nil), arr1D...)},
|
||||
Index: 0,
|
||||
Object: "embedding",
|
||||
}},
|
||||
Model: model,
|
||||
Object: "list",
|
||||
Usage: &schemas.BifrostLLMUsage{
|
||||
PromptTokens: 0,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: 0,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to unmarshal HuggingFace embedding response: unexpected structure")
|
||||
}
|
||||
57
core/providers/huggingface/errors.go
Normal file
57
core/providers/huggingface/errors.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package huggingface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// parseHuggingFaceImageError parses HuggingFace error responses
|
||||
func parseHuggingFaceImageError(resp *fasthttp.Response) *schemas.BifrostError {
|
||||
var errorResp HuggingFaceResponseError
|
||||
bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp)
|
||||
|
||||
if strings.TrimSpace(errorResp.Type) != "" {
|
||||
typeCopy := errorResp.Type
|
||||
bifrostErr.Type = &typeCopy
|
||||
}
|
||||
|
||||
if bifrostErr.Error == nil {
|
||||
bifrostErr.Error = &schemas.ErrorField{}
|
||||
}
|
||||
|
||||
// Handle FastAPI validation errors
|
||||
if len(errorResp.Detail) > 0 {
|
||||
var errorMessages []string
|
||||
for _, detail := range errorResp.Detail {
|
||||
msg := detail.Msg
|
||||
if len(detail.Loc) > 0 {
|
||||
// Build location string from loc array
|
||||
var locParts []string
|
||||
for _, locPart := range detail.Loc {
|
||||
if locStr, ok := locPart.(string); ok {
|
||||
locParts = append(locParts, locStr)
|
||||
} else if locNum, ok := locPart.(float64); ok {
|
||||
locParts = append(locParts, fmt.Sprintf("%.0f", locNum))
|
||||
}
|
||||
}
|
||||
if len(locParts) > 0 {
|
||||
msg = fmt.Sprintf("%s at %s", msg, strings.Join(locParts, "."))
|
||||
}
|
||||
}
|
||||
errorMessages = append(errorMessages, msg)
|
||||
}
|
||||
if len(errorMessages) > 0 {
|
||||
bifrostErr.Error.Message = strings.Join(errorMessages, "; ")
|
||||
}
|
||||
} else if strings.TrimSpace(errorResp.Message) != "" {
|
||||
bifrostErr.Error.Message = errorResp.Message
|
||||
} else if strings.TrimSpace(errorResp.Error) != "" {
|
||||
bifrostErr.Error.Message = errorResp.Error
|
||||
}
|
||||
|
||||
return bifrostErr
|
||||
}
|
||||
1801
core/providers/huggingface/huggingface.go
Normal file
1801
core/providers/huggingface/huggingface.go
Normal file
File diff suppressed because it is too large
Load Diff
132
core/providers/huggingface/huggingface_test.go
Normal file
132
core/providers/huggingface/huggingface_test.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package huggingface_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/internal/llmtests"
|
||||
"github.com/maximhq/bifrost/core/providers/huggingface"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func TestHuggingface(t *testing.T) {
|
||||
t.Parallel()
|
||||
if os.Getenv("HUGGING_FACE_API_KEY") == "" {
|
||||
t.Skip("Skipping HuggingFace tests because HUGGING_FACE_API_KEY is not set")
|
||||
}
|
||||
|
||||
client, ctx, cancel, err := llmtests.SetupTest()
|
||||
if err != nil {
|
||||
t.Fatalf("Error initializing test setup: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
defer client.Shutdown()
|
||||
|
||||
testConfig := llmtests.ComprehensiveTestConfig{
|
||||
Provider: schemas.HuggingFace,
|
||||
ChatModel: "groq/meta-llama/Llama-3.3-70B-Instruct",
|
||||
VisionModel: "cohere/CohereLabs/aya-vision-32b",
|
||||
EmbeddingModel: "sambanova/intfloat/e5-mistral-7b-instruct",
|
||||
TranscriptionModel: "fal-ai/openai/whisper-large-v3",
|
||||
SpeechSynthesisModel: "fal-ai/hexgrad/Kokoro-82M",
|
||||
SpeechSynthesisFallbacks: []schemas.Fallback{
|
||||
{Provider: schemas.HuggingFace, Model: "fal-ai/ResembleAI/chatterbox"},
|
||||
},
|
||||
ReasoningModel: "groq/openai/gpt-oss-120b",
|
||||
ImageGenerationModel: "fal-ai/fal-ai/flux/dev",
|
||||
ImageEditModel: "fal-ai/fal-ai/flux-2/edit",
|
||||
Scenarios: llmtests.TestScenarios{
|
||||
TextCompletion: false,
|
||||
TextCompletionStream: false,
|
||||
SimpleChat: true,
|
||||
CompletionStream: true,
|
||||
MultiTurnConversation: true,
|
||||
ToolCalls: true,
|
||||
ToolCallsStreaming: true,
|
||||
MultipleToolCalls: false,
|
||||
End2EndToolCalling: true,
|
||||
AutomaticFunctionCall: true,
|
||||
ImageURL: true,
|
||||
ImageBase64: true,
|
||||
MultipleImages: true,
|
||||
CompleteEnd2End: true,
|
||||
Embedding: false,
|
||||
Transcription: true,
|
||||
TranscriptionStream: false,
|
||||
SpeechSynthesis: true,
|
||||
SpeechSynthesisStream: false,
|
||||
Reasoning: true,
|
||||
ListModels: true,
|
||||
BatchCreate: false,
|
||||
BatchList: false,
|
||||
BatchRetrieve: false,
|
||||
BatchCancel: false,
|
||||
BatchResults: false,
|
||||
FileUpload: false,
|
||||
FileList: false,
|
||||
FileRetrieve: false,
|
||||
FileDelete: false,
|
||||
FileContent: false,
|
||||
FileBatchInput: false,
|
||||
ImageGeneration: true,
|
||||
ImageGenerationStream: true,
|
||||
ImageEdit: true,
|
||||
ImageEditStream: true,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("HuggingFaceTests", func(t *testing.T) {
|
||||
llmtests.RunAllComprehensiveTests(t, client, ctx, testConfig)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalHuggingFaceEmbeddingResponsePreservesPrecision(t *testing.T) {
|
||||
const want = 0.12345678901234568
|
||||
|
||||
resp, err := huggingface.UnmarshalHuggingFaceEmbeddingResponse([]byte(`[[0.12345678901234568]]`), "test-model")
|
||||
if err != nil {
|
||||
t.Fatalf("UnmarshalHuggingFaceEmbeddingResponse failed: %v", err)
|
||||
}
|
||||
|
||||
if resp == nil || len(resp.Data) != 1 {
|
||||
t.Fatalf("expected single embedding response, got %#v", resp)
|
||||
}
|
||||
if len(resp.Data[0].Embedding.EmbeddingArray) != 1 {
|
||||
t.Fatalf("expected single embedding value, got %#v", resp.Data[0].Embedding.EmbeddingArray)
|
||||
}
|
||||
|
||||
got := resp.Data[0].Embedding.EmbeddingArray[0]
|
||||
if got != want {
|
||||
t.Fatalf("expected %0.18f, got %0.18f", want, got)
|
||||
}
|
||||
|
||||
if got == float64(float32(want)) {
|
||||
t.Fatalf("expected preserved precision, got float32-rounded value %0.18f", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalHuggingFaceEmbeddingResponse1DPreservesPrecision(t *testing.T) {
|
||||
const want = 0.12345678901234568
|
||||
|
||||
resp, err := huggingface.UnmarshalHuggingFaceEmbeddingResponse([]byte(`[0.12345678901234568]`), "test-model")
|
||||
if err != nil {
|
||||
t.Fatalf("UnmarshalHuggingFaceEmbeddingResponse failed: %v", err)
|
||||
}
|
||||
|
||||
if resp == nil || len(resp.Data) != 1 {
|
||||
t.Fatalf("expected single embedding response, got %#v", resp)
|
||||
}
|
||||
if len(resp.Data[0].Embedding.EmbeddingArray) != 1 {
|
||||
t.Fatalf("expected single embedding value, got %#v", resp.Data[0].Embedding.EmbeddingArray)
|
||||
}
|
||||
|
||||
got := resp.Data[0].Embedding.EmbeddingArray[0]
|
||||
if got != want {
|
||||
t.Fatalf("expected %0.18f, got %0.18f", want, got)
|
||||
}
|
||||
|
||||
if got == float64(float32(want)) {
|
||||
t.Fatalf("expected preserved precision, got float32-rounded value %0.18f", got)
|
||||
}
|
||||
}
|
||||
603
core/providers/huggingface/images.go
Normal file
603
core/providers/huggingface/images.go
Normal file
@@ -0,0 +1,603 @@
|
||||
package huggingface
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
nebiusProvider "github.com/maximhq/bifrost/core/providers/nebius"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// Models that support multiple images (image_urls)
|
||||
var falAIMultiImageEditModels = map[string]bool{
|
||||
"fal-ai/flux-2/edit": true,
|
||||
"fal-ai/flux-2-pro/edit": true,
|
||||
}
|
||||
|
||||
// Models that only support single image (image_url)
|
||||
var falAISingleImageEditModels = map[string]bool{
|
||||
"fal-ai/flux-pro/kontext": true,
|
||||
"fal-ai/flux/dev/image-to-image": true,
|
||||
}
|
||||
|
||||
// ToHuggingFaceImageGenerationRequest converts a Bifrost image generation request to provider-specific format
|
||||
func ToHuggingFaceImageGenerationRequest(bifrostReq *schemas.BifrostImageGenerationRequest) (providerUtils.RequestBodyWithExtraParams, error) {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil {
|
||||
return nil, fmt.Errorf("bifrost request is nil or input is nil")
|
||||
}
|
||||
|
||||
inferenceProvider, model, nameErr := splitIntoModelProvider(bifrostReq.Model)
|
||||
if nameErr != nil {
|
||||
return nil, nameErr
|
||||
}
|
||||
|
||||
switch inferenceProvider {
|
||||
case nebius:
|
||||
req := &nebiusProvider.NebiusImageGenerationRequest{
|
||||
Model: &model,
|
||||
Prompt: &bifrostReq.Input.Prompt,
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil {
|
||||
if bifrostReq.Params.ResponseFormat != nil {
|
||||
req.ResponseFormat = bifrostReq.Params.ResponseFormat
|
||||
}
|
||||
|
||||
if bifrostReq.Params.Size != nil && strings.ToLower(*bifrostReq.Params.Size) != "auto" {
|
||||
size := strings.Split(strings.ToLower(*bifrostReq.Params.Size), "x")
|
||||
if len(size) != 2 {
|
||||
return nil, fmt.Errorf("invalid size format: expected 'WIDTHxHEIGHT', got %q", *bifrostReq.Params.Size)
|
||||
}
|
||||
|
||||
width, err := strconv.Atoi(size[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid width in size %q: %w", *bifrostReq.Params.Size, err)
|
||||
}
|
||||
|
||||
height, err := strconv.Atoi(size[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid height in size %q: %w", *bifrostReq.Params.Size, err)
|
||||
}
|
||||
|
||||
req.Width = &width
|
||||
req.Height = &height
|
||||
}
|
||||
if bifrostReq.Params.OutputFormat != nil {
|
||||
req.ResponseExtension = bifrostReq.Params.OutputFormat
|
||||
}
|
||||
|
||||
// Handle nebius inconsistency - normalize ResponseExtension case-insensitively
|
||||
if req.ResponseExtension != nil && strings.ToLower(*req.ResponseExtension) == "jpeg" {
|
||||
req.ResponseExtension = schemas.Ptr("jpg")
|
||||
}
|
||||
|
||||
// Map seed from direct field
|
||||
if bifrostReq.Params.Seed != nil {
|
||||
req.Seed = bifrostReq.Params.Seed
|
||||
}
|
||||
|
||||
// Map negative_prompt from direct field
|
||||
if bifrostReq.Params.NegativePrompt != nil {
|
||||
req.NegativePrompt = bifrostReq.Params.NegativePrompt
|
||||
}
|
||||
|
||||
// Handle extra params for nebius
|
||||
if bifrostReq.Params.ExtraParams != nil {
|
||||
req.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
// Map num_inference_steps
|
||||
if v, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["num_inference_steps"]); ok {
|
||||
delete(req.ExtraParams, "num_inference_steps")
|
||||
req.NumInferenceSteps = v
|
||||
}
|
||||
|
||||
// Map guidance_scale
|
||||
if v, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["guidance_scale"]); ok {
|
||||
delete(req.ExtraParams, "guidance_scale")
|
||||
req.GuidanceScale = v
|
||||
}
|
||||
|
||||
// Map loras
|
||||
if lorasValue, exists := bifrostReq.Params.ExtraParams["loras"]; exists && lorasValue != nil {
|
||||
delete(req.ExtraParams, "loras")
|
||||
if lorasArray, ok := lorasValue.([]interface{}); ok {
|
||||
for _, item := range lorasArray {
|
||||
if loraMap, ok := item.(map[string]interface{}); ok {
|
||||
if url, ok := schemas.SafeExtractString(loraMap["url"]); ok {
|
||||
if scale, ok := schemas.SafeExtractInt(loraMap["scale"]); ok {
|
||||
req.Loras = append(req.Loras, nebiusProvider.NebiusLora{URL: url, Scale: scale})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return req, nil
|
||||
|
||||
case hfInference:
|
||||
req := &HuggingFaceHFInferenceImageGenerationRequest{
|
||||
Inputs: bifrostReq.Input.Prompt,
|
||||
}
|
||||
if bifrostReq.Params != nil {
|
||||
req.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
}
|
||||
return req, nil
|
||||
|
||||
case falAI:
|
||||
req := &HuggingFaceFalAIImageGenerationRequest{
|
||||
Prompt: bifrostReq.Input.Prompt,
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil {
|
||||
// Map n to num_images for fal-ai
|
||||
if bifrostReq.Params.N != nil {
|
||||
req.NumImages = bifrostReq.Params.N
|
||||
}
|
||||
|
||||
// Pass through response_format
|
||||
if bifrostReq.Params.ResponseFormat != nil {
|
||||
req.ResponseFormat = bifrostReq.Params.ResponseFormat
|
||||
}
|
||||
|
||||
// Pass through output_format
|
||||
if bifrostReq.Params.OutputFormat != nil {
|
||||
if strings.ToLower(*bifrostReq.Params.OutputFormat) == "jpg" {
|
||||
req.OutputFormat = schemas.Ptr("jpeg")
|
||||
} else {
|
||||
req.OutputFormat = bifrostReq.Params.OutputFormat
|
||||
}
|
||||
}
|
||||
|
||||
// Convert size from "WxH" format to fal-ai's image_size object
|
||||
if bifrostReq.Params.Size != nil && strings.ToLower(*bifrostReq.Params.Size) != "auto" {
|
||||
size := strings.Split(*bifrostReq.Params.Size, "x")
|
||||
if len(size) == 2 {
|
||||
width, err := strconv.Atoi(size[0])
|
||||
if err == nil {
|
||||
height, err := strconv.Atoi(size[1])
|
||||
if err == nil {
|
||||
req.ImageSize = &HuggingFaceFalAISize{
|
||||
Width: width,
|
||||
Height: height,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if bifrostReq.Params.ResponseFormat != nil && *bifrostReq.Params.ResponseFormat == "b64_json" {
|
||||
req.SyncMode = schemas.Ptr(true)
|
||||
}
|
||||
|
||||
if bifrostReq.Params.Moderation != nil && *bifrostReq.Params.Moderation == "low" {
|
||||
req.EnableSafetyChecker = schemas.Ptr(false)
|
||||
}
|
||||
|
||||
// Map seed from direct field
|
||||
if bifrostReq.Params.Seed != nil {
|
||||
req.Seed = bifrostReq.Params.Seed
|
||||
}
|
||||
|
||||
// Map negative_prompt from direct field
|
||||
if bifrostReq.Params.NegativePrompt != nil {
|
||||
req.NegativePrompt = bifrostReq.Params.NegativePrompt
|
||||
}
|
||||
|
||||
// Map num_inference_steps from direct field
|
||||
if bifrostReq.Params.NumInferenceSteps != nil {
|
||||
req.NumInferenceSteps = bifrostReq.Params.NumInferenceSteps
|
||||
}
|
||||
|
||||
// Parse fal-ai specific params from ExtraParams
|
||||
if bifrostReq.Params.ExtraParams != nil {
|
||||
req.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
// Map guidance_scale
|
||||
if v, ok := schemas.SafeExtractFloat64Pointer(bifrostReq.Params.ExtraParams["guidance_scale"]); ok {
|
||||
delete(req.ExtraParams, "guidance_scale")
|
||||
req.GuidanceScale = v
|
||||
}
|
||||
|
||||
// Map acceleration
|
||||
if v, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["acceleration"]); ok {
|
||||
delete(req.ExtraParams, "acceleration")
|
||||
req.Acceleration = v
|
||||
}
|
||||
|
||||
// Map enable_prompt_expansion
|
||||
if v, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["enable_prompt_expansion"]); ok {
|
||||
delete(req.ExtraParams, "enable_prompt_expansion")
|
||||
req.EnablePromptExpansion = v
|
||||
}
|
||||
|
||||
// Map enable_safety_checker
|
||||
if v, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["enable_safety_checker"]); ok {
|
||||
delete(req.ExtraParams, "enable_safety_checker")
|
||||
req.EnableSafetyChecker = v
|
||||
}
|
||||
}
|
||||
}
|
||||
return req, nil
|
||||
|
||||
case together:
|
||||
req := &HuggingFaceTogetherImageGenerationRequest{
|
||||
Prompt: bifrostReq.Input.Prompt,
|
||||
Model: model,
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil {
|
||||
req.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
if bifrostReq.Params.ResponseFormat != nil {
|
||||
req.ResponseFormat = bifrostReq.Params.ResponseFormat
|
||||
}
|
||||
|
||||
if bifrostReq.Params.Size != nil {
|
||||
req.Size = bifrostReq.Params.Size
|
||||
}
|
||||
|
||||
if bifrostReq.Params.N != nil {
|
||||
req.N = bifrostReq.Params.N
|
||||
}
|
||||
if bifrostReq.Params.ResponseFormat != nil && *bifrostReq.Params.ResponseFormat == "b64_json" {
|
||||
req.ResponseFormat = schemas.Ptr("base64")
|
||||
}
|
||||
if bifrostReq.Params.NumInferenceSteps != nil {
|
||||
req.Steps = bifrostReq.Params.NumInferenceSteps
|
||||
}
|
||||
}
|
||||
return req, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported inference provider for image generation: %s", inferenceProvider)
|
||||
}
|
||||
}
|
||||
|
||||
// ToHuggingFaceImageStreamRequest converts a Bifrost image generation request to fal-ai streaming format
|
||||
func ToHuggingFaceImageStreamRequest(bifrostReq *schemas.BifrostImageGenerationRequest) (*HuggingFaceFalAIImageStreamRequest, error) {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil {
|
||||
return nil, fmt.Errorf("bifrost request is nil or input is nil")
|
||||
}
|
||||
|
||||
req := &HuggingFaceFalAIImageStreamRequest{
|
||||
Prompt: bifrostReq.Input.Prompt,
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil {
|
||||
req.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
// Map n to num_images for fal-ai
|
||||
if bifrostReq.Params.N != nil {
|
||||
req.NumImages = bifrostReq.Params.N
|
||||
}
|
||||
|
||||
// Pass through response_format
|
||||
if bifrostReq.Params.ResponseFormat != nil {
|
||||
req.ResponseFormat = bifrostReq.Params.ResponseFormat
|
||||
}
|
||||
|
||||
// Pass through output_format
|
||||
// Convert "jpg" to "jpeg" for fal-ai (fal-ai only accepts "jpeg", "png", "webp")
|
||||
if bifrostReq.Params.OutputFormat != nil {
|
||||
if strings.ToLower(*bifrostReq.Params.OutputFormat) == "jpg" {
|
||||
req.OutputFormat = schemas.Ptr("jpeg")
|
||||
} else {
|
||||
req.OutputFormat = bifrostReq.Params.OutputFormat
|
||||
}
|
||||
}
|
||||
|
||||
// Convert size from "WxH" format to fal-ai's image_size object
|
||||
if bifrostReq.Params.Size != nil && strings.ToLower(*bifrostReq.Params.Size) != "auto" {
|
||||
size := strings.Split(*bifrostReq.Params.Size, "x")
|
||||
if len(size) == 2 {
|
||||
width, err := strconv.Atoi(size[0])
|
||||
if err == nil {
|
||||
height, err := strconv.Atoi(size[1])
|
||||
if err == nil {
|
||||
req.ImageSize = &HuggingFaceFalAISize{
|
||||
Width: width,
|
||||
Height: height,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if bifrostReq.Params.Seed != nil {
|
||||
req.Seed = bifrostReq.Params.Seed
|
||||
}
|
||||
if bifrostReq.Params.NumInferenceSteps != nil {
|
||||
req.NumInferenceSteps = bifrostReq.Params.NumInferenceSteps
|
||||
}
|
||||
if bifrostReq.Params.ResponseFormat != nil && *bifrostReq.Params.ResponseFormat == "b64_json" {
|
||||
req.SyncMode = schemas.Ptr(true)
|
||||
}
|
||||
if bifrostReq.Params.Moderation != nil && *bifrostReq.Params.Moderation == "low" {
|
||||
req.EnableSafetyChecker = schemas.Ptr(false)
|
||||
}
|
||||
|
||||
// Parse fal-ai specific params from ExtraParams
|
||||
if bifrostReq.Params.ExtraParams != nil {
|
||||
if v, ok := schemas.SafeExtractFloat64Pointer(bifrostReq.Params.ExtraParams["guidance_scale"]); ok {
|
||||
delete(req.ExtraParams, "guidance_scale")
|
||||
req.GuidanceScale = v
|
||||
}
|
||||
if v, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["acceleration"]); ok {
|
||||
delete(req.ExtraParams, "acceleration")
|
||||
req.Acceleration = v
|
||||
}
|
||||
if v, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["enable_prompt_expansion"]); ok {
|
||||
delete(req.ExtraParams, "enable_prompt_expansion")
|
||||
req.EnablePromptExpansion = v
|
||||
}
|
||||
if v, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["enable_safety_checker"]); ok {
|
||||
delete(req.ExtraParams, "enable_safety_checker")
|
||||
req.EnableSafetyChecker = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// UnmarshalHuggingFaceImageGenerationResponse unmarshals HuggingFace image generation response to Bifrost format
|
||||
func UnmarshalHuggingFaceImageGenerationResponse(data []byte, model string) (*schemas.BifrostImageGenerationResponse, error) {
|
||||
if data == nil {
|
||||
return nil, fmt.Errorf("response data is nil")
|
||||
}
|
||||
inferenceProvider, _, err := splitIntoModelProvider(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch inferenceProvider {
|
||||
case nebius:
|
||||
// Unmarshal into Nebius response format
|
||||
var nebiusResponse nebiusProvider.NebiusImageGenerationResponse
|
||||
if err := sonic.Unmarshal(data, &nebiusResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal Nebius response: %w", err)
|
||||
}
|
||||
|
||||
// Convert to Bifrost format using Nebius converter
|
||||
bifrostResponse := nebiusProvider.ToBifrostImageResponse(&nebiusResponse)
|
||||
if bifrostResponse == nil {
|
||||
return nil, fmt.Errorf("failed to convert Nebius response to Bifrost format")
|
||||
}
|
||||
|
||||
// Set model field (Nebius converter doesn't set it, similar to embeddings pattern)
|
||||
if bifrostResponse.Model == "" {
|
||||
bifrostResponse.Model = model
|
||||
}
|
||||
|
||||
return bifrostResponse, nil
|
||||
|
||||
case hfInference:
|
||||
// Handle raw byte data - encode to base64
|
||||
b64Data := base64.StdEncoding.EncodeToString(data)
|
||||
return &schemas.BifrostImageGenerationResponse{
|
||||
Model: model,
|
||||
Data: []schemas.ImageData{
|
||||
{
|
||||
B64JSON: b64Data,
|
||||
Index: 0,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
|
||||
case falAI:
|
||||
// Handle fal-ai JSON response
|
||||
var falResponse HuggingFaceFalAIImageGenerationResponse
|
||||
if err := sonic.Unmarshal(data, &falResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal fal-ai response: %w", err)
|
||||
}
|
||||
|
||||
imageData := make([]schemas.ImageData, len(falResponse.Images))
|
||||
for i, img := range falResponse.Images {
|
||||
// Handle both URL and base64 responses
|
||||
imageData[i] = schemas.ImageData{
|
||||
URL: img.URL,
|
||||
B64JSON: img.B64JSON,
|
||||
Index: i,
|
||||
}
|
||||
}
|
||||
|
||||
return &schemas.BifrostImageGenerationResponse{
|
||||
Model: model,
|
||||
Data: imageData,
|
||||
}, nil
|
||||
|
||||
case together:
|
||||
// Handle together JSON response
|
||||
var togetherResponse HuggingFaceTogetherImageGenerationResponse
|
||||
if err := sonic.Unmarshal(data, &togetherResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal together response: %w", err)
|
||||
}
|
||||
|
||||
imageData := make([]schemas.ImageData, len(togetherResponse.Data))
|
||||
for i, img := range togetherResponse.Data {
|
||||
imageData[i] = schemas.ImageData{
|
||||
B64JSON: img.B64JSON,
|
||||
URL: img.URL,
|
||||
Index: i,
|
||||
}
|
||||
}
|
||||
|
||||
return &schemas.BifrostImageGenerationResponse{
|
||||
Model: model,
|
||||
Data: imageData,
|
||||
}, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported inference provider: %s", inferenceProvider)
|
||||
}
|
||||
}
|
||||
|
||||
// imageBytesToBase64DataURL converts raw image bytes to base64 data URL format
|
||||
func imageBytesToBase64DataURL(imageBytes []byte) string {
|
||||
mimeType := http.DetectContentType(imageBytes)
|
||||
b64Data := base64.StdEncoding.EncodeToString(imageBytes)
|
||||
return fmt.Sprintf("data:%s;base64,%s", mimeType, b64Data)
|
||||
}
|
||||
|
||||
// mapFalAIImageEditParams maps common parameters from Bifrost request to fal-ai request
|
||||
func mapFalAIImageEditParams(bifrostReq *schemas.BifrostImageEditRequest, req *HuggingFaceFalAIImageEditRequest) {
|
||||
if bifrostReq.Params == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Map n to num_images for fal-ai
|
||||
if bifrostReq.Params.N != nil {
|
||||
req.NumImages = bifrostReq.Params.N
|
||||
}
|
||||
|
||||
// Pass through output_format
|
||||
if bifrostReq.Params.OutputFormat != nil {
|
||||
if strings.ToLower(*bifrostReq.Params.OutputFormat) == "jpg" {
|
||||
req.OutputFormat = schemas.Ptr("jpeg")
|
||||
} else {
|
||||
req.OutputFormat = bifrostReq.Params.OutputFormat
|
||||
}
|
||||
}
|
||||
|
||||
if bifrostReq.Params.ResponseFormat != nil && *bifrostReq.Params.ResponseFormat == "b64_json" {
|
||||
req.SyncMode = schemas.Ptr(true)
|
||||
}
|
||||
|
||||
// Convert size from "WxH" format to fal-ai's image_size object
|
||||
if bifrostReq.Params.Size != nil && strings.ToLower(*bifrostReq.Params.Size) != "auto" {
|
||||
size := strings.Split(*bifrostReq.Params.Size, "x")
|
||||
if len(size) == 2 {
|
||||
width, err := strconv.Atoi(size[0])
|
||||
if err == nil {
|
||||
height, err := strconv.Atoi(size[1])
|
||||
if err == nil {
|
||||
req.ImageSize = &HuggingFaceFalAISize{
|
||||
Width: width,
|
||||
Height: height,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pass-through num_inference_steps
|
||||
if bifrostReq.Params.NumInferenceSteps != nil {
|
||||
req.NumInferenceSteps = bifrostReq.Params.NumInferenceSteps
|
||||
}
|
||||
|
||||
// Pass-through seed
|
||||
if bifrostReq.Params.Seed != nil {
|
||||
req.Seed = bifrostReq.Params.Seed
|
||||
}
|
||||
|
||||
// Parse fal-ai specific params from ExtraParams
|
||||
if bifrostReq.Params.ExtraParams != nil {
|
||||
// Map guidance_scale
|
||||
if v, ok := schemas.SafeExtractFloat64Pointer(bifrostReq.Params.ExtraParams["guidance_scale"]); ok {
|
||||
delete(req.ExtraParams, "guidance_scale")
|
||||
req.GuidanceScale = v
|
||||
}
|
||||
|
||||
// Map acceleration
|
||||
if v, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["acceleration"]); ok {
|
||||
delete(req.ExtraParams, "acceleration")
|
||||
req.Acceleration = v
|
||||
}
|
||||
|
||||
// Map enable_safety_checker
|
||||
if v, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["enable_safety_checker"]); ok {
|
||||
delete(req.ExtraParams, "enable_safety_checker")
|
||||
req.EnableSafetyChecker = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ToHuggingFaceImageEditRequest converts a Bifrost image edit request to fal-ai format
|
||||
func ToHuggingFaceImageEditRequest(bifrostReq *schemas.BifrostImageEditRequest) (*HuggingFaceFalAIImageEditRequest, error) {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil {
|
||||
return nil, fmt.Errorf("bifrost request is nil or input is nil")
|
||||
}
|
||||
|
||||
if len(bifrostReq.Input.Images) == 0 {
|
||||
return nil, fmt.Errorf("at least one image is required")
|
||||
}
|
||||
|
||||
// Convert images to base64 data URLs
|
||||
imageURLs := make([]string, 0, len(bifrostReq.Input.Images))
|
||||
for _, img := range bifrostReq.Input.Images {
|
||||
if len(img.Image) == 0 {
|
||||
continue
|
||||
}
|
||||
imageURLs = append(imageURLs, imageBytesToBase64DataURL(img.Image))
|
||||
}
|
||||
|
||||
if len(imageURLs) == 0 {
|
||||
return nil, fmt.Errorf("no valid images found")
|
||||
}
|
||||
|
||||
// Extract model name to determine image field strategy
|
||||
_, modelName, err := splitIntoModelProvider(bifrostReq.Model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to split model name: %w", err)
|
||||
}
|
||||
|
||||
req := &HuggingFaceFalAIImageEditRequest{
|
||||
Prompt: bifrostReq.Input.Prompt,
|
||||
}
|
||||
|
||||
// Check for explicit override in ExtraParams
|
||||
var useMultiImage *bool
|
||||
if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil {
|
||||
req.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
if v, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["use_image_urls"]); ok {
|
||||
delete(req.ExtraParams, "use_image_urls")
|
||||
useMultiImage = v
|
||||
}
|
||||
}
|
||||
|
||||
// Determine which image field to use based on model capabilities
|
||||
if useMultiImage != nil {
|
||||
// Explicit override from user
|
||||
if *useMultiImage {
|
||||
req.ImageURLs = imageURLs
|
||||
} else if len(imageURLs) == 1 {
|
||||
req.ImageURL = &imageURLs[0]
|
||||
} else {
|
||||
return nil, fmt.Errorf("use_image_urls is false but multiple images provided (%d images)", len(imageURLs))
|
||||
}
|
||||
} else if falAIMultiImageEditModels[modelName] {
|
||||
// Model supports multiple images - always use image_urls
|
||||
req.ImageURLs = imageURLs
|
||||
} else if falAISingleImageEditModels[modelName] {
|
||||
// Model only supports single image - validate and use image_url
|
||||
if len(imageURLs) == 1 {
|
||||
req.ImageURL = &imageURLs[0]
|
||||
} else {
|
||||
return nil, fmt.Errorf("model %s only supports single image, got %d images", modelName, len(imageURLs))
|
||||
}
|
||||
} else {
|
||||
// Unknown model - fallback to count-based logic
|
||||
if len(imageURLs) == 1 {
|
||||
req.ImageURL = &imageURLs[0]
|
||||
} else {
|
||||
req.ImageURLs = imageURLs
|
||||
}
|
||||
}
|
||||
|
||||
// Map common parameters
|
||||
mapFalAIImageEditParams(bifrostReq, req)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// extractImagesFromStreamResponse extracts images from a fal-ai streaming response.
|
||||
// Handles both API envelope structure (Data.Images) and legacy flattened format (top-level Images).
|
||||
func extractImagesFromStreamResponse(response *HuggingFaceFalAIImageStreamResponse) []FalAIImage {
|
||||
// Prefer Data.Images if available (API envelope structure)
|
||||
if response.Data != nil && len(response.Data.Images) > 0 {
|
||||
return response.Data.Images
|
||||
}
|
||||
// Fall back to top-level Images (legacy format)
|
||||
return response.Images
|
||||
}
|
||||
140
core/providers/huggingface/models.go
Normal file
140
core/providers/huggingface/models.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package huggingface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultModelFetchLimit = 200
|
||||
maxModelFetchLimit = 1000
|
||||
)
|
||||
|
||||
func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, inferenceProvider inferenceProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostListModelsResponse{
|
||||
Data: make([]schemas.Model, 0, len(response.Models)),
|
||||
}
|
||||
|
||||
pipeline := &providerUtils.ListModelsPipeline{
|
||||
AllowedModels: allowedModels,
|
||||
BlacklistedModels: blacklistedModels,
|
||||
Aliases: aliases,
|
||||
Unfiltered: unfiltered,
|
||||
ProviderKey: providerKey,
|
||||
MatchFns: providerUtils.DefaultMatchFns(),
|
||||
}
|
||||
if pipeline.ShouldEarlyExit() {
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
included := make(map[string]bool)
|
||||
|
||||
for _, model := range response.Models {
|
||||
if model.ModelID == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
supported := deriveSupportedMethods(model.PipelineTag, model.Tags)
|
||||
if len(supported) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Aliases apply at the model level (model.ModelID), not at the compound
|
||||
// "{providerKey}/{inferenceProvider}/{modelID}" level.
|
||||
for _, result := range pipeline.FilterModel(model.ModelID) {
|
||||
newModel := schemas.Model{
|
||||
// inferenceProvider stays in the compound ID; aliases rename only the model segment
|
||||
ID: fmt.Sprintf("%s/%s/%s", providerKey, inferenceProvider, result.ResolvedID),
|
||||
Name: schemas.Ptr(model.ModelID),
|
||||
SupportedMethods: supported,
|
||||
HuggingFaceID: schemas.Ptr(model.ID),
|
||||
}
|
||||
if result.AliasValue != "" {
|
||||
newModel.Alias = schemas.Ptr(result.AliasValue)
|
||||
}
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, newModel)
|
||||
included[strings.ToLower(result.ResolvedID)] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Backfill: use standard pipeline. Note that backfilled HF entries use a simplified
|
||||
// compound ID since we don't know which inferenceProvider to assign them to.
|
||||
for _, m := range pipeline.BackfillModels(included) {
|
||||
// Re-wrap the backfill ID to include the inferenceProvider segment
|
||||
rawID := strings.TrimPrefix(m.ID, string(providerKey)+"/")
|
||||
m.ID = fmt.Sprintf("%s/%s/%s", providerKey, inferenceProvider, rawID)
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, m)
|
||||
}
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
func deriveSupportedMethods(pipeline string, tags []string) []string {
|
||||
normalized := strings.TrimSpace(strings.ToLower(pipeline))
|
||||
|
||||
methodsSet := map[schemas.RequestType]struct{}{}
|
||||
|
||||
addMethods := func(methods ...schemas.RequestType) {
|
||||
for _, method := range methods {
|
||||
methodsSet[method] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
switch normalized {
|
||||
case "conversational", "chat-completion":
|
||||
addMethods(schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest,
|
||||
schemas.ResponsesRequest, schemas.ResponsesStreamRequest)
|
||||
case "feature-extraction":
|
||||
addMethods(schemas.EmbeddingRequest)
|
||||
case "text-to-speech":
|
||||
addMethods(schemas.SpeechRequest)
|
||||
case "automatic-speech-recognition":
|
||||
addMethods(schemas.TranscriptionRequest)
|
||||
case "text-to-image":
|
||||
addMethods(schemas.ImageGenerationRequest, schemas.ImageGenerationStreamRequest)
|
||||
}
|
||||
|
||||
for _, tag := range tags {
|
||||
tagLower := strings.ToLower(tag)
|
||||
switch {
|
||||
case tagLower == "text-embedding" || tagLower == "sentence-similarity" ||
|
||||
tagLower == "feature-extraction" || tagLower == "embeddings" ||
|
||||
tagLower == "sentence-transformers" || strings.Contains(tagLower, "embedding"):
|
||||
addMethods(schemas.EmbeddingRequest)
|
||||
case tagLower == "text-generation" || tagLower == "summarization" ||
|
||||
tagLower == "conversational" || tagLower == "chat-completion" ||
|
||||
tagLower == "text2text-generation" || tagLower == "question-answering" ||
|
||||
strings.Contains(tagLower, "chat") || strings.Contains(tagLower, "completion"):
|
||||
addMethods(schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest,
|
||||
schemas.ResponsesRequest, schemas.ResponsesStreamRequest)
|
||||
case tagLower == "text-to-speech" || tagLower == "tts" ||
|
||||
strings.Contains(tagLower, "text-to-speech"):
|
||||
addMethods(schemas.SpeechRequest)
|
||||
case tagLower == "automatic-speech-recognition" ||
|
||||
tagLower == "speech-to-text" || strings.Contains(tagLower, "speech-recognition"):
|
||||
addMethods(schemas.TranscriptionRequest)
|
||||
case tagLower == "text-to-image" || strings.Contains(tagLower, "image-generation"):
|
||||
addMethods(schemas.ImageGenerationRequest, schemas.ImageGenerationStreamRequest)
|
||||
}
|
||||
}
|
||||
|
||||
if len(methodsSet) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
methods := make([]string, 0, len(methodsSet))
|
||||
for method := range methodsSet {
|
||||
methods = append(methods, string(method))
|
||||
}
|
||||
|
||||
slices.Sort(methods)
|
||||
return methods
|
||||
}
|
||||
51
core/providers/huggingface/payload_ordering_test.go
Normal file
51
core/providers/huggingface/payload_ordering_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package huggingface
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPayloadOrdering_HuggingFaceChatRequest(t *testing.T) {
|
||||
req := &HuggingFaceChatRequest{
|
||||
Model: "meta-llama/Llama-3-70B-Instruct",
|
||||
Messages: []schemas.ChatMessage{
|
||||
{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("hello")}},
|
||||
},
|
||||
Temperature: schemas.Ptr(0.7),
|
||||
Stream: schemas.Ptr(true),
|
||||
Tools: []schemas.ChatTool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: schemas.Ptr("Get weather"),
|
||||
Parameters: &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("location", map[string]interface{}{"type": "string"}),
|
||||
),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := providerUtils.MarshalSorted(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
golden := `{"messages":[{"role":"user","content":"hello"}],"model":"meta-llama/Llama-3-70B-Instruct","stream":true,"temperature":0.7,"tools":[{"type":"function","function":{"name":"get_weather","description":"Get weather","parameters":{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}}}]}`
|
||||
|
||||
assert.Equal(t, golden, string(result), "payload field ordering changed — if intentional, update the golden string")
|
||||
|
||||
// Determinism: 100 iterations must produce identical bytes
|
||||
for i := 0; i < 100; i++ {
|
||||
iter, err := providerUtils.MarshalSorted(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(result), string(iter), "non-deterministic marshal output on iteration %d", i)
|
||||
}
|
||||
}
|
||||
49
core/providers/huggingface/responses.go
Normal file
49
core/providers/huggingface/responses.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package huggingface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToHuggingFaceResponsesRequest converts a Bifrost Responses request into the Hugging Face
|
||||
// chat-completions payload that the provider already understands.
|
||||
func ToHuggingFaceResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) (*HuggingFaceChatRequest, error) {
|
||||
if bifrostReq == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
chatReq := bifrostReq.ToChatRequest()
|
||||
if chatReq == nil {
|
||||
return nil, fmt.Errorf("failed to convert responses request to chat request")
|
||||
}
|
||||
|
||||
hfReq, err := ToHuggingFaceChatCompletionRequest(chatReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if hfReq == nil {
|
||||
return nil, fmt.Errorf("failed to convert chat request to Hugging Face request")
|
||||
}
|
||||
|
||||
return hfReq, nil
|
||||
}
|
||||
|
||||
// ToBifrostResponsesResponseFromHuggingFace converts a Bifrost chat response into the
|
||||
// Bifrost Responses response shape, preserving provider metadata.
|
||||
func ToBifrostResponsesResponseFromHuggingFace(resp *schemas.BifrostChatResponse, requestedModel string) (*schemas.BifrostResponsesResponse, error) {
|
||||
if resp == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Ensure model is set
|
||||
if resp.Model == "" {
|
||||
resp.Model = requestedModel
|
||||
}
|
||||
|
||||
responsesResp := resp.ToBifrostResponsesResponse()
|
||||
if responsesResp != nil {
|
||||
}
|
||||
|
||||
return responsesResp, nil
|
||||
}
|
||||
134
core/providers/huggingface/speech.go
Normal file
134
core/providers/huggingface/speech.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package huggingface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func ToHuggingFaceSpeechRequest(request *schemas.BifrostSpeechRequest) (*HuggingFaceSpeechRequest, error) {
|
||||
if request == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if request.Input == nil {
|
||||
return nil, fmt.Errorf("speech request input cannot be nil")
|
||||
}
|
||||
|
||||
inferenceProvider, modelName, nameErr := splitIntoModelProvider(request.Model)
|
||||
if nameErr != nil {
|
||||
return nil, nameErr
|
||||
}
|
||||
|
||||
// HuggingFace expects text in the Text field (for TTS - Text To Speech)
|
||||
hfRequest := &HuggingFaceSpeechRequest{
|
||||
Text: request.Input.Input,
|
||||
Model: modelName,
|
||||
Provider: string(inferenceProvider),
|
||||
}
|
||||
|
||||
// Map parameters if present
|
||||
if request.Params != nil {
|
||||
hfRequest.Parameters = &HuggingFaceSpeechParameters{}
|
||||
|
||||
// Map generation parameters from ExtraParams if available
|
||||
if request.Params.ExtraParams != nil {
|
||||
genParams := &HuggingFaceTranscriptionGenerationParameters{}
|
||||
|
||||
if val, ok := request.Params.ExtraParams["do_sample"].(bool); ok {
|
||||
delete(request.Params.ExtraParams, "do_sample")
|
||||
genParams.DoSample = &val
|
||||
}
|
||||
if v, ok := schemas.SafeExtractIntPointer(request.Params.ExtraParams["max_new_tokens"]); ok {
|
||||
delete(request.Params.ExtraParams, "max_new_tokens")
|
||||
genParams.MaxNewTokens = v
|
||||
}
|
||||
if v, ok := schemas.SafeExtractIntPointer(request.Params.ExtraParams["max_length"]); ok {
|
||||
delete(request.Params.ExtraParams, "max_length")
|
||||
genParams.MaxLength = v
|
||||
}
|
||||
if v, ok := schemas.SafeExtractIntPointer(request.Params.ExtraParams["min_length"]); ok {
|
||||
delete(request.Params.ExtraParams, "min_length")
|
||||
genParams.MinLength = v
|
||||
}
|
||||
if v, ok := schemas.SafeExtractIntPointer(request.Params.ExtraParams["min_new_tokens"]); ok {
|
||||
delete(request.Params.ExtraParams, "min_new_tokens")
|
||||
genParams.MinNewTokens = v
|
||||
}
|
||||
if v, ok := schemas.SafeExtractIntPointer(request.Params.ExtraParams["num_beams"]); ok {
|
||||
delete(request.Params.ExtraParams, "num_beams")
|
||||
genParams.NumBeams = v
|
||||
}
|
||||
if v, ok := schemas.SafeExtractIntPointer(request.Params.ExtraParams["num_beam_groups"]); ok {
|
||||
delete(request.Params.ExtraParams, "num_beam_groups")
|
||||
genParams.NumBeamGroups = v
|
||||
}
|
||||
if val, ok := request.Params.ExtraParams["penalty_alpha"].(float64); ok {
|
||||
delete(request.Params.ExtraParams, "penalty_alpha")
|
||||
genParams.PenaltyAlpha = &val
|
||||
}
|
||||
if val, ok := request.Params.ExtraParams["temperature"].(float64); ok {
|
||||
delete(request.Params.ExtraParams, "temperature")
|
||||
genParams.Temperature = &val
|
||||
}
|
||||
if v, ok := schemas.SafeExtractIntPointer(request.Params.ExtraParams["top_k"]); ok {
|
||||
delete(request.Params.ExtraParams, "top_k")
|
||||
genParams.TopK = v
|
||||
}
|
||||
if val, ok := request.Params.ExtraParams["top_p"].(float64); ok {
|
||||
delete(request.Params.ExtraParams, "top_p")
|
||||
genParams.TopP = &val
|
||||
}
|
||||
if val, ok := request.Params.ExtraParams["typical_p"].(float64); ok {
|
||||
delete(request.Params.ExtraParams, "typical_p")
|
||||
genParams.TypicalP = &val
|
||||
}
|
||||
if val, ok := request.Params.ExtraParams["use_cache"].(bool); ok {
|
||||
delete(request.Params.ExtraParams, "use_cache")
|
||||
genParams.UseCache = &val
|
||||
}
|
||||
if val, ok := request.Params.ExtraParams["epsilon_cutoff"].(float64); ok {
|
||||
delete(request.Params.ExtraParams, "epsilon_cutoff")
|
||||
genParams.EpsilonCutoff = &val
|
||||
}
|
||||
if val, ok := request.Params.ExtraParams["eta_cutoff"].(float64); ok {
|
||||
delete(request.Params.ExtraParams, "eta_cutoff")
|
||||
genParams.EtaCutoff = &val
|
||||
}
|
||||
|
||||
// Handle early_stopping (can be bool or string "never")
|
||||
if val, ok := request.Params.ExtraParams["early_stopping"].(bool); ok {
|
||||
delete(request.Params.ExtraParams, "early_stopping")
|
||||
genParams.EarlyStopping = &HuggingFaceTranscriptionEarlyStopping{BoolValue: &val}
|
||||
} else if val, ok := request.Params.ExtraParams["early_stopping"].(string); ok {
|
||||
delete(request.Params.ExtraParams, "early_stopping")
|
||||
genParams.EarlyStopping = &HuggingFaceTranscriptionEarlyStopping{StringValue: &val}
|
||||
}
|
||||
|
||||
hfRequest.Parameters.GenerationParameters = genParams
|
||||
}
|
||||
}
|
||||
hfRequest.ExtraParams = request.Params.ExtraParams
|
||||
|
||||
return hfRequest, nil
|
||||
}
|
||||
|
||||
func (response *HuggingFaceSpeechResponse) ToBifrostSpeechResponse(requestedModel string, audioData []byte) (*schemas.BifrostSpeechResponse, error) {
|
||||
if response == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if requestedModel == "" {
|
||||
return nil, fmt.Errorf("model name cannot be empty")
|
||||
}
|
||||
|
||||
// Create the base Bifrost response with the downloaded audio data
|
||||
bifrostResponse := &schemas.BifrostSpeechResponse{
|
||||
Audio: audioData,
|
||||
}
|
||||
|
||||
// Note: HuggingFace TTS API typically doesn't return usage information
|
||||
// or alignment data, so we leave those fields as nil
|
||||
|
||||
return bifrostResponse, nil
|
||||
}
|
||||
170
core/providers/huggingface/transcription.go
Normal file
170
core/providers/huggingface/transcription.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package huggingface
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func ToHuggingFaceTranscriptionRequest(request *schemas.BifrostTranscriptionRequest) (*HuggingFaceTranscriptionRequest, error) {
|
||||
if request == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if request.Input == nil {
|
||||
return nil, fmt.Errorf("transcription request input cannot be nil")
|
||||
}
|
||||
|
||||
if len(request.Input.File) == 0 {
|
||||
return nil, fmt.Errorf("transcription request audio file cannot be empty")
|
||||
}
|
||||
|
||||
inferenceProvider, modelName, nameErr := splitIntoModelProvider(request.Model)
|
||||
if nameErr != nil {
|
||||
return nil, nameErr
|
||||
}
|
||||
|
||||
var hfRequest *HuggingFaceTranscriptionRequest
|
||||
// HuggingFace expects audio data in the Inputs field (for ASR - Automatic Speech Recognition)
|
||||
if inferenceProvider != falAI {
|
||||
hfRequest = &HuggingFaceTranscriptionRequest{
|
||||
Inputs: request.Input.File,
|
||||
Model: schemas.Ptr(modelName),
|
||||
Provider: schemas.Ptr(string(inferenceProvider)),
|
||||
}
|
||||
} else {
|
||||
encoded := base64.StdEncoding.EncodeToString(request.Input.File)
|
||||
mimeType := getMimeTypeForAudioType(utils.DetectAudioMimeType(request.Input.File))
|
||||
if mimeType == "audio/wav" {
|
||||
return nil, fmt.Errorf("fal-ai provider does not support audio/wav format; please use a different format like mp3 or ogg")
|
||||
}
|
||||
encoded = fmt.Sprintf("data:%s;base64,%s", mimeType, encoded)
|
||||
hfRequest = &HuggingFaceTranscriptionRequest{
|
||||
AudioURL: encoded,
|
||||
}
|
||||
}
|
||||
// Map parameters if present
|
||||
if request.Params != nil {
|
||||
hfRequest.Parameters = &HuggingFaceTranscriptionRequestParameters{}
|
||||
genParams := &HuggingFaceTranscriptionGenerationParameters{}
|
||||
|
||||
if v, ok := schemas.SafeExtractIntPointer(request.Params.MaxNewTokens); ok {
|
||||
genParams.MaxNewTokens = v
|
||||
}
|
||||
if v, ok := schemas.SafeExtractIntPointer(request.Params.MaxLength); ok {
|
||||
genParams.MaxLength = v
|
||||
}
|
||||
if v, ok := schemas.SafeExtractIntPointer(request.Params.MinLength); ok {
|
||||
genParams.MinLength = v
|
||||
}
|
||||
if v, ok := schemas.SafeExtractIntPointer(request.Params.MinNewTokens); ok {
|
||||
genParams.MinNewTokens = v
|
||||
}
|
||||
|
||||
if request.Params.ExtraParams != nil {
|
||||
extra := request.Params.ExtraParams
|
||||
if val, ok := extra["do_sample"].(bool); ok {
|
||||
delete(extra, "do_sample")
|
||||
genParams.DoSample = &val
|
||||
}
|
||||
if v, ok := schemas.SafeExtractIntPointer(extra["num_beams"]); ok {
|
||||
delete(extra, "num_beams")
|
||||
genParams.NumBeams = v
|
||||
}
|
||||
if v, ok := schemas.SafeExtractIntPointer(extra["num_beam_groups"]); ok {
|
||||
delete(extra, "num_beam_groups")
|
||||
genParams.NumBeamGroups = v
|
||||
}
|
||||
if val, ok := extra["penalty_alpha"].(float64); ok {
|
||||
delete(extra, "penalty_alpha")
|
||||
genParams.PenaltyAlpha = &val
|
||||
}
|
||||
if val, ok := extra["temperature"].(float64); ok {
|
||||
delete(extra, "temperature")
|
||||
genParams.Temperature = &val
|
||||
}
|
||||
if v, ok := schemas.SafeExtractIntPointer(extra["top_k"]); ok {
|
||||
delete(extra, "top_k")
|
||||
genParams.TopK = v
|
||||
}
|
||||
if val, ok := extra["top_p"].(float64); ok {
|
||||
delete(extra, "top_p")
|
||||
genParams.TopP = &val
|
||||
}
|
||||
if val, ok := extra["typical_p"].(float64); ok {
|
||||
delete(extra, "typical_p")
|
||||
genParams.TypicalP = &val
|
||||
}
|
||||
if val, ok := extra["use_cache"].(bool); ok {
|
||||
delete(extra, "use_cache")
|
||||
genParams.UseCache = &val
|
||||
}
|
||||
if val, ok := extra["epsilon_cutoff"].(float64); ok {
|
||||
delete(extra, "epsilon_cutoff")
|
||||
genParams.EpsilonCutoff = &val
|
||||
}
|
||||
if val, ok := extra["eta_cutoff"].(float64); ok {
|
||||
delete(extra, "eta_cutoff")
|
||||
genParams.EtaCutoff = &val
|
||||
}
|
||||
|
||||
// Handle early_stopping (can be bool or string "never")
|
||||
if val, ok := extra["early_stopping"].(bool); ok {
|
||||
delete(extra, "early_stopping")
|
||||
genParams.EarlyStopping = &HuggingFaceTranscriptionEarlyStopping{BoolValue: &val}
|
||||
} else if val, ok := extra["early_stopping"].(string); ok {
|
||||
delete(extra, "early_stopping")
|
||||
genParams.EarlyStopping = &HuggingFaceTranscriptionEarlyStopping{StringValue: &val}
|
||||
}
|
||||
|
||||
// Handle return_timestamps
|
||||
if val, ok := extra["return_timestamps"].(bool); ok {
|
||||
delete(extra, "return_timestamps")
|
||||
hfRequest.Parameters.ReturnTimestamps = &val
|
||||
}
|
||||
}
|
||||
hfRequest.ExtraParams = request.Params.ExtraParams
|
||||
hfRequest.Parameters.GenerationParameters = genParams
|
||||
}
|
||||
|
||||
return hfRequest, nil
|
||||
}
|
||||
|
||||
func (response *HuggingFaceTranscriptionResponse) ToBifrostTranscriptionResponse(requestedModel string) (*schemas.BifrostTranscriptionResponse, error) {
|
||||
if response == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if requestedModel == "" {
|
||||
return nil, fmt.Errorf("model name cannot be empty")
|
||||
}
|
||||
|
||||
// Create the base Bifrost response
|
||||
bifrostResponse := &schemas.BifrostTranscriptionResponse{
|
||||
Text: response.Text,
|
||||
}
|
||||
|
||||
// Map chunks to segments if available
|
||||
if len(response.Chunks) > 0 {
|
||||
segments := make([]schemas.TranscriptionSegment, len(response.Chunks))
|
||||
for i, chunk := range response.Chunks {
|
||||
var start, end float64
|
||||
if len(chunk.Timestamp) >= 2 {
|
||||
start = chunk.Timestamp[0]
|
||||
end = chunk.Timestamp[1]
|
||||
}
|
||||
|
||||
segments[i] = schemas.TranscriptionSegment{
|
||||
ID: i,
|
||||
Start: start,
|
||||
End: end,
|
||||
Text: chunk.Text,
|
||||
}
|
||||
}
|
||||
bifrostResponse.Segments = segments
|
||||
}
|
||||
|
||||
return bifrostResponse, nil
|
||||
}
|
||||
529
core/providers/huggingface/types.go
Normal file
529
core/providers/huggingface/types.go
Normal file
@@ -0,0 +1,529 @@
|
||||
package huggingface
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// # MODELS TYPES
|
||||
|
||||
// refered from https://huggingface.co/api/models
|
||||
type HuggingFaceModel struct {
|
||||
ID string `json:"_id"`
|
||||
ModelID string `json:"modelId"`
|
||||
Likes int `json:"likes"`
|
||||
TrendingScore int `json:"trendingScore"`
|
||||
Private bool `json:"private"`
|
||||
Downloads int `json:"downloads"`
|
||||
Tags []string `json:"tags"`
|
||||
PipelineTag string `json:"pipeline_tag"`
|
||||
LibraryName string `json:"library_name"`
|
||||
CreatedAt string `json:"createdAt"`
|
||||
}
|
||||
|
||||
type HuggingFaceListModelsResponse struct {
|
||||
Models []HuggingFaceModel `json:"models"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON supports both the older object form `{"models": [...]}`
|
||||
// and the current API which returns a top-level JSON array `[...]`.
|
||||
func (r *HuggingFaceListModelsResponse) UnmarshalJSON(data []byte) error {
|
||||
// Try unmarshaling as an array first (most common for /api/models)
|
||||
var arr []HuggingFaceModel
|
||||
if err := sonic.Unmarshal(data, &arr); err == nil {
|
||||
r.Models = arr
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fallback: try object with a `models` field
|
||||
var obj struct {
|
||||
Models []HuggingFaceModel `json:"models"`
|
||||
}
|
||||
if err := sonic.Unmarshal(data, &obj); err == nil {
|
||||
r.Models = obj.Models
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to unmarshal HuggingFaceListModelsResponse: unexpected JSON structure")
|
||||
}
|
||||
|
||||
type HuggingFaceInferenceProviderMappingResponse struct {
|
||||
ID string `json:"_id"`
|
||||
ModelID string `json:"id"`
|
||||
PipelineTag string `json:"pipeline_tag"`
|
||||
InferenceProviderMapping map[string]HuggingFaceInferenceProviderInfo `json:"inferenceProviderMapping"`
|
||||
}
|
||||
|
||||
type HuggingFaceInferenceProviderInfo struct {
|
||||
Status string `json:"status"`
|
||||
ProviderModelID string `json:"providerId"`
|
||||
Task string `json:"task"`
|
||||
IsModelAuthor bool `json:"isModelAuthor"`
|
||||
}
|
||||
|
||||
type HuggingFaceInferenceProviderMapping struct {
|
||||
ProviderTask string
|
||||
ProviderModelID string
|
||||
}
|
||||
|
||||
// # CHAT TYPES
|
||||
|
||||
// Flexible/chat request types for HuggingFace-like chat completion payloads.
|
||||
type HuggingFaceChatRequest struct {
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
Logprobs *bool `json:"logprobs,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
Messages []schemas.ChatMessage `json:"messages"`
|
||||
Model string `json:"model" validate:"required"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat *HuggingFaceResponseFormat `json:"response_format,omitempty"`
|
||||
Seed *int `json:"seed,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
StreamOptions *schemas.ChatStreamOptions `json:"stream_options,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
ToolChoice *HuggingFaceToolChoice `json:"tool_choice,omitempty"`
|
||||
ToolPrompt *string `json:"tool_prompt,omitempty"`
|
||||
Tools []schemas.ChatTool `json:"tools,omitempty"`
|
||||
TopLogprobs *int `json:"top_logprobs,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
func (req *HuggingFaceChatRequest) GetExtraParams() map[string]interface{} {
|
||||
return req.ExtraParams
|
||||
}
|
||||
|
||||
// HuggingFaceToolChoice represents the flexible `tool_choice` field which
|
||||
// can be either one of the enum strings: "auto", "none", "required",
|
||||
// or an object with a `function` sub-object containing a required `name`.
|
||||
type HuggingFaceToolChoice struct {
|
||||
EnumValue *EnumStringType
|
||||
|
||||
// Function holds the function object when the field is a JSON object.
|
||||
Function *schemas.ChatToolChoiceFunction
|
||||
}
|
||||
|
||||
type EnumStringType string
|
||||
|
||||
const (
|
||||
EnumStringTypeAuto EnumStringType = "auto"
|
||||
EnumStringTypeNone EnumStringType = "none"
|
||||
EnumStringTypeRequired EnumStringType = "required"
|
||||
)
|
||||
|
||||
// MarshalJSON will emit either a JSON string for enum values or an object
|
||||
// containing the `function` key and `type` field.
|
||||
func (t HuggingFaceToolChoice) MarshalJSON() ([]byte, error) {
|
||||
if t.EnumValue != nil {
|
||||
return providerUtils.MarshalSorted(*t.EnumValue)
|
||||
}
|
||||
if t.Function != nil {
|
||||
return providerUtils.MarshalSorted(struct {
|
||||
Type string `json:"type"`
|
||||
Function *schemas.ChatToolChoiceFunction `json:"function"`
|
||||
}{
|
||||
Type: "function",
|
||||
Function: t.Function,
|
||||
})
|
||||
}
|
||||
return []byte("null"), nil
|
||||
}
|
||||
|
||||
type HuggingFaceResponseFormat struct {
|
||||
Type string `json:"type"`
|
||||
JSONSchema *HuggingFaceJSONSchema `json:"json_schema,omitempty"`
|
||||
}
|
||||
|
||||
type HuggingFaceJSONSchema struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Schema json.RawMessage `json:"schema,omitempty"`
|
||||
Strict *bool `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
// # RESPONSE TYPES
|
||||
|
||||
type HuggingFaceHubError struct {
|
||||
Error string `json:"error"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type HuggingFaceResponseError struct {
|
||||
Error string `json:"error"`
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
Detail []HuggingFaceErrorDetail `json:"detail,omitempty"` // FastAPI validation errors
|
||||
}
|
||||
|
||||
type HuggingFaceErrorDetail struct {
|
||||
Loc []interface{} `json:"loc"`
|
||||
Msg string `json:"msg"`
|
||||
Type string `json:"type"`
|
||||
Ctx map[string]interface{} `json:"ctx,omitempty"`
|
||||
}
|
||||
|
||||
// # EMBEDDING TYPES
|
||||
|
||||
// HuggingFaceEmbeddingRequest represents the request format for HuggingFace embeddings API
|
||||
// Based on the HuggingFace Router API specification
|
||||
type HuggingFaceEmbeddingRequest struct {
|
||||
Input *InputsCustomType `json:"input,omitempty"` // string or []string used by all inference providers other than hf-inference
|
||||
Inputs *InputsCustomType `json:"inputs,omitempty"` // string or []string used by hf-inference provider
|
||||
Provider *string `json:"provider,omitempty"` // used by all inference providers other than hf-inference
|
||||
Model *string `json:"model,omitempty"` // used by all inference providers other than hf-inference
|
||||
Normalize *bool `json:"normalize,omitempty"`
|
||||
PromptName *string `json:"prompt_name,omitempty"`
|
||||
Truncate *bool `json:"truncate,omitempty"`
|
||||
TruncationDirection *string `json:"truncation_direction,omitempty"` // "left" or "right"
|
||||
EncodingFormat *EncodingType `json:"encoding_format,omitempty"`
|
||||
Dimensions *int `json:"dimensions,omitempty"`
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
func (req *HuggingFaceEmbeddingRequest) GetExtraParams() map[string]interface{} {
|
||||
return req.ExtraParams
|
||||
}
|
||||
|
||||
type InputsCustomType struct {
|
||||
Texts []string `json:"texts,omitempty"`
|
||||
Text *string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
func (i *InputsCustomType) UnmarshalJSON(data []byte) error {
|
||||
if len(data) == 0 || bytes.Equal(bytes.TrimSpace(data), []byte("null")) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try string
|
||||
var text string
|
||||
if err := sonic.Unmarshal(data, &text); err == nil {
|
||||
i.Text = &text
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try array
|
||||
var texts []string
|
||||
if err := sonic.Unmarshal(data, &texts); err == nil {
|
||||
i.Texts = texts
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try object
|
||||
type alias InputsCustomType
|
||||
var obj alias
|
||||
if err := sonic.Unmarshal(data, &obj); err == nil {
|
||||
*i = InputsCustomType(obj)
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to unmarshal InputsCustomType: expected string, array, or object")
|
||||
}
|
||||
|
||||
func (i InputsCustomType) MarshalJSON() ([]byte, error) {
|
||||
if len(i.Texts) > 0 {
|
||||
return providerUtils.MarshalSorted(i.Texts)
|
||||
}
|
||||
if i.Text != nil {
|
||||
return providerUtils.MarshalSorted(*i.Text)
|
||||
}
|
||||
return []byte("null"), nil
|
||||
}
|
||||
|
||||
type EncodingType string
|
||||
|
||||
const (
|
||||
EncodingTypeFloat EncodingType = "float"
|
||||
EncodingTypeBase64 EncodingType = "base64"
|
||||
)
|
||||
|
||||
// # SPEECH TYPES
|
||||
|
||||
// Speech request represents the inputs for Text To Speech inference.
|
||||
type HuggingFaceSpeechRequest struct {
|
||||
Text string `json:"text"`
|
||||
Provider string `json:"provider" validate:"required"`
|
||||
Model string `json:"model" validate:"required"`
|
||||
Parameters *HuggingFaceSpeechParameters `json:"parameters,omitempty"`
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
func (req *HuggingFaceSpeechRequest) GetExtraParams() map[string]interface{} {
|
||||
return req.ExtraParams
|
||||
}
|
||||
|
||||
// Speech parameters are additional inference parameters for Text To Speech
|
||||
type HuggingFaceSpeechParameters struct {
|
||||
GenerationParameters *HuggingFaceTranscriptionGenerationParameters `json:"generation_parameters,omitempty"`
|
||||
}
|
||||
|
||||
// Speech response represents the outputs of inference for the Text To Speech task.
|
||||
type HuggingFaceSpeechResponse struct {
|
||||
Audio HuggingFaceSpeechAudio `json:"audio"`
|
||||
}
|
||||
|
||||
// HuggingFaceSpeechAudio represents the audio object in the speech response
|
||||
type HuggingFaceSpeechAudio struct {
|
||||
URL string `json:"url"`
|
||||
ContentType string `json:"content_type"`
|
||||
FileName string `json:"file_name"`
|
||||
FileSize int `json:"file_size"`
|
||||
}
|
||||
|
||||
// # TRANSCRIPT TYPES
|
||||
|
||||
// HuggingFaceTranscriptionRequest represents the request for Automatic Speech Recognition inference
|
||||
type HuggingFaceTranscriptionRequest struct {
|
||||
Inputs []byte `json:"inputs,omitempty"` // raw audio bytes
|
||||
AudioURL string `json:"audio_url,omitempty"` // URL to audio file only needed for fal ai
|
||||
Provider *string `json:"provider,omitempty"`
|
||||
Model *string `json:"model,omitempty"`
|
||||
Parameters *HuggingFaceTranscriptionRequestParameters `json:"parameters,omitempty"`
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
func (req *HuggingFaceTranscriptionRequest) GetExtraParams() map[string]interface{} {
|
||||
return req.ExtraParams
|
||||
}
|
||||
|
||||
// HuggingFaceTranscriptionRequestParameters contains additional inference parameters for Automatic Speech Recognition
|
||||
type HuggingFaceTranscriptionRequestParameters struct {
|
||||
GenerationParameters *HuggingFaceTranscriptionGenerationParameters `json:"generation_parameters,omitempty"`
|
||||
ReturnTimestamps *bool `json:"return_timestamps,omitempty"`
|
||||
}
|
||||
|
||||
// HuggingFaceTranscriptionGenerationParameters contains parametrization of the text generation process
|
||||
type HuggingFaceTranscriptionGenerationParameters struct {
|
||||
DoSample *bool `json:"do_sample,omitempty"`
|
||||
EarlyStopping *HuggingFaceTranscriptionEarlyStopping `json:"early_stopping,omitempty"`
|
||||
EpsilonCutoff *float64 `json:"epsilon_cutoff,omitempty"`
|
||||
EtaCutoff *float64 `json:"eta_cutoff,omitempty"`
|
||||
MaxLength *int `json:"max_length,omitempty"`
|
||||
MaxNewTokens *int `json:"max_new_tokens,omitempty"`
|
||||
MinLength *int `json:"min_length,omitempty"`
|
||||
MinNewTokens *int `json:"min_new_tokens,omitempty"`
|
||||
NumBeamGroups *int `json:"num_beam_groups,omitempty"`
|
||||
NumBeams *int `json:"num_beams,omitempty"`
|
||||
PenaltyAlpha *float64 `json:"penalty_alpha,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TypicalP *float64 `json:"typical_p,omitempty"`
|
||||
UseCache *bool `json:"use_cache,omitempty"`
|
||||
}
|
||||
|
||||
// HuggingFaceTranscriptionEarlyStopping controls the stopping condition for beam-based methods
|
||||
// Can be a boolean or the string "never"
|
||||
type HuggingFaceTranscriptionEarlyStopping struct {
|
||||
BoolValue *bool
|
||||
StringValue *string
|
||||
}
|
||||
|
||||
// MarshalJSON implements custom JSON marshaling for HuggingFaceTranscriptionEarlyStopping
|
||||
func (e HuggingFaceTranscriptionEarlyStopping) MarshalJSON() ([]byte, error) {
|
||||
if e.BoolValue != nil {
|
||||
return providerUtils.MarshalSorted(*e.BoolValue)
|
||||
}
|
||||
if e.StringValue != nil {
|
||||
return providerUtils.MarshalSorted(*e.StringValue)
|
||||
}
|
||||
return []byte("null"), nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshaling for HuggingFaceTranscriptionEarlyStopping
|
||||
func (e *HuggingFaceTranscriptionEarlyStopping) UnmarshalJSON(data []byte) error {
|
||||
// Try boolean first
|
||||
var boolVal bool
|
||||
if err := sonic.Unmarshal(data, &boolVal); err == nil {
|
||||
e.BoolValue = &boolVal
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try string
|
||||
var stringVal string
|
||||
if err := sonic.Unmarshal(data, &stringVal); err == nil {
|
||||
e.StringValue = &stringVal
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("early_stopping must be a boolean or string, got: %s", string(data))
|
||||
}
|
||||
|
||||
// HuggingFaceTranscriptionResponse represents the output of Automatic Speech Recognition inference
|
||||
type HuggingFaceTranscriptionResponse struct {
|
||||
Text string `json:"text"`
|
||||
Chunks []HuggingFaceTranscriptionResponseChunk `json:"chunks,omitempty"`
|
||||
}
|
||||
|
||||
// HuggingFaceTranscriptionResponseChunk represents an audio chunk identified by the model
|
||||
type HuggingFaceTranscriptionResponseChunk struct {
|
||||
Text string `json:"text"`
|
||||
Timestamp []float64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
type HuggingFaceGenerationParameters = HuggingFaceTranscriptionGenerationParameters
|
||||
type HuggingFaceEarlyStoppingUnion = HuggingFaceTranscriptionEarlyStopping
|
||||
|
||||
// # IMAGE GENERATION TYPES
|
||||
|
||||
// HuggingFaceHFInferenceImageGenerationRequest for hf-inference image generation
|
||||
type HuggingFaceHFInferenceImageGenerationRequest struct {
|
||||
Inputs string `json:"inputs"`
|
||||
ExtraParams map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
func (req *HuggingFaceHFInferenceImageGenerationRequest) GetExtraParams() map[string]any {
|
||||
return req.ExtraParams
|
||||
}
|
||||
|
||||
// HuggingFaceFalAIImageGenerationRequest for fal-ai image generation
|
||||
type HuggingFaceFalAIImageGenerationRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
NumImages *int `json:"num_images,omitempty"`
|
||||
ResponseFormat *string `json:"response_format,omitempty"`
|
||||
ImageSize *HuggingFaceFalAISize `json:"image_size,omitempty"`
|
||||
NegativePrompt *string `json:"negative_prompt,omitempty"`
|
||||
GuidanceScale *float64 `json:"guidance_scale,omitempty"`
|
||||
NumInferenceSteps *int `json:"num_inference_steps,omitempty"`
|
||||
Seed *int `json:"seed,omitempty"`
|
||||
OutputFormat *string `json:"output_format,omitempty"`
|
||||
SyncMode *bool `json:"sync_mode,omitempty"`
|
||||
EnableSafetyChecker *bool `json:"enable_safety_checker,omitempty"`
|
||||
Acceleration *string `json:"acceleration,omitempty"`
|
||||
EnablePromptExpansion *bool `json:"enable_prompt_expansion,omitempty"`
|
||||
ExtraParams map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
func (req *HuggingFaceFalAIImageGenerationRequest) GetExtraParams() map[string]any {
|
||||
return req.ExtraParams
|
||||
}
|
||||
|
||||
type HuggingFaceFalAISize struct {
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
}
|
||||
|
||||
// HuggingFaceFalAIImageGenerationResponse for fal-ai image generation
|
||||
// Matches the API envelope structure with top-level metadata and data array
|
||||
type HuggingFaceFalAIImageGenerationResponse struct {
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
CreatedAt *int64 `json:"created_at,omitempty"`
|
||||
Data *FalAIImageData `json:"data,omitempty"`
|
||||
// Legacy flattened fields for backward compatibility
|
||||
Images []FalAIImage `json:"images,omitempty"`
|
||||
Timings *FalAITimings `json:"timings,omitempty"`
|
||||
Seed *int64 `json:"seed,omitempty"`
|
||||
HasNSFWConcepts []bool `json:"has_nsfw_concepts,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
}
|
||||
|
||||
// FalAIImageData wraps the image data in the API envelope
|
||||
type FalAIImageData struct {
|
||||
Images []FalAIImage `json:"images,omitempty"`
|
||||
Timings *FalAITimings `json:"timings,omitempty"`
|
||||
Seed *int64 `json:"seed,omitempty"`
|
||||
HasNSFWConcepts []bool `json:"has_nsfw_concepts,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
}
|
||||
|
||||
type FalAIImage struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
B64JSON string `json:"b64_json,omitempty"`
|
||||
Width int `json:"width,omitempty"`
|
||||
Height int `json:"height,omitempty"`
|
||||
ContentType string `json:"content_type,omitempty"`
|
||||
}
|
||||
|
||||
type FalAITimings struct {
|
||||
Inference float64 `json:"inference"`
|
||||
}
|
||||
|
||||
// HuggingFaceTogetherImageGenerationRequest for together image generation
|
||||
type HuggingFaceTogetherImageGenerationRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model"`
|
||||
ResponseFormat *string `json:"response_format,omitempty"`
|
||||
Size *string `json:"size,omitempty"`
|
||||
Width *int `json:"width,omitempty"`
|
||||
Height *int `json:"height,omitempty"`
|
||||
N *int `json:"n,omitempty"`
|
||||
Steps *int `json:"steps,omitempty"`
|
||||
ExtraParams map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
func (req *HuggingFaceTogetherImageGenerationRequest) GetExtraParams() map[string]any {
|
||||
return req.ExtraParams
|
||||
}
|
||||
|
||||
// HuggingFaceTogetherImageGenerationResponse for together image generation
|
||||
type HuggingFaceTogetherImageGenerationResponse struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Object string `json:"object"`
|
||||
Data []HuggingFaceTogetherImageData `json:"data"`
|
||||
}
|
||||
|
||||
type HuggingFaceTogetherImageData struct {
|
||||
B64JSON string `json:"b64_json,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Index int `json:"index"`
|
||||
Timings *HuggingFaceTogetherTimings `json:"timings,omitempty"`
|
||||
}
|
||||
|
||||
type HuggingFaceTogetherTimings struct {
|
||||
Inference float64 `json:"inference"`
|
||||
}
|
||||
|
||||
// HuggingFaceFalAIImageStreamRequest for fal-ai image generation streaming
|
||||
type HuggingFaceFalAIImageStreamRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
ResponseFormat *string `json:"response_format,omitempty"`
|
||||
NumImages *int `json:"num_images,omitempty"`
|
||||
ImageSize *HuggingFaceFalAISize `json:"image_size,omitempty"`
|
||||
GuidanceScale *float64 `json:"guidance_scale,omitempty"`
|
||||
Seed *int `json:"seed,omitempty"`
|
||||
NumInferenceSteps *int `json:"num_inference_steps,omitempty"`
|
||||
Acceleration *string `json:"acceleration,omitempty"`
|
||||
EnablePromptExpansion *bool `json:"enable_prompt_expansion,omitempty"`
|
||||
SyncMode *bool `json:"sync_mode,omitempty"`
|
||||
EnableSafetyChecker *bool `json:"enable_safety_checker,omitempty"`
|
||||
OutputFormat *string `json:"output_format,omitempty"`
|
||||
ExtraParams map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
func (req *HuggingFaceFalAIImageStreamRequest) GetExtraParams() map[string]any {
|
||||
return req.ExtraParams
|
||||
}
|
||||
|
||||
// HuggingFaceFalAIImageStreamResponse for fal-ai SSE events
|
||||
type HuggingFaceFalAIImageStreamResponse struct {
|
||||
Data *FalAIImageData `json:"data,omitempty"`
|
||||
Images []FalAIImage `json:"images,omitempty"`
|
||||
}
|
||||
|
||||
// HuggingFaceFalAIImageEditRequest for fal-ai image edit
|
||||
type HuggingFaceFalAIImageEditRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
ImageURL *string `json:"image_url,omitempty"` // For single image models
|
||||
ImageURLs []string `json:"image_urls,omitempty"` // For multi-image models
|
||||
NumImages *int `json:"num_images,omitempty"`
|
||||
ImageSize *HuggingFaceFalAISize `json:"image_size,omitempty"`
|
||||
GuidanceScale *float64 `json:"guidance_scale,omitempty"`
|
||||
NumInferenceSteps *int `json:"num_inference_steps,omitempty"`
|
||||
SyncMode *bool `json:"sync_mode,omitempty"`
|
||||
Seed *int `json:"seed,omitempty"`
|
||||
OutputFormat *string `json:"output_format,omitempty"`
|
||||
EnableSafetyChecker *bool `json:"enable_safety_checker,omitempty"`
|
||||
Acceleration *string `json:"acceleration,omitempty"`
|
||||
ExtraParams map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
func (req *HuggingFaceFalAIImageEditRequest) GetExtraParams() map[string]any {
|
||||
return req.ExtraParams
|
||||
}
|
||||
354
core/providers/huggingface/utils.go
Normal file
354
core/providers/huggingface/utils.go
Normal file
@@ -0,0 +1,354 @@
|
||||
package huggingface
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
const (
|
||||
// According to https://huggingface.co/docs/inference-providers/en/tasks/chat-completion the
|
||||
// OpenAI-compatible router lives under the /v1 prefix, so we wire that in as the default base URL.
|
||||
defaultInferenceBaseURL = "https://router.huggingface.co"
|
||||
modelHubBaseURL = "https://huggingface.co"
|
||||
|
||||
//For custom deployments, HF offers inference endpoints under
|
||||
// inferenceBaseEndpointsEndpointBaseURL = "https://api.endpoints.huggingface.cloud/v2"
|
||||
)
|
||||
|
||||
type inferenceProvider string
|
||||
|
||||
const (
|
||||
cerebras inferenceProvider = "cerebras"
|
||||
cohere inferenceProvider = "cohere"
|
||||
falAI inferenceProvider = "fal-ai"
|
||||
featherlessAI inferenceProvider = "featherless-ai"
|
||||
fireworksAI inferenceProvider = "fireworks-ai"
|
||||
groq inferenceProvider = "groq"
|
||||
hfInference inferenceProvider = "hf-inference"
|
||||
hyperbolic inferenceProvider = "hyperbolic"
|
||||
nebius inferenceProvider = "nebius"
|
||||
novita inferenceProvider = "novita"
|
||||
nscale inferenceProvider = "nscale"
|
||||
ovhcloud inferenceProvider = "ovhcloud"
|
||||
publicai inferenceProvider = "publicai"
|
||||
replicate inferenceProvider = "replicate"
|
||||
sambanova inferenceProvider = "sambanova"
|
||||
scaleway inferenceProvider = "scaleway"
|
||||
together inferenceProvider = "together"
|
||||
wavespeed inferenceProvider = "wavespeed"
|
||||
zaiOrg inferenceProvider = "zai-org"
|
||||
auto inferenceProvider = "auto"
|
||||
)
|
||||
|
||||
// List of supported inference providers (kept in sync with HF docs/JS SDK)
|
||||
var INFERENCE_PROVIDERS = []inferenceProvider{
|
||||
cerebras,
|
||||
cohere,
|
||||
falAI,
|
||||
featherlessAI,
|
||||
fireworksAI,
|
||||
groq,
|
||||
hfInference,
|
||||
hyperbolic,
|
||||
nebius,
|
||||
novita,
|
||||
nscale,
|
||||
ovhcloud,
|
||||
publicai,
|
||||
replicate,
|
||||
sambanova,
|
||||
scaleway,
|
||||
together,
|
||||
wavespeed,
|
||||
zaiOrg,
|
||||
}
|
||||
|
||||
// PROVIDERS_OR_POLICIES is the above list plus the special "auto" policy
|
||||
var PROVIDERS_OR_POLICIES = func() []inferenceProvider {
|
||||
out := make([]inferenceProvider, 0, len(INFERENCE_PROVIDERS)+1)
|
||||
out = append(out, INFERENCE_PROVIDERS...)
|
||||
out = append(out, "auto")
|
||||
return out
|
||||
}()
|
||||
|
||||
func (provider *HuggingFaceProvider) buildModelHubURL(request *schemas.BifrostListModelsRequest, inferenceProvider inferenceProvider) string {
|
||||
values := url.Values{}
|
||||
|
||||
// Add inference_provider parameter to filter models served by Hugging Face's inference provider
|
||||
// According to https://huggingface.co/docs/inference-providers/hub-api
|
||||
limit := request.PageSize
|
||||
if limit <= 0 {
|
||||
limit = defaultModelFetchLimit
|
||||
}
|
||||
if limit > maxModelFetchLimit {
|
||||
limit = maxModelFetchLimit
|
||||
}
|
||||
values.Set("limit", strconv.Itoa(limit))
|
||||
values.Set("full", "1")
|
||||
values.Set("sort", "likes")
|
||||
values.Set("direction", "-1")
|
||||
values.Set("inference_provider", string(inferenceProvider))
|
||||
|
||||
for key, value := range request.ExtraParams {
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
if typed != "" {
|
||||
values.Set(key, typed)
|
||||
}
|
||||
case fmt.Stringer:
|
||||
values.Set(key, typed.String())
|
||||
case int:
|
||||
values.Set(key, strconv.Itoa(typed))
|
||||
case float64:
|
||||
values.Set(key, strconv.FormatFloat(typed, 'f', -1, 64))
|
||||
case bool:
|
||||
values.Set(key, strconv.FormatBool(typed))
|
||||
default:
|
||||
values.Set(key, fmt.Sprintf("%v", typed))
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s/api/models?%s", modelHubBaseURL, values.Encode())
|
||||
}
|
||||
|
||||
func (provider *HuggingFaceProvider) buildModelInferenceProviderURL(modelName string) string {
|
||||
values := url.Values{}
|
||||
values.Set("expand[]", "pipeline_tag")
|
||||
values.Set("expand[]", "inferenceProviderMapping")
|
||||
return fmt.Sprintf("%s/api/models/%s?%s", modelHubBaseURL, modelName, values.Encode())
|
||||
}
|
||||
|
||||
func splitIntoModelProvider(bifrostModelName string) (inferenceProvider, string, error) {
|
||||
// Extract provider and model name
|
||||
t := strings.Count(bifrostModelName, "/")
|
||||
if t == 0 {
|
||||
return "", "", fmt.Errorf("invalid model name format: %s", bifrostModelName)
|
||||
}
|
||||
var prov inferenceProvider
|
||||
var model string
|
||||
if t > 1 {
|
||||
before, after, _ := strings.Cut(bifrostModelName, "/")
|
||||
prov = inferenceProvider(before)
|
||||
model = after
|
||||
} else if t == 1 {
|
||||
prov = ""
|
||||
model = bifrostModelName
|
||||
}
|
||||
return prov, model, nil
|
||||
}
|
||||
|
||||
// Defined for tasks given by https://huggingface.co/docs/inference-providers/en/index and makeURL logic at https://github.com/huggingface/huggingface.js/blob/c02dd89eff24593b304d72715247f7eef79b3b73/packages/inference/src/providers/providerHelper.ts#L111
|
||||
func (provider *HuggingFaceProvider) getInferenceProviderRouteURL(ctx *schemas.BifrostContext, inferenceProvider inferenceProvider, modelName string, requestType schemas.RequestType) (string, error) {
|
||||
defaultPath := ""
|
||||
switch inferenceProvider {
|
||||
case falAI:
|
||||
defaultPath = fmt.Sprintf("/fal-ai/%s", modelName)
|
||||
case hfInference:
|
||||
var pipeline string
|
||||
switch requestType {
|
||||
case schemas.EmbeddingRequest:
|
||||
pipeline = "feature-extraction"
|
||||
case schemas.SpeechRequest:
|
||||
pipeline = "text-to-speech"
|
||||
case schemas.ImageGenerationRequest:
|
||||
return provider.buildRequestURL(ctx, fmt.Sprintf("/hf-inference/models/%s", modelName), requestType), nil
|
||||
case schemas.TranscriptionRequest:
|
||||
return provider.buildRequestURL(ctx, fmt.Sprintf("/hf-inference/models/%s", modelName), requestType), nil
|
||||
default:
|
||||
pipeline = "chat-completion"
|
||||
}
|
||||
defaultPath = fmt.Sprintf("/hf-inference/models/%s/pipeline/%s", modelName, pipeline)
|
||||
case nebius:
|
||||
if requestType == schemas.EmbeddingRequest {
|
||||
defaultPath = "/nebius/v1/embeddings"
|
||||
} else if requestType == schemas.ImageGenerationRequest {
|
||||
defaultPath = "/nebius/v1/images/generations"
|
||||
} else {
|
||||
return "", fmt.Errorf("nebius provider only supports embedding and image generation requests")
|
||||
}
|
||||
case replicate:
|
||||
defaultPath = "/replicate/v1/prediction"
|
||||
case together:
|
||||
if requestType == schemas.ImageGenerationRequest {
|
||||
defaultPath = "/together/v1/images/generations"
|
||||
} else {
|
||||
return "", fmt.Errorf("together provider only supports image generation requests")
|
||||
}
|
||||
case sambanova:
|
||||
if requestType == schemas.EmbeddingRequest {
|
||||
defaultPath = "/sambanova/v1/embeddings"
|
||||
} else {
|
||||
return "", fmt.Errorf("sambanova provider only supports embedding requests")
|
||||
}
|
||||
case scaleway:
|
||||
if requestType == schemas.EmbeddingRequest {
|
||||
defaultPath = "/scaleway/v1/embeddings"
|
||||
} else {
|
||||
return "", fmt.Errorf("scaleway provider only supports embedding requests")
|
||||
}
|
||||
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported inference provider: %s for action: %s", inferenceProvider, requestType)
|
||||
}
|
||||
return provider.buildRequestURL(ctx, defaultPath, requestType), nil
|
||||
}
|
||||
|
||||
// convertToInferenceProviderMappings converts HuggingFaceInferenceProviderMappingResponse to a map of HuggingFaceInferenceProviderMapping with ProviderName as key
|
||||
func convertToInferenceProviderMappings(resp *HuggingFaceInferenceProviderMappingResponse) map[inferenceProvider]HuggingFaceInferenceProviderMapping {
|
||||
if resp == nil || resp.InferenceProviderMapping == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
mappings := make(map[inferenceProvider]HuggingFaceInferenceProviderMapping, len(resp.InferenceProviderMapping))
|
||||
for providerKey, providerInfo := range resp.InferenceProviderMapping {
|
||||
providerName := inferenceProvider(providerKey)
|
||||
mappings[providerName] = HuggingFaceInferenceProviderMapping{
|
||||
ProviderTask: providerInfo.Task,
|
||||
ProviderModelID: providerInfo.ProviderModelID,
|
||||
}
|
||||
}
|
||||
|
||||
return mappings
|
||||
}
|
||||
|
||||
func (provider *HuggingFaceProvider) getModelInferenceProviderMapping(ctx context.Context, huggingfaceModelName string) (map[inferenceProvider]HuggingFaceInferenceProviderMapping, *schemas.BifrostError) {
|
||||
// Check cache first
|
||||
if cached, ok := provider.modelProviderMappingCache.Load(huggingfaceModelName); ok {
|
||||
if mappings, ok := cached.(map[inferenceProvider]HuggingFaceInferenceProviderMapping); ok {
|
||||
|
||||
return mappings, nil
|
||||
}
|
||||
}
|
||||
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
req.SetRequestURI(provider.buildModelInferenceProviderURL(huggingfaceModelName))
|
||||
req.Header.SetMethod(http.MethodGet)
|
||||
req.Header.SetContentType("application/json")
|
||||
_, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
||||
defer wait()
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
|
||||
if resp.StatusCode() != fasthttp.StatusOK {
|
||||
var errorResp HuggingFaceHubError
|
||||
bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp)
|
||||
if bifrostErr.Error == nil {
|
||||
bifrostErr.Error = &schemas.ErrorField{}
|
||||
}
|
||||
if strings.TrimSpace(errorResp.Message) != "" {
|
||||
bifrostErr.Error.Message = errorResp.Message
|
||||
}
|
||||
return nil, bifrostErr
|
||||
}
|
||||
|
||||
body, err := providerUtils.CheckAndDecodeBody(resp)
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
||||
}
|
||||
|
||||
var mappingResp HuggingFaceInferenceProviderMappingResponse
|
||||
if err := sonic.Unmarshal(body, &mappingResp); err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
||||
}
|
||||
|
||||
mappings := convertToInferenceProviderMappings(&mappingResp)
|
||||
|
||||
// Store in cache
|
||||
if mappings != nil {
|
||||
provider.modelProviderMappingCache.Store(huggingfaceModelName, mappings)
|
||||
}
|
||||
|
||||
return mappings, nil
|
||||
}
|
||||
|
||||
// getValidatedProviderModelID fetches the inference provider mapping for a model
|
||||
// and validates that the given inferenceProvider has a mapping with the expected task.
|
||||
// On success it returns the provider-specific model id. On failure it returns a
|
||||
// BifrostError indicating the operation isn't supported for the requested
|
||||
// request type or provider.
|
||||
func (provider *HuggingFaceProvider) getValidatedProviderModelID(ctx context.Context, inferenceProvider inferenceProvider, huggingfaceModelName string, requiredTask string, requestType schemas.RequestType) (string, *schemas.BifrostError) {
|
||||
providerName := provider.GetProviderKey()
|
||||
|
||||
providerMapping, bifrostErr := provider.getModelInferenceProviderMapping(ctx, huggingfaceModelName)
|
||||
if bifrostErr != nil {
|
||||
return "", bifrostErr
|
||||
}
|
||||
|
||||
if providerMapping == nil {
|
||||
return "", providerUtils.NewUnsupportedOperationError(requestType, providerName)
|
||||
}
|
||||
|
||||
mapping, ok := providerMapping[inferenceProvider]
|
||||
if !ok || mapping.ProviderModelID == "" || mapping.ProviderTask != requiredTask {
|
||||
return "", providerUtils.NewUnsupportedOperationError(requestType, providerName)
|
||||
}
|
||||
|
||||
return mapping.ProviderModelID, nil
|
||||
}
|
||||
|
||||
// downloadAudioFromURL downloads audio data from a URL
|
||||
func (provider *HuggingFaceProvider) downloadAudioFromURL(ctx context.Context, audioURL string) ([]byte, error) {
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
req.SetRequestURI(audioURL)
|
||||
req.Header.SetMethod(http.MethodGet)
|
||||
|
||||
_, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
||||
defer wait()
|
||||
if bifrostErr != nil {
|
||||
return nil, fmt.Errorf("failed to download audio: %v", bifrostErr)
|
||||
}
|
||||
|
||||
if resp.StatusCode() != fasthttp.StatusOK {
|
||||
return nil, fmt.Errorf("failed to download audio: status=%d", resp.StatusCode())
|
||||
}
|
||||
|
||||
body, err := providerUtils.CheckAndDecodeBody(resp)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read audio data: %w", err)
|
||||
}
|
||||
|
||||
// Copy the body to avoid use-after-free
|
||||
audioCopy := append([]byte(nil), body...)
|
||||
|
||||
return audioCopy, nil
|
||||
}
|
||||
|
||||
func getMimeTypeForAudioType(audioType string) string {
|
||||
if audioType == "" {
|
||||
return "audio/mpeg"
|
||||
}
|
||||
|
||||
// Lowercase for comparison and trim parameters if present (e.g.);
|
||||
t := strings.ToLower(strings.TrimSpace(audioType))
|
||||
|
||||
// If it already starts with "audio/", normalise some known variants
|
||||
if strings.HasPrefix(t, "audio/") {
|
||||
switch t {
|
||||
case "audio/mp3":
|
||||
return "audio/mpeg"
|
||||
default:
|
||||
return t
|
||||
}
|
||||
}
|
||||
|
||||
return "audio/mpeg"
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user