Files
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

141 lines
4.8 KiB
Go

package huggingface
import (
"fmt"
"slices"
"strings"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
)
const (
defaultModelFetchLimit = 200
maxModelFetchLimit = 1000
)
func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, inferenceProvider inferenceProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostListModelsResponse{
Data: make([]schemas.Model, 0, len(response.Models)),
}
pipeline := &providerUtils.ListModelsPipeline{
AllowedModels: allowedModels,
BlacklistedModels: blacklistedModels,
Aliases: aliases,
Unfiltered: unfiltered,
ProviderKey: providerKey,
MatchFns: providerUtils.DefaultMatchFns(),
}
if pipeline.ShouldEarlyExit() {
return bifrostResponse
}
included := make(map[string]bool)
for _, model := range response.Models {
if model.ModelID == "" {
continue
}
supported := deriveSupportedMethods(model.PipelineTag, model.Tags)
if len(supported) == 0 {
continue
}
// Aliases apply at the model level (model.ModelID), not at the compound
// "{providerKey}/{inferenceProvider}/{modelID}" level.
for _, result := range pipeline.FilterModel(model.ModelID) {
newModel := schemas.Model{
// inferenceProvider stays in the compound ID; aliases rename only the model segment
ID: fmt.Sprintf("%s/%s/%s", providerKey, inferenceProvider, result.ResolvedID),
Name: schemas.Ptr(model.ModelID),
SupportedMethods: supported,
HuggingFaceID: schemas.Ptr(model.ID),
}
if result.AliasValue != "" {
newModel.Alias = schemas.Ptr(result.AliasValue)
}
bifrostResponse.Data = append(bifrostResponse.Data, newModel)
included[strings.ToLower(result.ResolvedID)] = true
}
}
// Backfill: use standard pipeline. Note that backfilled HF entries use a simplified
// compound ID since we don't know which inferenceProvider to assign them to.
for _, m := range pipeline.BackfillModels(included) {
// Re-wrap the backfill ID to include the inferenceProvider segment
rawID := strings.TrimPrefix(m.ID, string(providerKey)+"/")
m.ID = fmt.Sprintf("%s/%s/%s", providerKey, inferenceProvider, rawID)
bifrostResponse.Data = append(bifrostResponse.Data, m)
}
return bifrostResponse
}
func deriveSupportedMethods(pipeline string, tags []string) []string {
normalized := strings.TrimSpace(strings.ToLower(pipeline))
methodsSet := map[schemas.RequestType]struct{}{}
addMethods := func(methods ...schemas.RequestType) {
for _, method := range methods {
methodsSet[method] = struct{}{}
}
}
switch normalized {
case "conversational", "chat-completion":
addMethods(schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest,
schemas.ResponsesRequest, schemas.ResponsesStreamRequest)
case "feature-extraction":
addMethods(schemas.EmbeddingRequest)
case "text-to-speech":
addMethods(schemas.SpeechRequest)
case "automatic-speech-recognition":
addMethods(schemas.TranscriptionRequest)
case "text-to-image":
addMethods(schemas.ImageGenerationRequest, schemas.ImageGenerationStreamRequest)
}
for _, tag := range tags {
tagLower := strings.ToLower(tag)
switch {
case tagLower == "text-embedding" || tagLower == "sentence-similarity" ||
tagLower == "feature-extraction" || tagLower == "embeddings" ||
tagLower == "sentence-transformers" || strings.Contains(tagLower, "embedding"):
addMethods(schemas.EmbeddingRequest)
case tagLower == "text-generation" || tagLower == "summarization" ||
tagLower == "conversational" || tagLower == "chat-completion" ||
tagLower == "text2text-generation" || tagLower == "question-answering" ||
strings.Contains(tagLower, "chat") || strings.Contains(tagLower, "completion"):
addMethods(schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest,
schemas.ResponsesRequest, schemas.ResponsesStreamRequest)
case tagLower == "text-to-speech" || tagLower == "tts" ||
strings.Contains(tagLower, "text-to-speech"):
addMethods(schemas.SpeechRequest)
case tagLower == "automatic-speech-recognition" ||
tagLower == "speech-to-text" || strings.Contains(tagLower, "speech-recognition"):
addMethods(schemas.TranscriptionRequest)
case tagLower == "text-to-image" || strings.Contains(tagLower, "image-generation"):
addMethods(schemas.ImageGenerationRequest, schemas.ImageGenerationStreamRequest)
}
}
if len(methodsSet) == 0 {
return nil
}
methods := make([]string, 0, len(methodsSet))
for method := range methodsSet {
methods = append(methods, string(method))
}
slices.Sort(methods)
return methods
}