first commit
This commit is contained in:
92
core/providers/cohere/models.go
Normal file
92
core/providers/cohere/models.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// CohereRerankRequest represents a Cohere rerank API request.
|
||||
type CohereRerankRequest struct {
|
||||
Model string `json:"model"`
|
||||
Query string `json:"query"`
|
||||
Documents []string `json:"documents"`
|
||||
TopN *int `json:"top_n,omitempty"`
|
||||
MaxTokensPerDoc *int `json:"max_tokens_per_doc,omitempty"`
|
||||
Priority *int `json:"priority,omitempty"`
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// GetExtraParams returns extra parameters for the rerank request.
|
||||
func (r *CohereRerankRequest) GetExtraParams() map[string]interface{} {
|
||||
return r.ExtraParams
|
||||
}
|
||||
|
||||
// CohereRerankResult represents a single result from Cohere rerank.
|
||||
type CohereRerankResult struct {
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float64 `json:"relevance_score"`
|
||||
Document json.RawMessage `json:"document,omitempty"`
|
||||
}
|
||||
|
||||
// CohereRerankResponse represents a Cohere rerank API response.
|
||||
type CohereRerankResponse struct {
|
||||
ID string `json:"id"`
|
||||
Results []CohereRerankResult `json:"results"`
|
||||
Meta *CohereRerankMeta `json:"meta,omitempty"`
|
||||
}
|
||||
|
||||
// CohereRerankMeta represents metadata in Cohere rerank response.
|
||||
type CohereRerankMeta struct {
|
||||
APIVersion *CohereEmbeddingAPIVersion `json:"api_version,omitempty"`
|
||||
BilledUnits *CohereBilledUnits `json:"billed_units,omitempty"`
|
||||
Tokens *CohereTokenUsage `json:"tokens,omitempty"`
|
||||
}
|
||||
|
||||
func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostListModelsResponse{
|
||||
Data: make([]schemas.Model, 0, len(response.Models)),
|
||||
}
|
||||
|
||||
pipeline := &providerUtils.ListModelsPipeline{
|
||||
AllowedModels: allowedModels,
|
||||
BlacklistedModels: blacklistedModels,
|
||||
Aliases: aliases,
|
||||
Unfiltered: unfiltered,
|
||||
ProviderKey: providerKey,
|
||||
MatchFns: providerUtils.DefaultMatchFns(),
|
||||
}
|
||||
if pipeline.ShouldEarlyExit() {
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
included := make(map[string]bool)
|
||||
|
||||
for _, model := range response.Models {
|
||||
// Cohere uses model.Name as the model identifier
|
||||
for _, result := range pipeline.FilterModel(model.Name) {
|
||||
entry := schemas.Model{
|
||||
ID: string(providerKey) + "/" + result.ResolvedID,
|
||||
Name: schemas.Ptr(model.Name),
|
||||
ContextLength: schemas.Ptr(int(model.ContextLength)),
|
||||
SupportedMethods: model.Endpoints,
|
||||
}
|
||||
if result.AliasValue != "" {
|
||||
entry.Alias = schemas.Ptr(result.AliasValue)
|
||||
}
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, entry)
|
||||
included[strings.ToLower(result.ResolvedID)] = true
|
||||
}
|
||||
}
|
||||
|
||||
bifrostResponse.Data = append(bifrostResponse.Data,
|
||||
pipeline.BackfillModels(included)...)
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
Reference in New Issue
Block a user