355 lines
12 KiB
Go
355 lines
12 KiB
Go
package huggingface
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/bytedance/sonic"
|
|
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
|
schemas "github.com/maximhq/bifrost/core/schemas"
|
|
"github.com/valyala/fasthttp"
|
|
)
|
|
|
|
const (
|
|
// According to https://huggingface.co/docs/inference-providers/en/tasks/chat-completion the
|
|
// OpenAI-compatible router lives under the /v1 prefix, so we wire that in as the default base URL.
|
|
defaultInferenceBaseURL = "https://router.huggingface.co"
|
|
modelHubBaseURL = "https://huggingface.co"
|
|
|
|
//For custom deployments, HF offers inference endpoints under
|
|
// inferenceBaseEndpointsEndpointBaseURL = "https://api.endpoints.huggingface.cloud/v2"
|
|
)
|
|
|
|
type inferenceProvider string
|
|
|
|
const (
|
|
cerebras inferenceProvider = "cerebras"
|
|
cohere inferenceProvider = "cohere"
|
|
falAI inferenceProvider = "fal-ai"
|
|
featherlessAI inferenceProvider = "featherless-ai"
|
|
fireworksAI inferenceProvider = "fireworks-ai"
|
|
groq inferenceProvider = "groq"
|
|
hfInference inferenceProvider = "hf-inference"
|
|
hyperbolic inferenceProvider = "hyperbolic"
|
|
nebius inferenceProvider = "nebius"
|
|
novita inferenceProvider = "novita"
|
|
nscale inferenceProvider = "nscale"
|
|
ovhcloud inferenceProvider = "ovhcloud"
|
|
publicai inferenceProvider = "publicai"
|
|
replicate inferenceProvider = "replicate"
|
|
sambanova inferenceProvider = "sambanova"
|
|
scaleway inferenceProvider = "scaleway"
|
|
together inferenceProvider = "together"
|
|
wavespeed inferenceProvider = "wavespeed"
|
|
zaiOrg inferenceProvider = "zai-org"
|
|
auto inferenceProvider = "auto"
|
|
)
|
|
|
|
// List of supported inference providers (kept in sync with HF docs/JS SDK)
|
|
var INFERENCE_PROVIDERS = []inferenceProvider{
|
|
cerebras,
|
|
cohere,
|
|
falAI,
|
|
featherlessAI,
|
|
fireworksAI,
|
|
groq,
|
|
hfInference,
|
|
hyperbolic,
|
|
nebius,
|
|
novita,
|
|
nscale,
|
|
ovhcloud,
|
|
publicai,
|
|
replicate,
|
|
sambanova,
|
|
scaleway,
|
|
together,
|
|
wavespeed,
|
|
zaiOrg,
|
|
}
|
|
|
|
// PROVIDERS_OR_POLICIES is the above list plus the special "auto" policy
|
|
var PROVIDERS_OR_POLICIES = func() []inferenceProvider {
|
|
out := make([]inferenceProvider, 0, len(INFERENCE_PROVIDERS)+1)
|
|
out = append(out, INFERENCE_PROVIDERS...)
|
|
out = append(out, "auto")
|
|
return out
|
|
}()
|
|
|
|
func (provider *HuggingFaceProvider) buildModelHubURL(request *schemas.BifrostListModelsRequest, inferenceProvider inferenceProvider) string {
|
|
values := url.Values{}
|
|
|
|
// Add inference_provider parameter to filter models served by Hugging Face's inference provider
|
|
// According to https://huggingface.co/docs/inference-providers/hub-api
|
|
limit := request.PageSize
|
|
if limit <= 0 {
|
|
limit = defaultModelFetchLimit
|
|
}
|
|
if limit > maxModelFetchLimit {
|
|
limit = maxModelFetchLimit
|
|
}
|
|
values.Set("limit", strconv.Itoa(limit))
|
|
values.Set("full", "1")
|
|
values.Set("sort", "likes")
|
|
values.Set("direction", "-1")
|
|
values.Set("inference_provider", string(inferenceProvider))
|
|
|
|
for key, value := range request.ExtraParams {
|
|
switch typed := value.(type) {
|
|
case string:
|
|
if typed != "" {
|
|
values.Set(key, typed)
|
|
}
|
|
case fmt.Stringer:
|
|
values.Set(key, typed.String())
|
|
case int:
|
|
values.Set(key, strconv.Itoa(typed))
|
|
case float64:
|
|
values.Set(key, strconv.FormatFloat(typed, 'f', -1, 64))
|
|
case bool:
|
|
values.Set(key, strconv.FormatBool(typed))
|
|
default:
|
|
values.Set(key, fmt.Sprintf("%v", typed))
|
|
}
|
|
}
|
|
|
|
return fmt.Sprintf("%s/api/models?%s", modelHubBaseURL, values.Encode())
|
|
}
|
|
|
|
func (provider *HuggingFaceProvider) buildModelInferenceProviderURL(modelName string) string {
|
|
values := url.Values{}
|
|
values.Set("expand[]", "pipeline_tag")
|
|
values.Set("expand[]", "inferenceProviderMapping")
|
|
return fmt.Sprintf("%s/api/models/%s?%s", modelHubBaseURL, modelName, values.Encode())
|
|
}
|
|
|
|
func splitIntoModelProvider(bifrostModelName string) (inferenceProvider, string, error) {
|
|
// Extract provider and model name
|
|
t := strings.Count(bifrostModelName, "/")
|
|
if t == 0 {
|
|
return "", "", fmt.Errorf("invalid model name format: %s", bifrostModelName)
|
|
}
|
|
var prov inferenceProvider
|
|
var model string
|
|
if t > 1 {
|
|
before, after, _ := strings.Cut(bifrostModelName, "/")
|
|
prov = inferenceProvider(before)
|
|
model = after
|
|
} else if t == 1 {
|
|
prov = ""
|
|
model = bifrostModelName
|
|
}
|
|
return prov, model, nil
|
|
}
|
|
|
|
// Defined for tasks given by https://huggingface.co/docs/inference-providers/en/index and makeURL logic at https://github.com/huggingface/huggingface.js/blob/c02dd89eff24593b304d72715247f7eef79b3b73/packages/inference/src/providers/providerHelper.ts#L111
|
|
func (provider *HuggingFaceProvider) getInferenceProviderRouteURL(ctx *schemas.BifrostContext, inferenceProvider inferenceProvider, modelName string, requestType schemas.RequestType) (string, error) {
|
|
defaultPath := ""
|
|
switch inferenceProvider {
|
|
case falAI:
|
|
defaultPath = fmt.Sprintf("/fal-ai/%s", modelName)
|
|
case hfInference:
|
|
var pipeline string
|
|
switch requestType {
|
|
case schemas.EmbeddingRequest:
|
|
pipeline = "feature-extraction"
|
|
case schemas.SpeechRequest:
|
|
pipeline = "text-to-speech"
|
|
case schemas.ImageGenerationRequest:
|
|
return provider.buildRequestURL(ctx, fmt.Sprintf("/hf-inference/models/%s", modelName), requestType), nil
|
|
case schemas.TranscriptionRequest:
|
|
return provider.buildRequestURL(ctx, fmt.Sprintf("/hf-inference/models/%s", modelName), requestType), nil
|
|
default:
|
|
pipeline = "chat-completion"
|
|
}
|
|
defaultPath = fmt.Sprintf("/hf-inference/models/%s/pipeline/%s", modelName, pipeline)
|
|
case nebius:
|
|
if requestType == schemas.EmbeddingRequest {
|
|
defaultPath = "/nebius/v1/embeddings"
|
|
} else if requestType == schemas.ImageGenerationRequest {
|
|
defaultPath = "/nebius/v1/images/generations"
|
|
} else {
|
|
return "", fmt.Errorf("nebius provider only supports embedding and image generation requests")
|
|
}
|
|
case replicate:
|
|
defaultPath = "/replicate/v1/prediction"
|
|
case together:
|
|
if requestType == schemas.ImageGenerationRequest {
|
|
defaultPath = "/together/v1/images/generations"
|
|
} else {
|
|
return "", fmt.Errorf("together provider only supports image generation requests")
|
|
}
|
|
case sambanova:
|
|
if requestType == schemas.EmbeddingRequest {
|
|
defaultPath = "/sambanova/v1/embeddings"
|
|
} else {
|
|
return "", fmt.Errorf("sambanova provider only supports embedding requests")
|
|
}
|
|
case scaleway:
|
|
if requestType == schemas.EmbeddingRequest {
|
|
defaultPath = "/scaleway/v1/embeddings"
|
|
} else {
|
|
return "", fmt.Errorf("scaleway provider only supports embedding requests")
|
|
}
|
|
|
|
default:
|
|
return "", fmt.Errorf("unsupported inference provider: %s for action: %s", inferenceProvider, requestType)
|
|
}
|
|
return provider.buildRequestURL(ctx, defaultPath, requestType), nil
|
|
}
|
|
|
|
// convertToInferenceProviderMappings converts HuggingFaceInferenceProviderMappingResponse to a map of HuggingFaceInferenceProviderMapping with ProviderName as key
|
|
func convertToInferenceProviderMappings(resp *HuggingFaceInferenceProviderMappingResponse) map[inferenceProvider]HuggingFaceInferenceProviderMapping {
|
|
if resp == nil || resp.InferenceProviderMapping == nil {
|
|
return nil
|
|
}
|
|
|
|
mappings := make(map[inferenceProvider]HuggingFaceInferenceProviderMapping, len(resp.InferenceProviderMapping))
|
|
for providerKey, providerInfo := range resp.InferenceProviderMapping {
|
|
providerName := inferenceProvider(providerKey)
|
|
mappings[providerName] = HuggingFaceInferenceProviderMapping{
|
|
ProviderTask: providerInfo.Task,
|
|
ProviderModelID: providerInfo.ProviderModelID,
|
|
}
|
|
}
|
|
|
|
return mappings
|
|
}
|
|
|
|
func (provider *HuggingFaceProvider) getModelInferenceProviderMapping(ctx context.Context, huggingfaceModelName string) (map[inferenceProvider]HuggingFaceInferenceProviderMapping, *schemas.BifrostError) {
|
|
// Check cache first
|
|
if cached, ok := provider.modelProviderMappingCache.Load(huggingfaceModelName); ok {
|
|
if mappings, ok := cached.(map[inferenceProvider]HuggingFaceInferenceProviderMapping); ok {
|
|
|
|
return mappings, nil
|
|
}
|
|
}
|
|
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
req.SetRequestURI(provider.buildModelInferenceProviderURL(huggingfaceModelName))
|
|
req.Header.SetMethod(http.MethodGet)
|
|
req.Header.SetContentType("application/json")
|
|
_, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
var errorResp HuggingFaceHubError
|
|
bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp)
|
|
if bifrostErr.Error == nil {
|
|
bifrostErr.Error = &schemas.ErrorField{}
|
|
}
|
|
if strings.TrimSpace(errorResp.Message) != "" {
|
|
bifrostErr.Error.Message = errorResp.Message
|
|
}
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
body, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
|
}
|
|
|
|
var mappingResp HuggingFaceInferenceProviderMappingResponse
|
|
if err := sonic.Unmarshal(body, &mappingResp); err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
|
}
|
|
|
|
mappings := convertToInferenceProviderMappings(&mappingResp)
|
|
|
|
// Store in cache
|
|
if mappings != nil {
|
|
provider.modelProviderMappingCache.Store(huggingfaceModelName, mappings)
|
|
}
|
|
|
|
return mappings, nil
|
|
}
|
|
|
|
// getValidatedProviderModelID fetches the inference provider mapping for a model
|
|
// and validates that the given inferenceProvider has a mapping with the expected task.
|
|
// On success it returns the provider-specific model id. On failure it returns a
|
|
// BifrostError indicating the operation isn't supported for the requested
|
|
// request type or provider.
|
|
func (provider *HuggingFaceProvider) getValidatedProviderModelID(ctx context.Context, inferenceProvider inferenceProvider, huggingfaceModelName string, requiredTask string, requestType schemas.RequestType) (string, *schemas.BifrostError) {
|
|
providerName := provider.GetProviderKey()
|
|
|
|
providerMapping, bifrostErr := provider.getModelInferenceProviderMapping(ctx, huggingfaceModelName)
|
|
if bifrostErr != nil {
|
|
return "", bifrostErr
|
|
}
|
|
|
|
if providerMapping == nil {
|
|
return "", providerUtils.NewUnsupportedOperationError(requestType, providerName)
|
|
}
|
|
|
|
mapping, ok := providerMapping[inferenceProvider]
|
|
if !ok || mapping.ProviderModelID == "" || mapping.ProviderTask != requiredTask {
|
|
return "", providerUtils.NewUnsupportedOperationError(requestType, providerName)
|
|
}
|
|
|
|
return mapping.ProviderModelID, nil
|
|
}
|
|
|
|
// downloadAudioFromURL downloads audio data from a URL
|
|
func (provider *HuggingFaceProvider) downloadAudioFromURL(ctx context.Context, audioURL string) ([]byte, error) {
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
req.SetRequestURI(audioURL)
|
|
req.Header.SetMethod(http.MethodGet)
|
|
|
|
_, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, fmt.Errorf("failed to download audio: %v", bifrostErr)
|
|
}
|
|
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
return nil, fmt.Errorf("failed to download audio: status=%d", resp.StatusCode())
|
|
}
|
|
|
|
body, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read audio data: %w", err)
|
|
}
|
|
|
|
// Copy the body to avoid use-after-free
|
|
audioCopy := append([]byte(nil), body...)
|
|
|
|
return audioCopy, nil
|
|
}
|
|
|
|
func getMimeTypeForAudioType(audioType string) string {
|
|
if audioType == "" {
|
|
return "audio/mpeg"
|
|
}
|
|
|
|
// Lowercase for comparison and trim parameters if present (e.g.);
|
|
t := strings.ToLower(strings.TrimSpace(audioType))
|
|
|
|
// If it already starts with "audio/", normalise some known variants
|
|
if strings.HasPrefix(t, "audio/") {
|
|
switch t {
|
|
case "audio/mp3":
|
|
return "audio/mpeg"
|
|
default:
|
|
return t
|
|
}
|
|
}
|
|
|
|
return "audio/mpeg"
|
|
|
|
}
|