first commit
This commit is contained in:
315
core/providers/replicate/chat.go
Normal file
315
core/providers/replicate/chat.go
Normal file
@@ -0,0 +1,315 @@
|
||||
package replicate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// unsupportedSystemPromptModels is a set of models that don't support the system_prompt field.
|
||||
var unsupportedSystemPromptModels = []string{
|
||||
"meta/meta-llama-3-8b",
|
||||
"meta/llama-2-70b",
|
||||
"openai/gpt-oss-20b",
|
||||
"openai/o1-mini",
|
||||
"xai/grok-4",
|
||||
}
|
||||
|
||||
func ToReplicateChatRequest(bifrostReq *schemas.BifrostChatRequest) (*ReplicatePredictionRequest, error) {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil {
|
||||
return nil, fmt.Errorf("bifrost request is nil or input is nil")
|
||||
}
|
||||
|
||||
// Build the input from messages
|
||||
input := &ReplicatePredictionRequestInput{}
|
||||
|
||||
isGPT5Structured := strings.HasPrefix(bifrostReq.Model, string(schemas.OpenAI)) && strings.Contains(bifrostReq.Model, "gpt-5-structured")
|
||||
|
||||
// openai models support messages
|
||||
if len(bifrostReq.Input) > 0 && strings.HasPrefix(bifrostReq.Model, string(schemas.OpenAI)) {
|
||||
if isGPT5Structured {
|
||||
responsesMessages := []schemas.ResponsesMessage{}
|
||||
for _, msg := range bifrostReq.Input {
|
||||
responsesMessages = append(responsesMessages, msg.ToResponsesMessages()...)
|
||||
}
|
||||
if len(responsesMessages) > 0 {
|
||||
input.InputItemList = responsesMessages
|
||||
}
|
||||
} else {
|
||||
input.Messages = bifrostReq.Input
|
||||
}
|
||||
} else {
|
||||
// Extract system prompt and build conversation prompt
|
||||
var systemPrompt string
|
||||
var conversationParts []string
|
||||
var imageInput []string
|
||||
|
||||
for _, msg := range bifrostReq.Input {
|
||||
if msg.Content == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get message content as string
|
||||
var contentStr string
|
||||
if msg.Content.ContentStr != nil {
|
||||
contentStr = *msg.Content.ContentStr
|
||||
} else if msg.Content.ContentBlocks != nil {
|
||||
// Concatenate text blocks only
|
||||
var textParts []string
|
||||
for _, block := range msg.Content.ContentBlocks {
|
||||
if block.Text != nil && *block.Text != "" {
|
||||
textParts = append(textParts, *block.Text)
|
||||
}
|
||||
if block.ImageURLStruct != nil && block.ImageURLStruct.URL != "" {
|
||||
imageInput = append(imageInput, block.ImageURLStruct.URL)
|
||||
}
|
||||
}
|
||||
contentStr = strings.Join(textParts, "\n")
|
||||
}
|
||||
|
||||
if contentStr == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle different roles
|
||||
switch msg.Role {
|
||||
case schemas.ChatMessageRoleSystem:
|
||||
if systemPrompt == "" {
|
||||
systemPrompt = contentStr
|
||||
} else {
|
||||
systemPrompt += "\n" + contentStr
|
||||
}
|
||||
case schemas.ChatMessageRoleUser:
|
||||
conversationParts = append(conversationParts, contentStr)
|
||||
case schemas.ChatMessageRoleAssistant:
|
||||
// For assistant messages, we can include them in the conversation context
|
||||
conversationParts = append(conversationParts, contentStr)
|
||||
}
|
||||
}
|
||||
|
||||
// Set system prompt if present and model supports it
|
||||
modelSupportsSystemPrompt := supportsSystemPrompt(bifrostReq.Model)
|
||||
|
||||
if systemPrompt != "" {
|
||||
if modelSupportsSystemPrompt {
|
||||
// Model supports system_prompt field
|
||||
input.SystemPrompt = &systemPrompt
|
||||
} else {
|
||||
// Model doesn't support system_prompt - prepend to prompt
|
||||
if len(conversationParts) > 0 {
|
||||
// Prepend system prompt to conversation
|
||||
conversationParts = append([]string{systemPrompt}, conversationParts...)
|
||||
} else {
|
||||
// No conversation parts, use system prompt as the prompt
|
||||
conversationParts = []string{systemPrompt}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build the final prompt from conversation parts
|
||||
if len(conversationParts) > 0 {
|
||||
prompt := strings.Join(conversationParts, "\n\n")
|
||||
input.Prompt = &prompt
|
||||
}
|
||||
|
||||
// Ensure we have at least some content (prompt or system prompt)
|
||||
if input.Prompt == nil && input.SystemPrompt == nil {
|
||||
return nil, fmt.Errorf("no content found in chat messages - need at least one user or system message")
|
||||
}
|
||||
|
||||
if len(imageInput) > 0 {
|
||||
input.ImageInput = imageInput
|
||||
}
|
||||
}
|
||||
|
||||
// Map parameters if present
|
||||
if bifrostReq.Params != nil {
|
||||
params := bifrostReq.Params
|
||||
|
||||
// Temperature
|
||||
if params.Temperature != nil {
|
||||
input.Temperature = params.Temperature
|
||||
}
|
||||
|
||||
// Top P
|
||||
if params.TopP != nil {
|
||||
input.TopP = params.TopP
|
||||
}
|
||||
|
||||
// Max tokens - use max_completion_tokens if available
|
||||
if params.MaxCompletionTokens != nil {
|
||||
if isGPT5Structured {
|
||||
input.MaxOutputTokens = params.MaxCompletionTokens
|
||||
} else if strings.HasPrefix(bifrostReq.Model, string(schemas.OpenAI)) {
|
||||
input.MaxCompletionTokens = params.MaxCompletionTokens
|
||||
} else {
|
||||
input.MaxTokens = params.MaxCompletionTokens
|
||||
}
|
||||
}
|
||||
|
||||
// Presence penalty
|
||||
if params.PresencePenalty != nil {
|
||||
input.PresencePenalty = params.PresencePenalty
|
||||
}
|
||||
|
||||
// Frequency penalty
|
||||
if params.FrequencyPenalty != nil {
|
||||
input.FrequencyPenalty = params.FrequencyPenalty
|
||||
}
|
||||
|
||||
// Seed
|
||||
if params.Seed != nil {
|
||||
input.Seed = params.Seed
|
||||
}
|
||||
|
||||
if params.Reasoning != nil {
|
||||
if params.Reasoning.Effort != nil {
|
||||
input.ReasoningEffort = params.Reasoning.Effort
|
||||
}
|
||||
}
|
||||
|
||||
if isGPT5Structured {
|
||||
if len(params.Tools) > 0 {
|
||||
responsesTools := []schemas.ResponsesTool{}
|
||||
for _, tool := range params.Tools {
|
||||
responsesTools = append(responsesTools, *tool.ToResponsesTool())
|
||||
}
|
||||
if len(responsesTools) > 0 {
|
||||
input.Tools = responsesTools
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if params.ExtraParams != nil {
|
||||
input.ExtraParams = params.ExtraParams
|
||||
}
|
||||
}
|
||||
|
||||
// Check if model is a version ID and set version field accordingly
|
||||
req := &ReplicatePredictionRequest{
|
||||
Input: input,
|
||||
}
|
||||
|
||||
if isVersionID(bifrostReq.Model) {
|
||||
req.Version = &bifrostReq.Model
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil {
|
||||
if webhook, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["webhook"]); ok {
|
||||
req.Webhook = webhook
|
||||
}
|
||||
if webhookEventsFilter, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["webhook_events_filter"]); ok {
|
||||
req.WebhookEventsFilter = webhookEventsFilter
|
||||
}
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// ToBifrostChatResponse converts a Replicate prediction response to Bifrost format
|
||||
func (response *ReplicatePredictionResponse) ToBifrostChatResponse() *schemas.BifrostChatResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse timestamps
|
||||
createdAt := ParseReplicateTimestamp(response.CreatedAt)
|
||||
if createdAt == 0 {
|
||||
createdAt = time.Now().Unix()
|
||||
}
|
||||
|
||||
// Initialize Bifrost response
|
||||
bifrostResponse := &schemas.BifrostChatResponse{
|
||||
ID: response.ID,
|
||||
Model: response.Model,
|
||||
Object: "chat.completion",
|
||||
Created: int(createdAt),
|
||||
}
|
||||
|
||||
// Convert output to content
|
||||
var contentStr *string
|
||||
if response.Output != nil {
|
||||
if response.Output.OutputStr != nil {
|
||||
contentStr = response.Output.OutputStr
|
||||
} else if response.Output.OutputArray != nil {
|
||||
// Join array of strings into a single string
|
||||
joined := strings.Join(response.Output.OutputArray, "")
|
||||
contentStr = &joined
|
||||
} else if response.Output.OutputObject != nil && response.Output.OutputObject.Text != nil {
|
||||
contentStr = response.Output.OutputObject.Text
|
||||
}
|
||||
}
|
||||
|
||||
// Create message content
|
||||
messageContent := schemas.ChatMessageContent{
|
||||
ContentStr: contentStr,
|
||||
}
|
||||
|
||||
// Create the assistant message
|
||||
message := schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
Content: &messageContent,
|
||||
}
|
||||
|
||||
// Determine finish reason based on status
|
||||
var finishReason *string
|
||||
switch response.Status {
|
||||
case ReplicatePredictionStatusSucceeded:
|
||||
reason := "stop"
|
||||
finishReason = &reason
|
||||
case ReplicatePredictionStatusFailed:
|
||||
reason := "error"
|
||||
finishReason = &reason
|
||||
case ReplicatePredictionStatusCanceled:
|
||||
reason := "stop"
|
||||
finishReason = &reason
|
||||
}
|
||||
|
||||
// Create choice
|
||||
choice := schemas.BifrostResponseChoice{
|
||||
Index: 0,
|
||||
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
|
||||
Message: &message,
|
||||
},
|
||||
FinishReason: finishReason,
|
||||
}
|
||||
|
||||
bifrostResponse.Choices = []schemas.BifrostResponseChoice{choice}
|
||||
|
||||
// Extract usage information from logs
|
||||
if response.Logs != nil {
|
||||
inputTokens, outputTokens, totalTokens, found := parseTokenUsageFromLogs(response.Logs, schemas.ChatCompletionRequest)
|
||||
if found {
|
||||
bifrostResponse.Usage = &schemas.BifrostLLMUsage{
|
||||
PromptTokens: inputTokens,
|
||||
CompletionTokens: outputTokens,
|
||||
TotalTokens: totalTokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
// supportsSystemPrompt checks if a model supports the system_prompt field.
|
||||
func supportsSystemPrompt(model string) bool {
|
||||
// Normalize model name to lowercase for comparison
|
||||
modelLower := strings.ToLower(model)
|
||||
|
||||
// Extract model identifier (handle both "owner/name" and "owner/name:version" formats)
|
||||
modelIdentifier := modelLower
|
||||
if idx := strings.Index(modelLower, ":"); idx != -1 {
|
||||
modelIdentifier = modelLower[:idx]
|
||||
}
|
||||
|
||||
// All deepseek models don't support system prompt
|
||||
if strings.HasPrefix(modelIdentifier, "deepseek-ai/deepseek") {
|
||||
return false
|
||||
}
|
||||
|
||||
isUnsupported := slices.Contains(unsupportedSystemPromptModels, modelIdentifier)
|
||||
return !isUnsupported
|
||||
}
|
||||
29
core/providers/replicate/errors.go
Normal file
29
core/providers/replicate/errors.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package replicate
|
||||
|
||||
import (
|
||||
"github.com/bytedance/sonic"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// parseReplicateError parses Replicate API error response
|
||||
func parseReplicateError(body []byte, statusCode int) *schemas.BifrostError {
|
||||
var replicateErr ReplicateError
|
||||
if err := sonic.Unmarshal(body, &replicateErr); err == nil && replicateErr.Detail != "" {
|
||||
return &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
StatusCode: &statusCode,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: replicateErr.Detail,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to generic error
|
||||
return &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
StatusCode: &statusCode,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: string(body),
|
||||
},
|
||||
}
|
||||
}
|
||||
89
core/providers/replicate/files.go
Normal file
89
core/providers/replicate/files.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package replicate
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// Replicate File API Converters
|
||||
|
||||
// ToBifrostFileStatus converts Replicate file status to Bifrost file status.
|
||||
// Replicate doesn't explicitly provide status, so we infer from the response.
|
||||
func ToBifrostFileStatus(fileResp *ReplicateFileResponse) schemas.FileStatus {
|
||||
// If file has all required fields and is accessible, it's processed
|
||||
if fileResp.ID != "" && fileResp.Size > 0 {
|
||||
return schemas.FileStatusProcessed
|
||||
}
|
||||
return schemas.FileStatusUploaded
|
||||
}
|
||||
|
||||
// ToBifrostFileUploadResponse converts Replicate file response to Bifrost file upload response.
|
||||
func (r *ReplicateFileResponse) ToBifrostFileUploadResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileUploadResponse {
|
||||
resp := &schemas.BifrostFileUploadResponse{
|
||||
ID: r.ID,
|
||||
Object: "file",
|
||||
Bytes: r.Size,
|
||||
CreatedAt: ParseReplicateTimestamp(r.CreatedAt),
|
||||
Filename: r.Name,
|
||||
Purpose: schemas.FilePurposeBatch, // Replicate uses files primarily for batch/general purposes
|
||||
Status: ToBifrostFileStatus(r),
|
||||
StorageBackend: schemas.FileStorageAPI,
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
Latency: latency.Milliseconds(),
|
||||
},
|
||||
}
|
||||
|
||||
// Add ExpiresAt if present
|
||||
if r.ExpiresAt != "" {
|
||||
expiresAt := ParseReplicateTimestamp(r.ExpiresAt)
|
||||
if expiresAt > 0 {
|
||||
resp.ExpiresAt = &expiresAt
|
||||
}
|
||||
}
|
||||
|
||||
if sendBackRawRequest {
|
||||
resp.ExtraFields.RawRequest = rawRequest
|
||||
}
|
||||
|
||||
if sendBackRawResponse {
|
||||
resp.ExtraFields.RawResponse = rawResponse
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// ToBifrostFileRetrieveResponse converts Replicate file response to Bifrost file retrieve response.
|
||||
func (r *ReplicateFileResponse) ToBifrostFileRetrieveResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileRetrieveResponse {
|
||||
resp := &schemas.BifrostFileRetrieveResponse{
|
||||
ID: r.ID,
|
||||
Object: "file",
|
||||
Bytes: r.Size,
|
||||
CreatedAt: ParseReplicateTimestamp(r.CreatedAt),
|
||||
Filename: r.Name,
|
||||
Purpose: schemas.FilePurposeBatch,
|
||||
Status: ToBifrostFileStatus(r),
|
||||
StorageBackend: schemas.FileStorageAPI,
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
Latency: latency.Milliseconds(),
|
||||
},
|
||||
}
|
||||
|
||||
// Add ExpiresAt if present
|
||||
if r.ExpiresAt != "" {
|
||||
expiresAt := ParseReplicateTimestamp(r.ExpiresAt)
|
||||
if expiresAt > 0 {
|
||||
resp.ExpiresAt = &expiresAt
|
||||
}
|
||||
}
|
||||
|
||||
if sendBackRawRequest {
|
||||
resp.ExtraFields.RawRequest = rawRequest
|
||||
}
|
||||
|
||||
if sendBackRawResponse {
|
||||
resp.ExtraFields.RawResponse = rawResponse
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
292
core/providers/replicate/images.go
Normal file
292
core/providers/replicate/images.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package replicate
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// modelInputImageFieldMap maps model identifiers to their input image field names.
|
||||
var modelInputImageFieldMap = map[string]string{
|
||||
// image_prompt models
|
||||
"black-forest-labs/flux-1.1-pro": "image_prompt",
|
||||
"black-forest-labs/flux-1.1-pro-ultra": "image_prompt",
|
||||
"black-forest-labs/flux-pro": "image_prompt",
|
||||
"black-forest-labs/flux-1.1-pro-ultra-finetuned": "image_prompt",
|
||||
|
||||
// input_image models (kontext variants)
|
||||
"black-forest-labs/flux-kontext-pro": "input_image",
|
||||
"black-forest-labs/flux-kontext-max": "input_image",
|
||||
"black-forest-labs/flux-kontext-dev": "input_image",
|
||||
|
||||
// image models
|
||||
"black-forest-labs/flux-dev": "image",
|
||||
"black-forest-labs/flux-fill-pro": "image",
|
||||
"black-forest-labs/flux-dev-lora": "image",
|
||||
"black-forest-labs/flux-krea-dev": "image",
|
||||
}
|
||||
|
||||
// ToReplicateImageGenerationInput converts a Bifrost image generation request to Replicate prediction input
|
||||
func ToReplicateImageGenerationInput(bifrostReq *schemas.BifrostImageGenerationRequest) *ReplicatePredictionRequest {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
input := &ReplicatePredictionRequestInput{
|
||||
Prompt: &bifrostReq.Input.Prompt,
|
||||
}
|
||||
|
||||
// Map parameters if available
|
||||
if bifrostReq.Params != nil {
|
||||
params := bifrostReq.Params
|
||||
|
||||
if bifrostReq.Params.N != nil {
|
||||
input.NumberOfImages = bifrostReq.Params.N
|
||||
}
|
||||
|
||||
if params.AspectRatio != nil {
|
||||
input.AspectRatio = params.AspectRatio
|
||||
}
|
||||
|
||||
if params.Size != nil {
|
||||
aspectRatio, imageSize := providerUtils.ConvertSizeToAspectRatioAndResolution(*params.Size)
|
||||
_, hasExplicitResolution := params.ExtraParams["resolution"]
|
||||
if params.AspectRatio == nil && aspectRatio != "" {
|
||||
input.AspectRatio = &aspectRatio
|
||||
}
|
||||
if imageSize != "" && !hasExplicitResolution {
|
||||
input.Resolution = &imageSize
|
||||
}
|
||||
}
|
||||
|
||||
// Map OutputFormat
|
||||
if params.OutputFormat != nil {
|
||||
input.OutputFormat = params.OutputFormat
|
||||
}
|
||||
|
||||
if params.Quality != nil {
|
||||
input.Quality = params.Quality
|
||||
}
|
||||
|
||||
if params.Background != nil {
|
||||
input.Background = params.Background
|
||||
}
|
||||
|
||||
// Map Seed
|
||||
if params.Seed != nil {
|
||||
input.Seed = params.Seed
|
||||
}
|
||||
|
||||
// Map NegativePrompt
|
||||
if params.NegativePrompt != nil {
|
||||
input.NegativePrompt = params.NegativePrompt
|
||||
}
|
||||
|
||||
// Map NumInferenceSteps
|
||||
if params.NumInferenceSteps != nil {
|
||||
input.NumInferenceStep = params.NumInferenceSteps
|
||||
}
|
||||
|
||||
if params.ExtraParams != nil {
|
||||
input.ExtraParams = params.ExtraParams
|
||||
}
|
||||
}
|
||||
|
||||
request := &ReplicatePredictionRequest{
|
||||
Input: input,
|
||||
}
|
||||
|
||||
// Check if model is a version ID and set version field accordingly
|
||||
if isVersionID(bifrostReq.Model) {
|
||||
request.Version = &bifrostReq.Model
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil {
|
||||
if webhook, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["webhook"]); ok {
|
||||
request.Webhook = webhook
|
||||
}
|
||||
if webhookEventsFilter, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["webhook_events_filter"]); ok {
|
||||
request.WebhookEventsFilter = webhookEventsFilter
|
||||
}
|
||||
}
|
||||
|
||||
return request
|
||||
}
|
||||
|
||||
// ToBifrostImageGenerationResponse converts a Replicate prediction response to Bifrost format
|
||||
func ToBifrostImageGenerationResponse(
|
||||
prediction *ReplicatePredictionResponse,
|
||||
) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
if prediction == nil {
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: true,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "prediction response is nil",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
response := &schemas.BifrostImageGenerationResponse{
|
||||
ID: prediction.ID,
|
||||
Created: ParseReplicateTimestamp(prediction.CreatedAt),
|
||||
Model: prediction.Model,
|
||||
Data: []schemas.ImageData{},
|
||||
}
|
||||
|
||||
// Convert output to ImageData
|
||||
// Replicate output can be either a string (single URL) or array of strings
|
||||
if prediction.Output != nil {
|
||||
if prediction.Output.OutputStr != nil && *prediction.Output.OutputStr != "" {
|
||||
response.Data = append(response.Data, schemas.ImageData{
|
||||
URL: *prediction.Output.OutputStr,
|
||||
Index: 0,
|
||||
})
|
||||
} else if len(prediction.Output.OutputArray) > 0 {
|
||||
for i, url := range prediction.Output.OutputArray {
|
||||
response.Data = append(response.Data, schemas.ImageData{
|
||||
URL: url,
|
||||
Index: i,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract usage information from logs
|
||||
if prediction.Logs != nil {
|
||||
inputTokens, outputTokens, totalTokens, found := parseTokenUsageFromLogs(prediction.Logs, schemas.ImageGenerationRequest)
|
||||
if found {
|
||||
response.Usage = &schemas.ImageUsage{
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
TotalTokens: totalTokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// getInputImageFieldName returns the appropriate input image field name based on the model.
|
||||
// Uses O(1) map lookup for high RPS performance.
|
||||
func getInputImageFieldName(model string) string {
|
||||
// Normalize model name to lowercase for comparison
|
||||
modelLower := strings.ToLower(model)
|
||||
|
||||
// Extract model identifier (handle both "owner/name" and "owner/name:version" formats)
|
||||
modelIdentifier := modelLower
|
||||
if before, _, ok := strings.Cut(modelLower, ":"); ok {
|
||||
modelIdentifier = before
|
||||
}
|
||||
|
||||
if fieldName, exists := modelInputImageFieldMap[modelIdentifier]; exists {
|
||||
return fieldName
|
||||
}
|
||||
|
||||
// Default to input_images for all other models
|
||||
return "input_images"
|
||||
}
|
||||
|
||||
// ToReplicateImageEditInput converts a Bifrost image edit request to Replicate prediction input
|
||||
func ToReplicateImageEditInput(bifrostReq *schemas.BifrostImageEditRequest) *ReplicatePredictionRequest {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
input := &ReplicatePredictionRequestInput{
|
||||
Prompt: &bifrostReq.Input.Prompt,
|
||||
}
|
||||
|
||||
// Map image URLs - Replicate requires image URLs, not file bytes
|
||||
if len(bifrostReq.Input.Images) > 0 {
|
||||
images := make([]string, 0, len(bifrostReq.Input.Images))
|
||||
for _, img := range bifrostReq.Input.Images {
|
||||
if len(img.Image) > 0 {
|
||||
images = append(images, providerUtils.FileBytesToBase64DataURL(img.Image))
|
||||
}
|
||||
}
|
||||
|
||||
if len(images) > 0 {
|
||||
// Determine the appropriate field based on model
|
||||
fieldName := getInputImageFieldName(bifrostReq.Model)
|
||||
|
||||
switch fieldName {
|
||||
case "image_prompt":
|
||||
// For flux-1.1-pro variants: use first image as image_prompt
|
||||
input.ImagePrompt = &images[0]
|
||||
|
||||
case "input_image":
|
||||
// For flux-kontext variants: use first image as input_image
|
||||
input.InputImage = &images[0]
|
||||
|
||||
case "image":
|
||||
// For flux-dev variants: use first image as image field
|
||||
input.Image = &images[0]
|
||||
|
||||
case "input_images":
|
||||
// For all other models: use input_images array (preserves multi-image support)
|
||||
input.InputImages = images
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Map parameters if available
|
||||
if bifrostReq.Params != nil {
|
||||
params := bifrostReq.Params
|
||||
|
||||
if params.N != nil {
|
||||
input.NumberOfImages = params.N
|
||||
}
|
||||
|
||||
if params.Size != nil {
|
||||
aspectRatio, imageSize := providerUtils.ConvertSizeToAspectRatioAndResolution(*params.Size)
|
||||
_, hasExplicitAspectRatio := params.ExtraParams["aspect_ratio"]
|
||||
_, hasExplicitResolution := params.ExtraParams["resolution"]
|
||||
if aspectRatio != "" && !hasExplicitAspectRatio {
|
||||
input.AspectRatio = &aspectRatio
|
||||
}
|
||||
if imageSize != "" && !hasExplicitResolution {
|
||||
input.Resolution = &imageSize
|
||||
}
|
||||
}
|
||||
|
||||
if params.OutputFormat != nil {
|
||||
input.OutputFormat = params.OutputFormat
|
||||
}
|
||||
|
||||
if params.Quality != nil {
|
||||
input.Quality = params.Quality
|
||||
}
|
||||
|
||||
if params.Background != nil {
|
||||
input.Background = params.Background
|
||||
}
|
||||
|
||||
if params.Seed != nil {
|
||||
input.Seed = params.Seed
|
||||
}
|
||||
|
||||
if params.NegativePrompt != nil {
|
||||
input.NegativePrompt = params.NegativePrompt
|
||||
}
|
||||
|
||||
if params.NumInferenceSteps != nil {
|
||||
input.NumInferenceStep = params.NumInferenceSteps
|
||||
}
|
||||
|
||||
if params.ExtraParams != nil {
|
||||
input.ExtraParams = params.ExtraParams
|
||||
}
|
||||
}
|
||||
|
||||
request := &ReplicatePredictionRequest{
|
||||
Input: input,
|
||||
}
|
||||
|
||||
// Check if model is a version ID and set version field accordingly
|
||||
if isVersionID(bifrostReq.Model) {
|
||||
request.Version = &bifrostReq.Model
|
||||
}
|
||||
|
||||
return request
|
||||
}
|
||||
75
core/providers/replicate/models.go
Normal file
75
core/providers/replicate/models.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package replicate
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToBifrostListModelsResponse converts Replicate deployments to a Bifrost list models response.
|
||||
// Replicate model IDs are composite: "{owner}/{name}" (e.g. "stability-ai/stable-diffusion").
|
||||
func ToBifrostListModelsResponse(
|
||||
deploymentsResponse *ReplicateDeploymentListResponse,
|
||||
providerKey schemas.ModelProvider,
|
||||
allowedModels schemas.WhiteList,
|
||||
blacklistedModels schemas.BlackList,
|
||||
aliases map[string]string,
|
||||
unfiltered bool,
|
||||
) *schemas.BifrostListModelsResponse {
|
||||
bifrostResponse := &schemas.BifrostListModelsResponse{
|
||||
Data: make([]schemas.Model, 0),
|
||||
}
|
||||
|
||||
pipeline := &providerUtils.ListModelsPipeline{
|
||||
AllowedModels: allowedModels,
|
||||
BlacklistedModels: blacklistedModels,
|
||||
Aliases: aliases,
|
||||
Unfiltered: unfiltered,
|
||||
ProviderKey: providerKey,
|
||||
MatchFns: providerUtils.DefaultMatchFns(),
|
||||
}
|
||||
if pipeline.ShouldEarlyExit() {
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
included := make(map[string]bool)
|
||||
|
||||
if deploymentsResponse != nil {
|
||||
for _, deployment := range deploymentsResponse.Results {
|
||||
// Replicate model IDs are composite owner/name
|
||||
deploymentID := deployment.Owner + "/" + deployment.Name
|
||||
|
||||
var created *int64
|
||||
if deployment.CurrentRelease != nil && deployment.CurrentRelease.CreatedAt != "" {
|
||||
createdTimestamp := ParseReplicateTimestamp(deployment.CurrentRelease.CreatedAt)
|
||||
if createdTimestamp > 0 {
|
||||
created = schemas.Ptr(createdTimestamp)
|
||||
}
|
||||
}
|
||||
|
||||
for _, result := range pipeline.FilterModel(deploymentID) {
|
||||
bifrostModel := schemas.Model{
|
||||
ID: string(providerKey) + "/" + result.ResolvedID,
|
||||
Name: schemas.Ptr(deployment.Name),
|
||||
OwnedBy: schemas.Ptr(deployment.Owner),
|
||||
Created: created,
|
||||
}
|
||||
if result.AliasValue != "" {
|
||||
bifrostModel.Alias = schemas.Ptr(result.AliasValue)
|
||||
}
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, bifrostModel)
|
||||
included[strings.ToLower(result.ResolvedID)] = true
|
||||
}
|
||||
}
|
||||
|
||||
if deploymentsResponse.Next != nil {
|
||||
bifrostResponse.NextPageToken = *deploymentsResponse.Next
|
||||
}
|
||||
}
|
||||
|
||||
bifrostResponse.Data = append(bifrostResponse.Data,
|
||||
pipeline.BackfillModels(included)...)
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
3292
core/providers/replicate/replicate.go
Normal file
3292
core/providers/replicate/replicate.go
Normal file
File diff suppressed because it is too large
Load Diff
1440
core/providers/replicate/replicate_test.go
Normal file
1440
core/providers/replicate/replicate_test.go
Normal file
File diff suppressed because it is too large
Load Diff
300
core/providers/replicate/responses.go
Normal file
300
core/providers/replicate/responses.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package replicate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func ToReplicateResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) (*ReplicatePredictionRequest, error) {
|
||||
if bifrostReq == nil {
|
||||
return nil, fmt.Errorf("bifrost request is nil")
|
||||
}
|
||||
|
||||
input := &ReplicatePredictionRequestInput{}
|
||||
|
||||
if strings.HasPrefix(bifrostReq.Model, "openai/") && strings.Contains(bifrostReq.Model, "gpt-5-structured") {
|
||||
// handle responses style request
|
||||
if len(bifrostReq.Input) > 0 {
|
||||
input.InputItemList = bifrostReq.Input
|
||||
}
|
||||
if bifrostReq.Params != nil {
|
||||
if bifrostReq.Params.Instructions != nil {
|
||||
input.Instructions = bifrostReq.Params.Instructions
|
||||
}
|
||||
if bifrostReq.Params.Tools != nil {
|
||||
input.Tools = bifrostReq.Params.Tools
|
||||
}
|
||||
if bifrostReq.Params.MaxOutputTokens != nil {
|
||||
input.MaxOutputTokens = bifrostReq.Params.MaxOutputTokens
|
||||
}
|
||||
if bifrostReq.Params.Text != nil {
|
||||
input.JsonSchema = bifrostReq.Params.Text
|
||||
}
|
||||
if bifrostReq.Params.ExtraParams != nil {
|
||||
input.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// handle chat style request (same logic as chat converter)
|
||||
if len(bifrostReq.Input) > 0 {
|
||||
// if model is from openai family, use messages
|
||||
if strings.HasPrefix(bifrostReq.Model, string(schemas.OpenAI)) {
|
||||
input.Messages = schemas.ToChatMessages(bifrostReq.Input)
|
||||
} else {
|
||||
// convert input to prompt and system prompt
|
||||
var systemPrompt string
|
||||
var conversationParts []string
|
||||
var imageInput []string
|
||||
|
||||
for _, msg := range bifrostReq.Input {
|
||||
if msg.Content == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get message content as string
|
||||
var contentStr string
|
||||
if msg.Content.ContentStr != nil {
|
||||
contentStr = *msg.Content.ContentStr
|
||||
} else if msg.Content.ContentBlocks != nil {
|
||||
// Concatenate text blocks only
|
||||
var textParts []string
|
||||
for _, block := range msg.Content.ContentBlocks {
|
||||
if block.Text != nil && *block.Text != "" {
|
||||
textParts = append(textParts, *block.Text)
|
||||
}
|
||||
if block.ResponsesInputMessageContentBlockImage != nil && block.ResponsesInputMessageContentBlockImage.ImageURL != nil && *block.ResponsesInputMessageContentBlockImage.ImageURL != "" {
|
||||
imageInput = append(imageInput, *block.ResponsesInputMessageContentBlockImage.ImageURL)
|
||||
}
|
||||
}
|
||||
contentStr = strings.Join(textParts, "\n")
|
||||
}
|
||||
|
||||
if contentStr == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle different roles
|
||||
if msg.Role != nil {
|
||||
switch *msg.Role {
|
||||
case schemas.ResponsesInputMessageRoleSystem:
|
||||
if systemPrompt == "" {
|
||||
systemPrompt = contentStr
|
||||
} else {
|
||||
systemPrompt += "\n" + contentStr
|
||||
}
|
||||
case schemas.ResponsesInputMessageRoleUser:
|
||||
conversationParts = append(conversationParts, contentStr)
|
||||
case schemas.ResponsesInputMessageRoleAssistant:
|
||||
// For assistant messages, we can include them in the conversation context
|
||||
conversationParts = append(conversationParts, contentStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set system prompt if present and model supports it
|
||||
modelSupportsSystemPrompt := supportsSystemPrompt(bifrostReq.Model)
|
||||
|
||||
if systemPrompt != "" {
|
||||
if modelSupportsSystemPrompt {
|
||||
// Model supports system_prompt field
|
||||
input.SystemPrompt = &systemPrompt
|
||||
} else {
|
||||
// Model doesn't support system_prompt - prepend to prompt
|
||||
if len(conversationParts) > 0 {
|
||||
// Prepend system prompt to conversation
|
||||
conversationParts = append([]string{systemPrompt}, conversationParts...)
|
||||
} else {
|
||||
// No conversation parts, use system prompt as the prompt
|
||||
conversationParts = []string{systemPrompt}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build the final prompt from conversation parts
|
||||
if len(conversationParts) > 0 {
|
||||
prompt := strings.Join(conversationParts, "\n\n")
|
||||
input.Prompt = &prompt
|
||||
}
|
||||
|
||||
if len(imageInput) > 0 {
|
||||
input.ImageInput = imageInput
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Map parameters if present
|
||||
if bifrostReq.Params != nil {
|
||||
params := bifrostReq.Params
|
||||
|
||||
// Temperature
|
||||
if params.Temperature != nil {
|
||||
input.Temperature = params.Temperature
|
||||
}
|
||||
|
||||
// Top P
|
||||
if params.TopP != nil {
|
||||
input.TopP = params.TopP
|
||||
}
|
||||
|
||||
// Max tokens - use max_completion_tokens if available
|
||||
if params.MaxOutputTokens != nil {
|
||||
if strings.HasPrefix(bifrostReq.Model, string(schemas.OpenAI)) {
|
||||
input.MaxCompletionTokens = params.MaxOutputTokens
|
||||
} else {
|
||||
input.MaxTokens = params.MaxOutputTokens
|
||||
}
|
||||
}
|
||||
|
||||
// Reasoning effort
|
||||
if params.Reasoning != nil {
|
||||
if params.Reasoning.Effort != nil {
|
||||
input.ReasoningEffort = params.Reasoning.Effort
|
||||
}
|
||||
}
|
||||
|
||||
if params.Instructions != nil && *params.Instructions != "" {
|
||||
if supportsSystemPrompt(bifrostReq.Model) {
|
||||
if input.SystemPrompt == nil {
|
||||
input.SystemPrompt = params.Instructions
|
||||
}
|
||||
} else {
|
||||
if input.Prompt != nil && *input.Prompt != "" {
|
||||
prefixed := *params.Instructions + "\n\n" + *input.Prompt
|
||||
input.Prompt = schemas.Ptr(prefixed)
|
||||
} else if input.Prompt == nil {
|
||||
input.Prompt = params.Instructions
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if params.ExtraParams != nil {
|
||||
input.ExtraParams = params.ExtraParams
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if model is a version ID and set version field accordingly
|
||||
req := &ReplicatePredictionRequest{
|
||||
Input: input,
|
||||
}
|
||||
|
||||
if isVersionID(bifrostReq.Model) {
|
||||
req.Version = &bifrostReq.Model
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil {
|
||||
if webhook, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["webhook"]); ok {
|
||||
req.Webhook = webhook
|
||||
}
|
||||
if webhookEventsFilter, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["webhook_events_filter"]); ok {
|
||||
req.WebhookEventsFilter = webhookEventsFilter
|
||||
}
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (response *ReplicatePredictionResponse) ToBifrostResponsesResponse() *schemas.BifrostResponsesResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse timestamps
|
||||
createdAt := ParseReplicateTimestamp(response.CreatedAt)
|
||||
if createdAt == 0 {
|
||||
createdAt = time.Now().Unix()
|
||||
}
|
||||
|
||||
var completedAt *int
|
||||
if response.CompletedAt != nil {
|
||||
completed := int(ParseReplicateTimestamp(*response.CompletedAt))
|
||||
if completed > 0 {
|
||||
completedAt = &completed
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize Bifrost response
|
||||
bifrostResponse := &schemas.BifrostResponsesResponse{
|
||||
ID: schemas.Ptr(response.ID),
|
||||
Model: response.Model,
|
||||
CreatedAt: int(createdAt),
|
||||
CompletedAt: completedAt,
|
||||
}
|
||||
|
||||
// Convert output to ResponsesMessage
|
||||
var outputMessages []schemas.ResponsesMessage
|
||||
if response.Output != nil {
|
||||
var contentStr *string
|
||||
|
||||
// Handle different output types
|
||||
if response.Output.OutputStr != nil {
|
||||
contentStr = response.Output.OutputStr
|
||||
} else if response.Output.OutputArray != nil {
|
||||
// Join array of strings into a single string
|
||||
joined := strings.Join(response.Output.OutputArray, "")
|
||||
contentStr = &joined
|
||||
} else if response.Output.OutputObject != nil && response.Output.OutputObject.Text != nil {
|
||||
// Use text field from OutputObject
|
||||
contentStr = response.Output.OutputObject.Text
|
||||
}
|
||||
|
||||
if contentStr != nil && *contentStr != "" {
|
||||
messageType := schemas.ResponsesMessageTypeMessage
|
||||
role := schemas.ResponsesInputMessageRoleAssistant
|
||||
|
||||
outputMsg := schemas.ResponsesMessage{
|
||||
Type: &messageType,
|
||||
Role: &role,
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentStr: contentStr,
|
||||
},
|
||||
}
|
||||
outputMessages = append(outputMessages, outputMsg)
|
||||
}
|
||||
}
|
||||
|
||||
bifrostResponse.Output = outputMessages
|
||||
|
||||
// Set status based on prediction status
|
||||
var status string
|
||||
switch response.Status {
|
||||
case ReplicatePredictionStatusSucceeded:
|
||||
status = "completed"
|
||||
case ReplicatePredictionStatusFailed:
|
||||
status = "failed"
|
||||
case ReplicatePredictionStatusCanceled:
|
||||
status = "cancelled"
|
||||
case ReplicatePredictionStatusProcessing:
|
||||
status = "in_progress"
|
||||
case ReplicatePredictionStatusStarting:
|
||||
status = "queued"
|
||||
default:
|
||||
status = string(response.Status)
|
||||
}
|
||||
bifrostResponse.Status = &status
|
||||
|
||||
// Set error if present
|
||||
if response.Error != nil && *response.Error != "" {
|
||||
bifrostResponse.Error = &schemas.ResponsesResponseError{
|
||||
Code: "provider_error",
|
||||
Message: *response.Error,
|
||||
}
|
||||
}
|
||||
|
||||
// Extract usage information from logs
|
||||
if response.Logs != nil {
|
||||
inputTokens, outputTokens, totalTokens, found := parseTokenUsageFromLogs(response.Logs, schemas.ResponsesRequest)
|
||||
if found {
|
||||
bifrostResponse.Usage = &schemas.ResponsesResponseUsage{
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
TotalTokens: totalTokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
148
core/providers/replicate/text.go
Normal file
148
core/providers/replicate/text.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package replicate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func ToReplicateTextRequest(bifrostReq *schemas.BifrostTextCompletionRequest) (*ReplicatePredictionRequest, error) {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil {
|
||||
return nil, fmt.Errorf("bifrost request is nil or prompt is nil")
|
||||
}
|
||||
|
||||
input := &ReplicatePredictionRequestInput{}
|
||||
if bifrostReq.Input.PromptStr != nil {
|
||||
input.Prompt = bifrostReq.Input.PromptStr
|
||||
} else if len(bifrostReq.Input.PromptArray) > 0 {
|
||||
prompt := strings.Join(bifrostReq.Input.PromptArray, "\n")
|
||||
input.Prompt = &prompt
|
||||
}
|
||||
|
||||
// Map parameters if present
|
||||
if bifrostReq.Params != nil {
|
||||
params := bifrostReq.Params
|
||||
|
||||
// Temperature
|
||||
if params.Temperature != nil {
|
||||
input.Temperature = params.Temperature
|
||||
}
|
||||
|
||||
// Top P
|
||||
if params.TopP != nil {
|
||||
input.TopP = params.TopP
|
||||
}
|
||||
|
||||
// Max tokens
|
||||
if params.MaxTokens != nil {
|
||||
input.MaxTokens = params.MaxTokens
|
||||
}
|
||||
|
||||
// Presence penalty
|
||||
if params.PresencePenalty != nil {
|
||||
input.PresencePenalty = params.PresencePenalty
|
||||
}
|
||||
|
||||
// Frequency penalty
|
||||
if params.FrequencyPenalty != nil {
|
||||
input.FrequencyPenalty = params.FrequencyPenalty
|
||||
}
|
||||
|
||||
// Top K (from ExtraParams)
|
||||
if topK, ok := schemas.SafeExtractIntPointer(params.ExtraParams["top_k"]); ok {
|
||||
input.TopK = topK
|
||||
}
|
||||
|
||||
// Seed
|
||||
if params.Seed != nil {
|
||||
input.Seed = params.Seed
|
||||
}
|
||||
|
||||
if params.ExtraParams != nil {
|
||||
input.ExtraParams = params.ExtraParams
|
||||
}
|
||||
}
|
||||
|
||||
// Check if model is a version ID and set version field accordingly
|
||||
req := &ReplicatePredictionRequest{
|
||||
Input: input,
|
||||
}
|
||||
|
||||
if isVersionID(bifrostReq.Model) {
|
||||
req.Version = &bifrostReq.Model
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil {
|
||||
if webhook, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["webhook"]); ok {
|
||||
req.Webhook = webhook
|
||||
}
|
||||
if webhookEventsFilter, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["webhook_events_filter"]); ok {
|
||||
req.WebhookEventsFilter = webhookEventsFilter
|
||||
}
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// ToBifrostTextCompletionResponse converts a Replicate prediction response to Bifrost format
|
||||
func (response *ReplicatePredictionResponse) ToBifrostTextCompletionResponse() *schemas.BifrostTextCompletionResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initialize Bifrost response
|
||||
bifrostResponse := &schemas.BifrostTextCompletionResponse{
|
||||
ID: response.ID,
|
||||
Model: response.Model,
|
||||
Object: "text_completion",
|
||||
}
|
||||
|
||||
// Convert output to text
|
||||
var textOutput *string
|
||||
if response.Output != nil {
|
||||
if response.Output.OutputStr != nil {
|
||||
textOutput = response.Output.OutputStr
|
||||
} else if response.Output.OutputArray != nil {
|
||||
// Join array of strings into a single string
|
||||
joined := strings.Join(response.Output.OutputArray, "")
|
||||
textOutput = &joined
|
||||
}
|
||||
}
|
||||
|
||||
// Determine finish reason based on status
|
||||
var finishReason *string
|
||||
switch response.Status {
|
||||
case ReplicatePredictionStatusSucceeded:
|
||||
finishReason = schemas.Ptr("stop")
|
||||
case ReplicatePredictionStatusFailed:
|
||||
finishReason = schemas.Ptr("error")
|
||||
case ReplicatePredictionStatusCanceled:
|
||||
finishReason = schemas.Ptr("stop")
|
||||
}
|
||||
|
||||
// Create choice with text completion response choice
|
||||
choice := schemas.BifrostResponseChoice{
|
||||
Index: 0,
|
||||
TextCompletionResponseChoice: &schemas.TextCompletionResponseChoice{
|
||||
Text: textOutput,
|
||||
},
|
||||
FinishReason: finishReason,
|
||||
}
|
||||
|
||||
bifrostResponse.Choices = []schemas.BifrostResponseChoice{choice}
|
||||
|
||||
// Extract usage information from logs
|
||||
if response.Logs != nil {
|
||||
inputTokens, outputTokens, totalTokens, found := parseTokenUsageFromLogs(response.Logs, schemas.TextCompletionRequest)
|
||||
if found {
|
||||
bifrostResponse.Usage = &schemas.BifrostLLMUsage{
|
||||
PromptTokens: inputTokens,
|
||||
CompletionTokens: outputTokens,
|
||||
TotalTokens: totalTokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
507
core/providers/replicate/types.go
Normal file
507
core/providers/replicate/types.go
Normal file
@@ -0,0 +1,507 @@
|
||||
package replicate
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ==================== REQUEST TYPES ====================
|
||||
|
||||
// ReplicatePredictionRequest represents a request to create a prediction
|
||||
type ReplicatePredictionRequest struct {
|
||||
Version *string `json:"version,omitempty"` // Required: Model version ID
|
||||
Input *ReplicatePredictionRequestInput `json:"input"` // Required: Input parameters for the model
|
||||
Stream *bool `json:"stream,omitempty"` // Enable streaming output
|
||||
Webhook *string `json:"webhook,omitempty"` // Webhook URL for notifications
|
||||
WebhookEventsFilter []string `json:"webhook_events_filter,omitempty"` // Filter webhook events: start, output, logs, completed
|
||||
OutputFileURLPrefix *string `json:"output_file_url_prefix,omitempty"` // Custom prefix for output file URLs
|
||||
PollTimeout *int `json:"poll_timeout,omitempty"` // Timeout in seconds for polling (used with Prefer: wait header)
|
||||
UseFileOutput *bool `json:"use_file_output,omitempty"` // Output files as URLs instead of data URIs
|
||||
ExtraParams map[string]interface{} `json:"-"` // Extra parameters to merge into the request
|
||||
}
|
||||
|
||||
// GetExtraParams implements the RequestBodyWithExtraParams interface
|
||||
func (req *ReplicatePredictionRequest) GetExtraParams() map[string]interface{} {
|
||||
return req.ExtraParams
|
||||
}
|
||||
|
||||
// ReplicatePredictionRequestInput represents the input parameters for a model prediction
|
||||
// This is flexible to support different model types - exact fields depend on the model
|
||||
type ReplicatePredictionRequestInput struct {
|
||||
Prompt *string `json:"prompt,omitempty"`
|
||||
Messages []schemas.ChatMessage `json:"messages,omitempty"`
|
||||
SystemPrompt *string `json:"system_prompt,omitempty"`
|
||||
Image *string `json:"image,omitempty"` // URL or data URI
|
||||
NumberOfImages *int `json:"number_of_images,omitempty"` // Number of images to generate
|
||||
Quality *string `json:"quality,omitempty"` // Quality of the image
|
||||
Background *string `json:"background,omitempty"` // Background of the image
|
||||
Seed *int `json:"seed,omitempty"` // Random seed
|
||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"` // Reasoning effort
|
||||
NumInferenceStep *int `json:"num_inference_steps,omitempty"` // Number of inference steps
|
||||
NegativePrompt *string `json:"negative_prompt,omitempty"` // Negative prompt
|
||||
|
||||
// Responses parameters
|
||||
Instructions *string `json:"instructions,omitempty"`
|
||||
InputItemList []schemas.ResponsesMessage `json:"input_item_list,omitempty"`
|
||||
Tools []schemas.ResponsesTool `json:"tools,omitempty"`
|
||||
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
|
||||
JsonSchema *schemas.ResponsesTextConfig `json:"json_schema,omitempty"`
|
||||
|
||||
// Chat parameters
|
||||
Temperature *float64 `json:"temperature,omitempty"` // Temperature for sampling
|
||||
TopP *float64 `json:"top_p,omitempty"` // Top-p sampling
|
||||
TopK *int `json:"top_k,omitempty"` // Top-k sampling
|
||||
MaxTokens *int `json:"max_tokens,omitempty"` // Maximum tokens to generate
|
||||
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` // Maximum completion tokens to generate
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Presence penalty
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Frequency penalty
|
||||
|
||||
// Image generation parameters
|
||||
AspectRatio *string `json:"aspect_ratio,omitempty"`
|
||||
Resolution *string `json:"resolution,omitempty"` // Resolution tier: "1k", "2k", "4k"
|
||||
OutputFormat *string `json:"output_format,omitempty"`
|
||||
InputImages []string `json:"input_images,omitempty"` // Image input for image-to-image models
|
||||
ImagePrompt *string `json:"image_prompt,omitempty"` // Image prompt for image models (flux family)
|
||||
ImageInput []string `json:"image_input,omitempty"` // Image input for chat models (openai family)
|
||||
InputImage *string `json:"input_image,omitempty"` // Image input for image-to-image models
|
||||
|
||||
// video generation parameters
|
||||
Duration *int `json:"duration,omitempty"`
|
||||
InputReference *string `json:"input_reference,omitempty"`
|
||||
ExtraParams map[string]interface{} `json:"-"` // Additional model-specific parameters
|
||||
}
|
||||
|
||||
// MarshalJSON implements custom JSON marshalling for ReplicatePredictionRequestInput.
|
||||
// It marshals all defined fields and then flattens ExtraParams at the top level.
|
||||
func (r *ReplicatePredictionRequestInput) MarshalJSON() ([]byte, error) {
|
||||
if r == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
|
||||
// Create a temporary type to avoid infinite recursion
|
||||
type Alias ReplicatePredictionRequestInput
|
||||
|
||||
// Marshal the struct normally (ExtraParams will be omitted due to json:"-" tag)
|
||||
aliasData, err := providerUtils.MarshalSorted((*Alias)(r))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If there are no ExtraParams, return the marshaled data as-is
|
||||
if len(r.ExtraParams) == 0 {
|
||||
return aliasData, nil
|
||||
}
|
||||
|
||||
// Use order-preserving merge to avoid destroying key ordering in the serialized JSON.
|
||||
return providerUtils.MergeExtraParamsIntoJSON(aliasData, r.ExtraParams)
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshalling for ReplicatePredictionRequestInput.
|
||||
// It unmarshals known fields and captures any unrecognized fields in ExtraParams.
|
||||
func (r *ReplicatePredictionRequestInput) UnmarshalJSON(data []byte) error {
|
||||
// Create a temporary type to avoid infinite recursion
|
||||
type Alias ReplicatePredictionRequestInput
|
||||
|
||||
// Unmarshal into the alias type
|
||||
aux := (*Alias)(r)
|
||||
if err := sonic.Unmarshal(data, aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Unmarshal into a map to find extra fields
|
||||
var rawMap map[string]interface{}
|
||||
if err := sonic.Unmarshal(data, &rawMap); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// List of known field names (in JSON format)
|
||||
knownFields := map[string]bool{
|
||||
"prompt": true,
|
||||
"messages": true,
|
||||
"system_prompt": true,
|
||||
"image": true,
|
||||
"number_of_images": true,
|
||||
"quality": true,
|
||||
"background": true,
|
||||
"seed": true,
|
||||
"reasoning_effort": true,
|
||||
"num_inference_steps": true,
|
||||
"negative_prompt": true,
|
||||
"instructions": true,
|
||||
"input_item_list": true,
|
||||
"tools": true,
|
||||
"max_output_tokens": true,
|
||||
"json_schema": true,
|
||||
"temperature": true,
|
||||
"top_p": true,
|
||||
"top_k": true,
|
||||
"max_tokens": true,
|
||||
"max_completion_tokens": true,
|
||||
"presence_penalty": true,
|
||||
"frequency_penalty": true,
|
||||
"aspect_ratio": true,
|
||||
"resolution": true,
|
||||
"output_format": true,
|
||||
"input_images": true,
|
||||
"image_prompt": true,
|
||||
"input_image": true,
|
||||
"image_input": true,
|
||||
"duration": true,
|
||||
"input_reference": true,
|
||||
}
|
||||
|
||||
// Collect extra fields
|
||||
r.ExtraParams = make(map[string]interface{})
|
||||
for key, value := range rawMap {
|
||||
if !knownFields[key] {
|
||||
r.ExtraParams[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
// If no extra params found, keep it as nil instead of empty map
|
||||
if len(r.ExtraParams) == 0 {
|
||||
r.ExtraParams = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReplicateModelListRequest represents a request to list/search models
|
||||
type ReplicateModelListRequest struct {
|
||||
Query *string `json:"query,omitempty"` // Search query
|
||||
Limit *int `json:"limit,omitempty"` // Maximum results (1-50, default 20)
|
||||
}
|
||||
|
||||
// ==================== RESPONSE TYPES ====================
|
||||
|
||||
// ReplicatePredictionStatus represents the status of a prediction
|
||||
type ReplicatePredictionStatus string
|
||||
|
||||
const (
|
||||
ReplicatePredictionStatusStarting ReplicatePredictionStatus = "starting"
|
||||
ReplicatePredictionStatusProcessing ReplicatePredictionStatus = "processing"
|
||||
ReplicatePredictionStatusSucceeded ReplicatePredictionStatus = "succeeded"
|
||||
ReplicatePredictionStatusFailed ReplicatePredictionStatus = "failed"
|
||||
ReplicatePredictionStatusCanceled ReplicatePredictionStatus = "canceled"
|
||||
)
|
||||
|
||||
// ReplicatePredictionResponse represents a prediction response
|
||||
type ReplicatePredictionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"` // Model identifier (owner/name or owner/name:version)
|
||||
Version string `json:"version"` // Model version ID
|
||||
Input json.RawMessage `json:"input"` // Input parameters used (json.RawMessage preserves key ordering)
|
||||
Output *ReplicateOutput `json:"output,omitempty"` // Output data (can be various types)
|
||||
Logs *string `json:"logs,omitempty"` // Execution logs
|
||||
Error *string `json:"error,omitempty"` // Error message if failed
|
||||
Status ReplicatePredictionStatus `json:"status"` // Current status
|
||||
CreatedAt string `json:"created_at"` // ISO 8601 timestamp
|
||||
StartedAt *string `json:"started_at,omitempty"` // ISO 8601 timestamp
|
||||
CompletedAt *string `json:"completed_at,omitempty"` // ISO 8601 timestamp
|
||||
URLs *ReplicatePredictionURLs `json:"urls,omitempty"` // URLs for API operations
|
||||
Metrics *ReplicateMetrics `json:"metrics,omitempty"` // Execution metrics
|
||||
DataRemoved *bool `json:"data_removed,omitempty"` // Whether data has been removed
|
||||
Source *string `json:"source,omitempty"` // Source of the prediction (web/api)
|
||||
WebhookCompleted *bool `json:"webhook_completed,omitempty"` // Whether webhook was completed
|
||||
Stream *bool `json:"stream,omitempty"` // Whether the prediction is streaming
|
||||
}
|
||||
|
||||
type ReplicateOutputText struct {
|
||||
ResponseId *string `json:"response_id,omitempty"`
|
||||
Text *string `json:"text,omitempty"`
|
||||
}
|
||||
type ReplicateOutput struct {
|
||||
OutputStr *string
|
||||
OutputArray []string
|
||||
OutputObject *ReplicateOutputText
|
||||
}
|
||||
|
||||
// MarshalJSON implements custom JSON marshalling for ReplicateOutput.
|
||||
// It marshals either OutputStr, OutputArray, or OutputObject directly without wrapping.
|
||||
func (mc ReplicateOutput) MarshalJSON() ([]byte, error) {
|
||||
// Validation: ensure only one field is set at a time
|
||||
fieldsSet := 0
|
||||
if mc.OutputStr != nil {
|
||||
fieldsSet++
|
||||
}
|
||||
if mc.OutputArray != nil {
|
||||
fieldsSet++
|
||||
}
|
||||
if mc.OutputObject != nil {
|
||||
fieldsSet++
|
||||
}
|
||||
if fieldsSet > 1 {
|
||||
return nil, fmt.Errorf("multiple output fields are set; only one should be non-nil")
|
||||
}
|
||||
|
||||
if mc.OutputStr != nil {
|
||||
return providerUtils.MarshalSorted(*mc.OutputStr)
|
||||
}
|
||||
if mc.OutputArray != nil {
|
||||
return providerUtils.MarshalSorted(mc.OutputArray)
|
||||
}
|
||||
if mc.OutputObject != nil {
|
||||
return providerUtils.MarshalSorted(mc.OutputObject)
|
||||
}
|
||||
// If all are nil, return null
|
||||
return providerUtils.MarshalSorted(nil)
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshalling for ReplicateOutput.
|
||||
// It determines whether "content" is a string, array, or object and assigns to the appropriate field.
|
||||
func (mc *ReplicateOutput) UnmarshalJSON(data []byte) error {
|
||||
mc.OutputStr = nil
|
||||
mc.OutputArray = nil
|
||||
mc.OutputObject = nil
|
||||
|
||||
if string(data) == "null" || len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// First, try to unmarshal as a direct string
|
||||
var stringContent string
|
||||
if err := sonic.Unmarshal(data, &stringContent); err == nil {
|
||||
mc.OutputStr = schemas.Ptr(stringContent)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to unmarshal as a direct array of strings
|
||||
var arrayContent []string
|
||||
if err := sonic.Unmarshal(data, &arrayContent); err == nil {
|
||||
mc.OutputArray = arrayContent
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to unmarshal as an object (ReplicateOutputText)
|
||||
var objectContent ReplicateOutputText
|
||||
if err := sonic.Unmarshal(data, &objectContent); err == nil {
|
||||
mc.OutputObject = &objectContent
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("output field is neither a string, array of strings, nor a valid object")
|
||||
}
|
||||
|
||||
// ReplicatePredictionURLs represents URLs for prediction operations
|
||||
type ReplicatePredictionURLs struct {
|
||||
Get string `json:"get"` // URL to get prediction details
|
||||
Cancel string `json:"cancel"` // URL to cancel prediction
|
||||
Stream *string `json:"stream,omitempty"` // URL for streaming output (if applicable)
|
||||
Web *string `json:"web,omitempty"` // URL for web output (if applicable)
|
||||
}
|
||||
|
||||
// ReplicateMetrics represents execution metrics
|
||||
type ReplicateMetrics struct {
|
||||
PredictTime *float64 `json:"predict_time,omitempty"` // Time spent in prediction (seconds)
|
||||
TotalTime *float64 `json:"total_time,omitempty"` // Total time including queue (seconds)
|
||||
ImageCount *int `json:"image_count,omitempty"` // Number of images generated
|
||||
TimeToFirstToken *float64 `json:"time_to_first_token,omitempty"` // Time to first token (seconds)
|
||||
TokensPerSecond *float64 `json:"tokens_per_second,omitempty"` // Tokens generated per second
|
||||
}
|
||||
|
||||
// ReplicatePredictionListResponse represents a paginated list of predictions
|
||||
type ReplicatePredictionListResponse struct {
|
||||
Next *string `json:"next"` // URL for next page
|
||||
Previous *string `json:"previous"` // URL for previous page
|
||||
Results []ReplicatePredictionResponse `json:"results"` // List of predictions
|
||||
}
|
||||
|
||||
// ReplicateModelResponse represents a model response
|
||||
type ReplicateModelResponse struct {
|
||||
URL string `json:"url"` // Model API URL
|
||||
Owner string `json:"owner"` // Owner username or org name
|
||||
Name string `json:"name"` // Model name
|
||||
Description *string `json:"description,omitempty"` // Model description
|
||||
Visibility string `json:"visibility"` // "public" or "private"
|
||||
GithubURL *string `json:"github_url,omitempty"` // GitHub repository URL
|
||||
PaperURL *string `json:"paper_url,omitempty"` // Research paper URL
|
||||
LicenseURL *string `json:"license_url,omitempty"` // License URL
|
||||
RunCount *int `json:"run_count,omitempty"` // Number of times run
|
||||
CoverImageURL *string `json:"cover_image_url,omitempty"` // Cover image URL
|
||||
DefaultExample *json.RawMessage `json:"default_example,omitempty"` // Default example prediction (json.RawMessage preserves key ordering)
|
||||
LatestVersion *ReplicateModelVersion `json:"latest_version,omitempty"` // Latest version details
|
||||
FeaturedVersion *ReplicateModelVersion `json:"featured_version,omitempty"` // Featured version details
|
||||
}
|
||||
|
||||
// ReplicateModelVersion represents a model version
|
||||
type ReplicateModelVersion struct {
|
||||
ID string `json:"id"` // Version ID
|
||||
CreatedAt string `json:"created_at"` // ISO 8601 timestamp
|
||||
CogVersion *string `json:"cog_version,omitempty"` // Cog version used
|
||||
OpenAPISchema json.RawMessage `json:"openapi_schema,omitempty"` // OpenAPI schema for the model (json.RawMessage preserves key ordering)
|
||||
DockerImageID *string `json:"docker_image_id,omitempty"` // Docker image ID
|
||||
}
|
||||
|
||||
// ReplicateModelListResponse represents a paginated list of models
|
||||
type ReplicateModelListResponse struct {
|
||||
Next *string `json:"next"` // URL for next page
|
||||
Previous *string `json:"previous"` // URL for previous page
|
||||
Results []ReplicateModelResponse `json:"results"` // List of models
|
||||
}
|
||||
|
||||
// ReplicateDeploymentOwner represents the owner of a deployment
|
||||
type ReplicateDeploymentOwner struct {
|
||||
Type string `json:"type"` // "user" or "organization"
|
||||
Username string `json:"username"` // Username or organization name
|
||||
Name *string `json:"name,omitempty"` // Display name
|
||||
AvatarURL *string `json:"avatar_url,omitempty"` // Avatar URL
|
||||
GithubURL *string `json:"github_url,omitempty"` // GitHub URL
|
||||
}
|
||||
|
||||
// ReplicateDeploymentConfiguration represents the deployment configuration
|
||||
type ReplicateDeploymentConfiguration struct {
|
||||
Hardware string `json:"hardware"` // Hardware type (e.g., "gpu-t4")
|
||||
MinInstances int `json:"min_instances"` // Minimum number of instances
|
||||
MaxInstances int `json:"max_instances"` // Maximum number of instances
|
||||
}
|
||||
|
||||
// ReplicateDeploymentRelease represents a deployment release
|
||||
type ReplicateDeploymentRelease struct {
|
||||
Number int `json:"number"` // Release number
|
||||
Model string `json:"model"` // Model identifier (owner/name)
|
||||
Version string `json:"version"` // Model version ID
|
||||
CreatedAt string `json:"created_at"` // ISO 8601 timestamp
|
||||
CreatedBy *ReplicateDeploymentOwner `json:"created_by"` // User or organization that created the release
|
||||
Configuration *ReplicateDeploymentConfiguration `json:"configuration"` // Deployment configuration
|
||||
}
|
||||
|
||||
// ReplicateDeployment represents a deployment
|
||||
type ReplicateDeployment struct {
|
||||
Owner string `json:"owner"` // Owner username or org name
|
||||
Name string `json:"name"` // Deployment name
|
||||
CurrentRelease *ReplicateDeploymentRelease `json:"current_release"` // Current active release
|
||||
}
|
||||
|
||||
// ReplicateDeploymentListResponse represents a paginated list of deployments
|
||||
type ReplicateDeploymentListResponse struct {
|
||||
Next *string `json:"next"` // URL for next page
|
||||
Previous *string `json:"previous"` // URL for previous page
|
||||
Results []ReplicateDeployment `json:"results"` // List of deployments
|
||||
}
|
||||
|
||||
// ==================== ERROR TYPES ====================
|
||||
|
||||
// ReplicateError represents an error response from the Replicate API
|
||||
type ReplicateError struct {
|
||||
Detail string `json:"detail"` // Error message
|
||||
Status int `json:"status"` // HTTP status code
|
||||
Title *string `json:"title,omitempty"` // Error title
|
||||
Type *string `json:"type,omitempty"` // Error type
|
||||
}
|
||||
|
||||
// ==================== STREAMING TYPES ====================
|
||||
|
||||
// ReplicateStreamEvent represents a streaming event
|
||||
type ReplicateStreamEvent struct {
|
||||
Event string `json:"event,omitempty"` // Event type (output, logs, done, error)
|
||||
Data interface{} `json:"data,omitempty"` // Event data
|
||||
Error *string `json:"error,omitempty"` // Error message if event is error
|
||||
}
|
||||
|
||||
// ==================== WEBHOOK TYPES ====================
|
||||
|
||||
// ReplicateWebhookPayload represents a webhook payload
|
||||
type ReplicateWebhookPayload struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Version string `json:"version"`
|
||||
Input json.RawMessage `json:"input"`
|
||||
Output interface{} `json:"output,omitempty"`
|
||||
Logs *string `json:"logs,omitempty"`
|
||||
Error *string `json:"error,omitempty"`
|
||||
Status ReplicatePredictionStatus `json:"status"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
StartedAt *string `json:"started_at,omitempty"`
|
||||
CompletedAt *string `json:"completed_at,omitempty"`
|
||||
URLs *ReplicatePredictionURLs `json:"urls,omitempty"`
|
||||
Metrics *ReplicateMetrics `json:"metrics,omitempty"`
|
||||
}
|
||||
|
||||
// ==================== SSE TYPES ====================
|
||||
|
||||
// ReplicateSSEEvent represents a Server-Sent Event from Replicate streaming API
|
||||
type ReplicateSSEEvent struct {
|
||||
Event string // Event type: "output", "done", "error"
|
||||
Data string // Event data - can be plain text, JSON object, or data URI
|
||||
ID string // Event ID (e.g., "1690212292:0")
|
||||
}
|
||||
|
||||
// ReplicateDoneEvent represents the data payload of a "done" event
|
||||
type ReplicateDoneEvent struct {
|
||||
Reason string `json:"reason,omitempty"` // Reason for completion: "canceled", "error", or empty for success
|
||||
Output interface{} `json:"output,omitempty"` // Output data if available (e.g., error message)
|
||||
}
|
||||
|
||||
// ReplicateErrorEvent represents the data payload of an "error" event
|
||||
type ReplicateErrorEvent struct {
|
||||
Detail string `json:"detail"` // Error message
|
||||
}
|
||||
|
||||
// ==================== UTILITY FUNCTIONS ====================
|
||||
|
||||
// ParseReplicateTimestamp parses a Replicate ISO 8601 timestamp to Unix timestamp
|
||||
func ParseReplicateTimestamp(timestamp string) int64 {
|
||||
if timestamp == "" {
|
||||
return 0
|
||||
}
|
||||
t, err := time.Parse(time.RFC3339Nano, timestamp)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return t.Unix()
|
||||
}
|
||||
|
||||
// ToBifrostPredictionStatus converts Replicate status to Bifrost status
|
||||
func ToBifrostPredictionStatus(status ReplicatePredictionStatus) string {
|
||||
switch status {
|
||||
case ReplicatePredictionStatusStarting:
|
||||
return "starting"
|
||||
case ReplicatePredictionStatusProcessing:
|
||||
return "processing"
|
||||
case ReplicatePredictionStatusSucceeded:
|
||||
return "succeeded"
|
||||
case ReplicatePredictionStatusFailed:
|
||||
return "failed"
|
||||
case ReplicatePredictionStatusCanceled:
|
||||
return "canceled"
|
||||
default:
|
||||
return string(status)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== FILE API TYPES ====================
|
||||
|
||||
// ReplicateFileResponse represents a file resource from Replicate API
|
||||
type ReplicateFileResponse struct {
|
||||
ID string `json:"id"` // Unique file identifier
|
||||
Checksums *ReplicateFileChecksum `json:"checksums,omitempty"` // File checksums
|
||||
ContentType string `json:"content_type"` // MIME type
|
||||
CreatedAt string `json:"created_at"` // ISO 8601 timestamp
|
||||
ExpiresAt string `json:"expires_at,omitempty"` // ISO 8601 timestamp
|
||||
Metadata json.RawMessage `json:"metadata,omitempty"` // User-provided metadata (json.RawMessage preserves key ordering)
|
||||
Name string `json:"name,omitempty"` // File name
|
||||
Size int64 `json:"size"` // File size in bytes
|
||||
URLs *ReplicateFileURLs `json:"urls,omitempty"` // Associated URLs
|
||||
}
|
||||
|
||||
// ReplicateFileChecksum represents checksums for a file
|
||||
type ReplicateFileChecksum struct {
|
||||
SHA256 string `json:"sha256,omitempty"` // SHA256 checksum
|
||||
}
|
||||
|
||||
// ReplicateFileURLs represents URLs associated with a file
|
||||
type ReplicateFileURLs struct {
|
||||
Get string `json:"get"` // URL to retrieve file metadata
|
||||
}
|
||||
|
||||
// ReplicateFileListResponse represents a paginated list of files
|
||||
type ReplicateFileListResponse struct {
|
||||
Next *string `json:"next,omitempty"` // URL for next page
|
||||
Previous *string `json:"previous,omitempty"` // URL for previous page
|
||||
Results []ReplicateFileResponse `json:"results"` // List of files
|
||||
}
|
||||
280
core/providers/replicate/utils.go
Normal file
280
core/providers/replicate/utils.go
Normal file
@@ -0,0 +1,280 @@
|
||||
package replicate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// isTerminalStatus checks if a prediction status is terminal (completed/failed/canceled)
|
||||
func isTerminalStatus(status ReplicatePredictionStatus) bool {
|
||||
return status == ReplicatePredictionStatusSucceeded ||
|
||||
status == ReplicatePredictionStatusFailed ||
|
||||
status == ReplicatePredictionStatusCanceled
|
||||
}
|
||||
|
||||
// checkForErrorStatus returns an error if the prediction failed
|
||||
func checkForErrorStatus(prediction *ReplicatePredictionResponse) *schemas.BifrostError {
|
||||
if prediction.Status == ReplicatePredictionStatusFailed {
|
||||
errorMsg := "prediction failed"
|
||||
if prediction.Error != nil && *prediction.Error != "" {
|
||||
errorMsg = *prediction.Error
|
||||
}
|
||||
return providerUtils.NewBifrostOperationError(
|
||||
"prediction failed",
|
||||
fmt.Errorf("%s", errorMsg))
|
||||
}
|
||||
|
||||
if prediction.Status == ReplicatePredictionStatusCanceled {
|
||||
return providerUtils.NewBifrostOperationError(
|
||||
"prediction was canceled",
|
||||
fmt.Errorf("prediction was canceled"))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parsePreferHeader parses the Prefer header to extract wait duration
|
||||
// Examples: "wait", "wait=30", "wait=60"
|
||||
// Returns the header value to use and whether sync mode should be enabled
|
||||
func parsePreferHeader(extraHeaders map[string]string) bool {
|
||||
if preferValue, exists := extraHeaders["Prefer"]; exists {
|
||||
if strings.HasPrefix(preferValue, "wait") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Streaming requests should always be async and not wait for completion,
|
||||
// so the Prefer header (which enables sync mode) must be excluded.
|
||||
func stripPreferHeader(extraHeaders map[string]string) map[string]string {
|
||||
if extraHeaders == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if Prefer header exists
|
||||
if _, exists := extraHeaders["Prefer"]; !exists {
|
||||
// No Prefer header, return original map
|
||||
return extraHeaders
|
||||
}
|
||||
|
||||
// Create new map without Prefer header
|
||||
filtered := make(map[string]string, len(extraHeaders)-1)
|
||||
for key, value := range extraHeaders {
|
||||
if key != "Prefer" {
|
||||
filtered[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
// listenToReplicateStreamURL listens to a Replicate stream URL and processes SSE events.
|
||||
// This is a reusable utility for any Replicate streaming endpoint.
|
||||
// It returns the response body stream (as io.Reader) and any error that occurred during connection.
|
||||
func listenToReplicateStreamURL(
|
||||
ctx *schemas.BifrostContext,
|
||||
client *fasthttp.Client,
|
||||
streamURL string,
|
||||
key schemas.Key,
|
||||
) (io.Reader, *fasthttp.Response, *schemas.BifrostError) {
|
||||
// Create request
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
resp.StreamBody = true
|
||||
|
||||
// Set URL and headers
|
||||
req.SetRequestURI(streamURL)
|
||||
req.Header.SetMethod(http.MethodGet)
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
req.Header.Set("Cache-Control", "no-cache")
|
||||
|
||||
// Set authorization header
|
||||
if value := key.Value.GetValue(); value != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+value)
|
||||
}
|
||||
|
||||
// Make request
|
||||
err := client.Do(req, resp)
|
||||
fasthttp.ReleaseRequest(req)
|
||||
|
||||
if err != nil {
|
||||
providerUtils.ReleaseStreamingResponse(resp)
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, nil, &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
Error: &schemas.ErrorField{
|
||||
Type: schemas.Ptr(schemas.RequestCancelled),
|
||||
Message: schemas.ErrRequestCancelled,
|
||||
Error: err,
|
||||
},
|
||||
}
|
||||
}
|
||||
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err)
|
||||
}
|
||||
return nil, nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err)
|
||||
}
|
||||
|
||||
// Extract provider response headers before status check so error responses also forward them
|
||||
if ctx != nil {
|
||||
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
||||
}
|
||||
|
||||
// Check for HTTP errors
|
||||
if resp.StatusCode() != fasthttp.StatusOK {
|
||||
defer providerUtils.ReleaseStreamingResponse(resp)
|
||||
return nil, nil, parseReplicateError(resp.Body(), resp.StatusCode())
|
||||
}
|
||||
|
||||
return resp.BodyStream(), resp, nil
|
||||
}
|
||||
|
||||
// parseDataURIImage extracts the base64 data from a data URI
|
||||
// Example: "data:image/webp;base64,UklGRmSu..." -> "UklGRmSu..."
|
||||
func parseDataURIImage(dataURI string) (base64Data string, mimeType string) {
|
||||
// Format: data:image/webp;base64,<base64-data>
|
||||
if !strings.HasPrefix(dataURI, "data:") {
|
||||
return dataURI, "" // Return as-is if not a data URI
|
||||
}
|
||||
|
||||
// Split by comma to separate metadata and data
|
||||
parts := strings.SplitN(dataURI[len("data:"):], ",", 2)
|
||||
if len(parts) != 2 {
|
||||
return dataURI, ""
|
||||
}
|
||||
|
||||
// Parse MIME type from metadata (e.g., "image/webp;base64")
|
||||
metadata := parts[0]
|
||||
metaParts := strings.Split(metadata, ";")
|
||||
if len(metaParts) > 0 {
|
||||
mimeType = metaParts[0]
|
||||
}
|
||||
|
||||
// Return the base64 data
|
||||
return parts[1], mimeType
|
||||
}
|
||||
|
||||
// versionIDPattern matches a 64-character hexadecimal string (Replicate version ID format)
|
||||
var versionIDPattern = regexp.MustCompile(`^[a-f0-9]{64}$`)
|
||||
|
||||
// isVersionID checks if a string is a Replicate version ID (64-character hex string)
|
||||
func isVersionID(s string) bool {
|
||||
return versionIDPattern.MatchString(s)
|
||||
}
|
||||
|
||||
// buildPredictionURL builds the appropriate URL for creating a prediction
|
||||
// Returns the URL for the appropriate prediction endpoint.
|
||||
func buildPredictionURL(ctx *schemas.BifrostContext, baseURL, model string, customProviderConfig *schemas.CustomProviderConfig, requestType schemas.RequestType, useDeploymentsEndpoint bool) string {
|
||||
var defaultPath string
|
||||
|
||||
if useDeploymentsEndpoint {
|
||||
defaultPath = "/v1/deployments/" + model + "/predictions"
|
||||
} else if isVersionID(model) {
|
||||
// If model is a version ID, use base predictions endpoint
|
||||
defaultPath = "/v1/predictions"
|
||||
} else {
|
||||
// If model is a name (owner/name), use model-specific endpoint
|
||||
defaultPath = "/v1/models/" + model + "/predictions"
|
||||
}
|
||||
|
||||
path, isCompleteURL := providerUtils.GetRequestPath(ctx, defaultPath, customProviderConfig, requestType)
|
||||
if isCompleteURL {
|
||||
return path
|
||||
}
|
||||
return baseURL + path
|
||||
}
|
||||
|
||||
// parseTokenUsageFromLogs extracts token counts from Replicate's logs field
|
||||
// Handles multiple log formats with varying levels of detail
|
||||
func parseTokenUsageFromLogs(logs *string, requestType schemas.RequestType) (inputTokens, outputTokens, totalTokens int, found bool) {
|
||||
if logs == nil || *logs == "" {
|
||||
return 0, 0, 0, false
|
||||
}
|
||||
|
||||
logText := *logs
|
||||
foundAny := false
|
||||
|
||||
// Pattern 1: Detailed format with input/output breakdown
|
||||
// "Input token count: 20"
|
||||
// "Input text token count: 15"
|
||||
inputPatterns := []string{
|
||||
`Input token count:\s*(\d+)`,
|
||||
`Input text token count:\s*(\d+)`,
|
||||
}
|
||||
for _, pattern := range inputPatterns {
|
||||
if matches := regexp.MustCompile(pattern).FindStringSubmatch(logText); len(matches) > 1 {
|
||||
if val, err := strconv.Atoi(matches[1]); err == nil {
|
||||
inputTokens = val
|
||||
foundAny = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// "Input image token count: 0" (for image generation)
|
||||
if matches := regexp.MustCompile(`Input image token count:\s*(\d+)`).FindStringSubmatch(logText); len(matches) > 1 {
|
||||
if val, err := strconv.Atoi(matches[1]); err == nil {
|
||||
inputTokens += val // Add to text input tokens
|
||||
foundAny = true
|
||||
}
|
||||
}
|
||||
|
||||
// "Output token count: 28"
|
||||
if matches := regexp.MustCompile(`Output token count:\s*(\d+)`).FindStringSubmatch(logText); len(matches) > 1 {
|
||||
if val, err := strconv.Atoi(matches[1]); err == nil {
|
||||
outputTokens = val
|
||||
foundAny = true
|
||||
}
|
||||
}
|
||||
|
||||
// "Total token count: 48"
|
||||
if matches := regexp.MustCompile(`Total token count:\s*(\d+)`).FindStringSubmatch(logText); len(matches) > 1 {
|
||||
if val, err := strconv.Atoi(matches[1]); err == nil {
|
||||
totalTokens = val
|
||||
foundAny = true
|
||||
}
|
||||
}
|
||||
|
||||
// Pattern 2: Simple "Tokens: X" format (ambiguous - need heuristic)
|
||||
// Only use if detailed format not found
|
||||
if !foundAny {
|
||||
if matches := regexp.MustCompile(`Tokens:\s*(\d+)`).FindStringSubmatch(logText); len(matches) > 1 {
|
||||
if val, err := strconv.Atoi(matches[1]); err == nil {
|
||||
// Heuristic based on response type
|
||||
switch requestType {
|
||||
case schemas.ImageGenerationRequest:
|
||||
// For image generation, "Tokens: X" typically means output tokens
|
||||
outputTokens = val
|
||||
totalTokens = val
|
||||
case schemas.TextCompletionRequest, schemas.ChatCompletionRequest, schemas.ResponsesRequest:
|
||||
// For text, unclear - could be total or output
|
||||
// Conservative approach: treat as total tokens
|
||||
totalTokens = val
|
||||
default:
|
||||
// Unknown type - treat as total
|
||||
totalTokens = val
|
||||
}
|
||||
foundAny = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we found input/output but not total, compute it
|
||||
if foundAny && totalTokens == 0 {
|
||||
totalTokens = inputTokens + outputTokens
|
||||
}
|
||||
|
||||
return inputTokens, outputTokens, totalTokens, foundAny
|
||||
}
|
||||
152
core/providers/replicate/videos.go
Normal file
152
core/providers/replicate/videos.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package replicate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func ToReplicateVideoGenerationInput(bifrostReq *schemas.BifrostVideoGenerationRequest) (*ReplicatePredictionRequest, error) {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil {
|
||||
return nil, fmt.Errorf("bifrost request or input is nil")
|
||||
}
|
||||
|
||||
input := &ReplicatePredictionRequestInput{
|
||||
Prompt: &bifrostReq.Input.Prompt,
|
||||
}
|
||||
|
||||
if bifrostReq.Input.InputReference != nil {
|
||||
// convert input reference to base64
|
||||
// if provider is openai, set input reference to base64
|
||||
sanitizedURL, err := schemas.SanitizeImageURL(*bifrostReq.Input.InputReference)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid input reference: %w", err)
|
||||
}
|
||||
if strings.HasPrefix(bifrostReq.Model, string(schemas.OpenAI)) {
|
||||
input.InputReference = schemas.Ptr(sanitizedURL)
|
||||
} else {
|
||||
input.Image = schemas.Ptr(sanitizedURL)
|
||||
}
|
||||
}
|
||||
|
||||
// Map parameters if available
|
||||
if bifrostReq.Params != nil {
|
||||
params := bifrostReq.Params
|
||||
|
||||
if params.Seconds != nil {
|
||||
seconds, err := strconv.Atoi(*params.Seconds)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid seconds value: %w", err)
|
||||
}
|
||||
input.Duration = &seconds
|
||||
}
|
||||
|
||||
if params.Seed != nil {
|
||||
input.Seed = params.Seed
|
||||
}
|
||||
|
||||
if params.NegativePrompt != nil {
|
||||
input.NegativePrompt = params.NegativePrompt
|
||||
}
|
||||
|
||||
if params.ExtraParams != nil {
|
||||
input.ExtraParams = params.ExtraParams
|
||||
}
|
||||
}
|
||||
|
||||
request := &ReplicatePredictionRequest{
|
||||
Input: input,
|
||||
}
|
||||
|
||||
// Check if model is a version ID and set version field accordingly
|
||||
if isVersionID(bifrostReq.Model) {
|
||||
request.Version = &bifrostReq.Model
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil {
|
||||
request.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
if webhook, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["webhook"]); ok {
|
||||
delete(request.ExtraParams, "webhook")
|
||||
request.Webhook = webhook
|
||||
}
|
||||
if webhookEventsFilter, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["webhook_events_filter"]); ok {
|
||||
delete(request.ExtraParams, "webhook_events_filter")
|
||||
request.WebhookEventsFilter = webhookEventsFilter
|
||||
}
|
||||
}
|
||||
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func ToBifrostVideoGenerationResponse(prediction *ReplicatePredictionResponse) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
if prediction == nil {
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: true,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "prediction response is nil",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
response := &schemas.BifrostVideoGenerationResponse{
|
||||
ID: prediction.ID,
|
||||
CreatedAt: ParseReplicateTimestamp(prediction.CreatedAt),
|
||||
Model: prediction.Model,
|
||||
Object: "video",
|
||||
}
|
||||
|
||||
// Map Replicate status to Bifrost video status.
|
||||
switch prediction.Status {
|
||||
case ReplicatePredictionStatusStarting:
|
||||
response.Status = schemas.VideoStatusQueued
|
||||
case ReplicatePredictionStatusProcessing:
|
||||
response.Status = schemas.VideoStatusInProgress
|
||||
case ReplicatePredictionStatusSucceeded:
|
||||
response.Status = schemas.VideoStatusCompleted
|
||||
case ReplicatePredictionStatusFailed, ReplicatePredictionStatusCanceled:
|
||||
response.Status = schemas.VideoStatusFailed
|
||||
default:
|
||||
response.Status = schemas.VideoStatusQueued
|
||||
}
|
||||
|
||||
// Surface provider error details on failed terminal states.
|
||||
if response.Status == schemas.VideoStatusFailed {
|
||||
errorMsg := "prediction failed"
|
||||
errorCode := string(prediction.Status)
|
||||
if prediction.Error != nil && *prediction.Error != "" {
|
||||
errorMsg = *prediction.Error
|
||||
}
|
||||
response.Error = &schemas.VideoCreateError{
|
||||
Code: errorCode,
|
||||
Message: errorMsg,
|
||||
}
|
||||
}
|
||||
|
||||
if prediction.CompletedAt != nil {
|
||||
response.CompletedAt = schemas.Ptr(ParseReplicateTimestamp(*prediction.CompletedAt))
|
||||
}
|
||||
|
||||
// Convert output to ImageData
|
||||
// Replicate output can be either a string (single URL) or array of strings
|
||||
if prediction.Output != nil {
|
||||
if prediction.Output.OutputStr != nil && *prediction.Output.OutputStr != "" {
|
||||
response.Videos = append(response.Videos, schemas.VideoOutput{
|
||||
Type: schemas.VideoOutputTypeURL,
|
||||
URL: schemas.Ptr(*prediction.Output.OutputStr),
|
||||
ContentType: "video/mp4",
|
||||
})
|
||||
} else if len(prediction.Output.OutputArray) > 0 {
|
||||
for _, url := range prediction.Output.OutputArray {
|
||||
response.Videos = append(response.Videos, schemas.VideoOutput{
|
||||
Type: schemas.VideoOutputTypeURL,
|
||||
URL: schemas.Ptr(url),
|
||||
ContentType: "video/mp4",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
Reference in New Issue
Block a user