Files
bifrost/core/schemas/models.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

270 lines
8.6 KiB
Go

package schemas
import (
"encoding/base64"
"encoding/json"
"fmt"
)
// DefaultPageSize is the default page size for listing models
const DefaultPageSize = 1000
// MaxPaginationRequests is the maximum number of pagination requests to make
const MaxPaginationRequests = 20
// Structure to collect results from goroutines
type ListModelsByKeyResult struct {
Response *BifrostListModelsResponse
Err *BifrostError
KeyID string
}
// KeyStatus represents the status of model listing for a specific key
type KeyStatus struct {
KeyID string `json:"key_id"` // Empty for keyless providers
Status KeyStatusType `json:"status"` // "success", "failed"
Provider ModelProvider `json:"provider"` // Always populated
Error *BifrostError `json:"error,omitempty"`
}
// MarshalJSON implements custom JSON marshaling for KeyStatus to prevent
// circular reference: KeyStatus.Error → BifrostError.ExtraFields.KeyStatuses → KeyStatus.
func (k KeyStatus) MarshalJSON() ([]byte, error) {
type Alias KeyStatus
alias := Alias(k)
if alias.Error != nil {
errCopy := *alias.Error
errCopy.ExtraFields.KeyStatuses = nil
alias.Error = &errCopy
}
return MarshalSorted(alias)
}
type BifrostListModelsRequest struct {
Provider ModelProvider `json:"provider"`
PageSize int `json:"page_size"`
// PageToken: Token received from previous request to retrieve next page
PageToken string `json:"page_token"`
// Unfiltered: If true, the response will include all models for the provider, regardless of the allowed models (internal bifrost use only, not sent to the provider)
Unfiltered bool `json:"-"`
// ExtraParams: Additional provider-specific query parameters
// This allows for flexibility to pass any custom parameters that specific providers might support
ExtraParams map[string]interface{} `json:"-"`
}
type BifrostListModelsResponse struct {
Data []Model `json:"data"`
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
NextPageToken string `json:"next_page_token,omitempty"` // Token to retrieve next page
// Key-level status tracking for multi-key providers
KeyStatuses []KeyStatus `json:"key_statuses,omitempty"`
// Anthropic specific fields
FirstID *string `json:"-"`
LastID *string `json:"-"`
HasMore *bool `json:"-"`
}
// ApplyPagination applies offset-based pagination to a BifrostListModelsResponse.
// Uses opaque tokens with LastID validation to ensure cursor integrity.
// Returns the paginated response with properly set NextPageToken.
func (response *BifrostListModelsResponse) ApplyPagination(pageSize int, pageToken string) *BifrostListModelsResponse {
if response == nil {
return nil
}
totalItems := len(response.Data)
if pageSize <= 0 {
return response
}
cursor := decodePaginationCursor(pageToken)
offset := cursor.Offset
// Validate cursor integrity if LastID is present
if cursor.LastID != "" && !validatePaginationCursor(cursor, response.Data) {
// Invalid cursor: reset to beginning
offset = 0
}
if offset >= totalItems {
// Return empty page, no next token
return &BifrostListModelsResponse{
Data: []Model{},
ExtraFields: response.ExtraFields,
NextPageToken: "",
KeyStatuses: response.KeyStatuses,
}
}
endIndex := offset + pageSize
if endIndex > totalItems {
endIndex = totalItems
}
paginatedData := response.Data[offset:endIndex]
paginatedResponse := &BifrostListModelsResponse{
Data: paginatedData,
ExtraFields: response.ExtraFields,
KeyStatuses: response.KeyStatuses,
}
if endIndex < totalItems {
// Get the last item ID for cursor validation
var lastID string
if len(paginatedData) > 0 {
lastID = paginatedData[len(paginatedData)-1].ID
}
nextToken, err := encodePaginationCursor(endIndex, lastID)
if err == nil {
paginatedResponse.NextPageToken = nextToken
}
} else {
paginatedResponse.NextPageToken = ""
}
return paginatedResponse
}
type Model struct {
ID string `json:"id"`
CanonicalSlug *string `json:"canonical_slug,omitempty"`
Name *string `json:"name,omitempty"`
Alias *string `json:"alias,omitempty"` // Provider API identifier this model alias maps to (e.g. Azure deployment name, Bedrock ARN)
Created *int64 `json:"created,omitempty"`
ContextLength *int `json:"context_length,omitempty"`
MaxInputTokens *int `json:"max_input_tokens,omitempty"`
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
Architecture *Architecture `json:"architecture,omitempty"`
Pricing *Pricing `json:"pricing,omitempty"`
TopProvider *TopProvider `json:"top_provider,omitempty"`
PerRequestLimits *PerRequestLimits `json:"per_request_limits,omitempty"`
SupportedParameters []string `json:"supported_parameters,omitempty"`
DefaultParameters *DefaultParameters `json:"default_parameters,omitempty"`
HuggingFaceID *string `json:"hugging_face_id,omitempty"`
Description *string `json:"description,omitempty"`
OwnedBy *string `json:"owned_by,omitempty"`
SupportedMethods []string `json:"supported_methods,omitempty"`
// ProviderExtra carries opaque provider-specific data (e.g. Anthropic capabilities)
// through the Bifrost pipeline for integration reverse-conversion. Never serialized.
ProviderExtra json.RawMessage `json:"-"`
}
type Architecture struct {
Modality *string `json:"modality,omitempty"`
Tokenizer *string `json:"tokenizer,omitempty"`
InstructType *string `json:"instruct_type,omitempty"`
InputModalities []string `json:"input_modalities,omitempty"`
OutputModalities []string `json:"output_modalities,omitempty"`
}
type Pricing struct {
Prompt *string `json:"prompt,omitempty"`
Completion *string `json:"completion,omitempty"`
Request *string `json:"request,omitempty"`
Image *string `json:"image,omitempty"`
WebSearch *string `json:"web_search,omitempty"`
InternalReasoning *string `json:"internal_reasoning,omitempty"`
InputCacheRead *string `json:"input_cache_read,omitempty"`
InputCacheWrite *string `json:"input_cache_write,omitempty"`
}
type TopProvider struct {
IsModerated *bool `json:"is_moderated,omitempty"`
ContextLength *int `json:"context_length,omitempty"`
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
}
type PerRequestLimits struct {
PromptTokens *int `json:"prompt_tokens,omitempty"`
CompletionTokens *int `json:"completion_tokens,omitempty"`
}
type DefaultParameters struct {
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
}
// paginationCursor represents the internal cursor structure for pagination.
type paginationCursor struct {
Offset int `json:"o"`
LastID string `json:"l,omitempty"`
}
// encodePaginationCursor creates an opaque base64-encoded page token from cursor data.
// Returns empty string if offset is 0 or negative.
func encodePaginationCursor(offset int, lastID string) (string, error) {
if offset <= 0 {
return "", nil
}
cursor := paginationCursor{
Offset: offset,
LastID: lastID,
}
jsonData, err := MarshalSorted(cursor)
if err != nil {
return "", fmt.Errorf("failed to marshal pagination cursor: %w", err)
}
// Use URL-safe base64 encoding without padding for opaque token
encoded := base64.RawURLEncoding.EncodeToString(jsonData)
return encoded, nil
}
// decodePaginationCursor extracts cursor data from an opaque base64-encoded page token.
// Returns cursor with 0 offset for empty or invalid tokens.
func decodePaginationCursor(token string) paginationCursor {
if token == "" {
return paginationCursor{}
}
// Decode base64
decoded, err := base64.RawURLEncoding.DecodeString(token)
if err != nil {
return paginationCursor{}
}
var cursor paginationCursor
if err := Unmarshal(decoded, &cursor); err != nil {
return paginationCursor{}
}
if cursor.Offset < 0 {
return paginationCursor{}
}
return cursor
}
// validatePaginationCursor validates that the cursor matches the expected position in the data.
// Returns true if the cursor is valid, false otherwise.
func validatePaginationCursor(cursor paginationCursor, data []Model) bool {
if cursor.LastID == "" {
return true
}
if cursor.Offset <= 0 || cursor.Offset > len(data) {
return false
}
prevIndex := cursor.Offset - 1
if prevIndex >= 0 && prevIndex < len(data) {
return data[prevIndex].ID == cursor.LastID
}
return true
}