first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

View File

@@ -0,0 +1,315 @@
package replicate
import (
"fmt"
"slices"
"strings"
"time"
schemas "github.com/maximhq/bifrost/core/schemas"
)
// unsupportedSystemPromptModels is a set of models that don't support the system_prompt field.
var unsupportedSystemPromptModels = []string{
"meta/meta-llama-3-8b",
"meta/llama-2-70b",
"openai/gpt-oss-20b",
"openai/o1-mini",
"xai/grok-4",
}
func ToReplicateChatRequest(bifrostReq *schemas.BifrostChatRequest) (*ReplicatePredictionRequest, error) {
if bifrostReq == nil || bifrostReq.Input == nil {
return nil, fmt.Errorf("bifrost request is nil or input is nil")
}
// Build the input from messages
input := &ReplicatePredictionRequestInput{}
isGPT5Structured := strings.HasPrefix(bifrostReq.Model, string(schemas.OpenAI)) && strings.Contains(bifrostReq.Model, "gpt-5-structured")
// openai models support messages
if len(bifrostReq.Input) > 0 && strings.HasPrefix(bifrostReq.Model, string(schemas.OpenAI)) {
if isGPT5Structured {
responsesMessages := []schemas.ResponsesMessage{}
for _, msg := range bifrostReq.Input {
responsesMessages = append(responsesMessages, msg.ToResponsesMessages()...)
}
if len(responsesMessages) > 0 {
input.InputItemList = responsesMessages
}
} else {
input.Messages = bifrostReq.Input
}
} else {
// Extract system prompt and build conversation prompt
var systemPrompt string
var conversationParts []string
var imageInput []string
for _, msg := range bifrostReq.Input {
if msg.Content == nil {
continue
}
// Get message content as string
var contentStr string
if msg.Content.ContentStr != nil {
contentStr = *msg.Content.ContentStr
} else if msg.Content.ContentBlocks != nil {
// Concatenate text blocks only
var textParts []string
for _, block := range msg.Content.ContentBlocks {
if block.Text != nil && *block.Text != "" {
textParts = append(textParts, *block.Text)
}
if block.ImageURLStruct != nil && block.ImageURLStruct.URL != "" {
imageInput = append(imageInput, block.ImageURLStruct.URL)
}
}
contentStr = strings.Join(textParts, "\n")
}
if contentStr == "" {
continue
}
// Handle different roles
switch msg.Role {
case schemas.ChatMessageRoleSystem:
if systemPrompt == "" {
systemPrompt = contentStr
} else {
systemPrompt += "\n" + contentStr
}
case schemas.ChatMessageRoleUser:
conversationParts = append(conversationParts, contentStr)
case schemas.ChatMessageRoleAssistant:
// For assistant messages, we can include them in the conversation context
conversationParts = append(conversationParts, contentStr)
}
}
// Set system prompt if present and model supports it
modelSupportsSystemPrompt := supportsSystemPrompt(bifrostReq.Model)
if systemPrompt != "" {
if modelSupportsSystemPrompt {
// Model supports system_prompt field
input.SystemPrompt = &systemPrompt
} else {
// Model doesn't support system_prompt - prepend to prompt
if len(conversationParts) > 0 {
// Prepend system prompt to conversation
conversationParts = append([]string{systemPrompt}, conversationParts...)
} else {
// No conversation parts, use system prompt as the prompt
conversationParts = []string{systemPrompt}
}
}
}
// Build the final prompt from conversation parts
if len(conversationParts) > 0 {
prompt := strings.Join(conversationParts, "\n\n")
input.Prompt = &prompt
}
// Ensure we have at least some content (prompt or system prompt)
if input.Prompt == nil && input.SystemPrompt == nil {
return nil, fmt.Errorf("no content found in chat messages - need at least one user or system message")
}
if len(imageInput) > 0 {
input.ImageInput = imageInput
}
}
// Map parameters if present
if bifrostReq.Params != nil {
params := bifrostReq.Params
// Temperature
if params.Temperature != nil {
input.Temperature = params.Temperature
}
// Top P
if params.TopP != nil {
input.TopP = params.TopP
}
// Max tokens - use max_completion_tokens if available
if params.MaxCompletionTokens != nil {
if isGPT5Structured {
input.MaxOutputTokens = params.MaxCompletionTokens
} else if strings.HasPrefix(bifrostReq.Model, string(schemas.OpenAI)) {
input.MaxCompletionTokens = params.MaxCompletionTokens
} else {
input.MaxTokens = params.MaxCompletionTokens
}
}
// Presence penalty
if params.PresencePenalty != nil {
input.PresencePenalty = params.PresencePenalty
}
// Frequency penalty
if params.FrequencyPenalty != nil {
input.FrequencyPenalty = params.FrequencyPenalty
}
// Seed
if params.Seed != nil {
input.Seed = params.Seed
}
if params.Reasoning != nil {
if params.Reasoning.Effort != nil {
input.ReasoningEffort = params.Reasoning.Effort
}
}
if isGPT5Structured {
if len(params.Tools) > 0 {
responsesTools := []schemas.ResponsesTool{}
for _, tool := range params.Tools {
responsesTools = append(responsesTools, *tool.ToResponsesTool())
}
if len(responsesTools) > 0 {
input.Tools = responsesTools
}
}
}
if params.ExtraParams != nil {
input.ExtraParams = params.ExtraParams
}
}
// Check if model is a version ID and set version field accordingly
req := &ReplicatePredictionRequest{
Input: input,
}
if isVersionID(bifrostReq.Model) {
req.Version = &bifrostReq.Model
}
if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil {
if webhook, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["webhook"]); ok {
req.Webhook = webhook
}
if webhookEventsFilter, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["webhook_events_filter"]); ok {
req.WebhookEventsFilter = webhookEventsFilter
}
}
return req, nil
}
// ToBifrostChatResponse converts a Replicate prediction response to Bifrost format
func (response *ReplicatePredictionResponse) ToBifrostChatResponse() *schemas.BifrostChatResponse {
if response == nil {
return nil
}
// Parse timestamps
createdAt := ParseReplicateTimestamp(response.CreatedAt)
if createdAt == 0 {
createdAt = time.Now().Unix()
}
// Initialize Bifrost response
bifrostResponse := &schemas.BifrostChatResponse{
ID: response.ID,
Model: response.Model,
Object: "chat.completion",
Created: int(createdAt),
}
// Convert output to content
var contentStr *string
if response.Output != nil {
if response.Output.OutputStr != nil {
contentStr = response.Output.OutputStr
} else if response.Output.OutputArray != nil {
// Join array of strings into a single string
joined := strings.Join(response.Output.OutputArray, "")
contentStr = &joined
} else if response.Output.OutputObject != nil && response.Output.OutputObject.Text != nil {
contentStr = response.Output.OutputObject.Text
}
}
// Create message content
messageContent := schemas.ChatMessageContent{
ContentStr: contentStr,
}
// Create the assistant message
message := schemas.ChatMessage{
Role: schemas.ChatMessageRoleAssistant,
Content: &messageContent,
}
// Determine finish reason based on status
var finishReason *string
switch response.Status {
case ReplicatePredictionStatusSucceeded:
reason := "stop"
finishReason = &reason
case ReplicatePredictionStatusFailed:
reason := "error"
finishReason = &reason
case ReplicatePredictionStatusCanceled:
reason := "stop"
finishReason = &reason
}
// Create choice
choice := schemas.BifrostResponseChoice{
Index: 0,
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
Message: &message,
},
FinishReason: finishReason,
}
bifrostResponse.Choices = []schemas.BifrostResponseChoice{choice}
// Extract usage information from logs
if response.Logs != nil {
inputTokens, outputTokens, totalTokens, found := parseTokenUsageFromLogs(response.Logs, schemas.ChatCompletionRequest)
if found {
bifrostResponse.Usage = &schemas.BifrostLLMUsage{
PromptTokens: inputTokens,
CompletionTokens: outputTokens,
TotalTokens: totalTokens,
}
}
}
return bifrostResponse
}
// supportsSystemPrompt checks if a model supports the system_prompt field.
func supportsSystemPrompt(model string) bool {
// Normalize model name to lowercase for comparison
modelLower := strings.ToLower(model)
// Extract model identifier (handle both "owner/name" and "owner/name:version" formats)
modelIdentifier := modelLower
if idx := strings.Index(modelLower, ":"); idx != -1 {
modelIdentifier = modelLower[:idx]
}
// All deepseek models don't support system prompt
if strings.HasPrefix(modelIdentifier, "deepseek-ai/deepseek") {
return false
}
isUnsupported := slices.Contains(unsupportedSystemPromptModels, modelIdentifier)
return !isUnsupported
}

View File

@@ -0,0 +1,29 @@
package replicate
import (
"github.com/bytedance/sonic"
schemas "github.com/maximhq/bifrost/core/schemas"
)
// parseReplicateError parses Replicate API error response
func parseReplicateError(body []byte, statusCode int) *schemas.BifrostError {
var replicateErr ReplicateError
if err := sonic.Unmarshal(body, &replicateErr); err == nil && replicateErr.Detail != "" {
return &schemas.BifrostError{
IsBifrostError: false,
StatusCode: &statusCode,
Error: &schemas.ErrorField{
Message: replicateErr.Detail,
},
}
}
// Fallback to generic error
return &schemas.BifrostError{
IsBifrostError: false,
StatusCode: &statusCode,
Error: &schemas.ErrorField{
Message: string(body),
},
}
}

View File

@@ -0,0 +1,89 @@
package replicate
import (
"time"
"github.com/maximhq/bifrost/core/schemas"
)
// Replicate File API Converters
// ToBifrostFileStatus converts Replicate file status to Bifrost file status.
// Replicate doesn't explicitly provide status, so we infer from the response.
func ToBifrostFileStatus(fileResp *ReplicateFileResponse) schemas.FileStatus {
// If file has all required fields and is accessible, it's processed
if fileResp.ID != "" && fileResp.Size > 0 {
return schemas.FileStatusProcessed
}
return schemas.FileStatusUploaded
}
// ToBifrostFileUploadResponse converts Replicate file response to Bifrost file upload response.
func (r *ReplicateFileResponse) ToBifrostFileUploadResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileUploadResponse {
resp := &schemas.BifrostFileUploadResponse{
ID: r.ID,
Object: "file",
Bytes: r.Size,
CreatedAt: ParseReplicateTimestamp(r.CreatedAt),
Filename: r.Name,
Purpose: schemas.FilePurposeBatch, // Replicate uses files primarily for batch/general purposes
Status: ToBifrostFileStatus(r),
StorageBackend: schemas.FileStorageAPI,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
},
}
// Add ExpiresAt if present
if r.ExpiresAt != "" {
expiresAt := ParseReplicateTimestamp(r.ExpiresAt)
if expiresAt > 0 {
resp.ExpiresAt = &expiresAt
}
}
if sendBackRawRequest {
resp.ExtraFields.RawRequest = rawRequest
}
if sendBackRawResponse {
resp.ExtraFields.RawResponse = rawResponse
}
return resp
}
// ToBifrostFileRetrieveResponse converts Replicate file response to Bifrost file retrieve response.
func (r *ReplicateFileResponse) ToBifrostFileRetrieveResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileRetrieveResponse {
resp := &schemas.BifrostFileRetrieveResponse{
ID: r.ID,
Object: "file",
Bytes: r.Size,
CreatedAt: ParseReplicateTimestamp(r.CreatedAt),
Filename: r.Name,
Purpose: schemas.FilePurposeBatch,
Status: ToBifrostFileStatus(r),
StorageBackend: schemas.FileStorageAPI,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
},
}
// Add ExpiresAt if present
if r.ExpiresAt != "" {
expiresAt := ParseReplicateTimestamp(r.ExpiresAt)
if expiresAt > 0 {
resp.ExpiresAt = &expiresAt
}
}
if sendBackRawRequest {
resp.ExtraFields.RawRequest = rawRequest
}
if sendBackRawResponse {
resp.ExtraFields.RawResponse = rawResponse
}
return resp
}

View File

@@ -0,0 +1,292 @@
package replicate
import (
"strings"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
)
// modelInputImageFieldMap maps model identifiers to their input image field names.
var modelInputImageFieldMap = map[string]string{
// image_prompt models
"black-forest-labs/flux-1.1-pro": "image_prompt",
"black-forest-labs/flux-1.1-pro-ultra": "image_prompt",
"black-forest-labs/flux-pro": "image_prompt",
"black-forest-labs/flux-1.1-pro-ultra-finetuned": "image_prompt",
// input_image models (kontext variants)
"black-forest-labs/flux-kontext-pro": "input_image",
"black-forest-labs/flux-kontext-max": "input_image",
"black-forest-labs/flux-kontext-dev": "input_image",
// image models
"black-forest-labs/flux-dev": "image",
"black-forest-labs/flux-fill-pro": "image",
"black-forest-labs/flux-dev-lora": "image",
"black-forest-labs/flux-krea-dev": "image",
}
// ToReplicateImageGenerationInput converts a Bifrost image generation request to Replicate prediction input
func ToReplicateImageGenerationInput(bifrostReq *schemas.BifrostImageGenerationRequest) *ReplicatePredictionRequest {
if bifrostReq == nil || bifrostReq.Input == nil {
return nil
}
input := &ReplicatePredictionRequestInput{
Prompt: &bifrostReq.Input.Prompt,
}
// Map parameters if available
if bifrostReq.Params != nil {
params := bifrostReq.Params
if bifrostReq.Params.N != nil {
input.NumberOfImages = bifrostReq.Params.N
}
if params.AspectRatio != nil {
input.AspectRatio = params.AspectRatio
}
if params.Size != nil {
aspectRatio, imageSize := providerUtils.ConvertSizeToAspectRatioAndResolution(*params.Size)
_, hasExplicitResolution := params.ExtraParams["resolution"]
if params.AspectRatio == nil && aspectRatio != "" {
input.AspectRatio = &aspectRatio
}
if imageSize != "" && !hasExplicitResolution {
input.Resolution = &imageSize
}
}
// Map OutputFormat
if params.OutputFormat != nil {
input.OutputFormat = params.OutputFormat
}
if params.Quality != nil {
input.Quality = params.Quality
}
if params.Background != nil {
input.Background = params.Background
}
// Map Seed
if params.Seed != nil {
input.Seed = params.Seed
}
// Map NegativePrompt
if params.NegativePrompt != nil {
input.NegativePrompt = params.NegativePrompt
}
// Map NumInferenceSteps
if params.NumInferenceSteps != nil {
input.NumInferenceStep = params.NumInferenceSteps
}
if params.ExtraParams != nil {
input.ExtraParams = params.ExtraParams
}
}
request := &ReplicatePredictionRequest{
Input: input,
}
// Check if model is a version ID and set version field accordingly
if isVersionID(bifrostReq.Model) {
request.Version = &bifrostReq.Model
}
if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil {
if webhook, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["webhook"]); ok {
request.Webhook = webhook
}
if webhookEventsFilter, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["webhook_events_filter"]); ok {
request.WebhookEventsFilter = webhookEventsFilter
}
}
return request
}
// ToBifrostImageGenerationResponse converts a Replicate prediction response to Bifrost format
func ToBifrostImageGenerationResponse(
prediction *ReplicatePredictionResponse,
) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
if prediction == nil {
return nil, &schemas.BifrostError{
IsBifrostError: true,
Error: &schemas.ErrorField{
Message: "prediction response is nil",
},
}
}
response := &schemas.BifrostImageGenerationResponse{
ID: prediction.ID,
Created: ParseReplicateTimestamp(prediction.CreatedAt),
Model: prediction.Model,
Data: []schemas.ImageData{},
}
// Convert output to ImageData
// Replicate output can be either a string (single URL) or array of strings
if prediction.Output != nil {
if prediction.Output.OutputStr != nil && *prediction.Output.OutputStr != "" {
response.Data = append(response.Data, schemas.ImageData{
URL: *prediction.Output.OutputStr,
Index: 0,
})
} else if len(prediction.Output.OutputArray) > 0 {
for i, url := range prediction.Output.OutputArray {
response.Data = append(response.Data, schemas.ImageData{
URL: url,
Index: i,
})
}
}
}
// Extract usage information from logs
if prediction.Logs != nil {
inputTokens, outputTokens, totalTokens, found := parseTokenUsageFromLogs(prediction.Logs, schemas.ImageGenerationRequest)
if found {
response.Usage = &schemas.ImageUsage{
InputTokens: inputTokens,
OutputTokens: outputTokens,
TotalTokens: totalTokens,
}
}
}
return response, nil
}
// getInputImageFieldName returns the appropriate input image field name based on the model.
// Uses O(1) map lookup for high RPS performance.
func getInputImageFieldName(model string) string {
// Normalize model name to lowercase for comparison
modelLower := strings.ToLower(model)
// Extract model identifier (handle both "owner/name" and "owner/name:version" formats)
modelIdentifier := modelLower
if before, _, ok := strings.Cut(modelLower, ":"); ok {
modelIdentifier = before
}
if fieldName, exists := modelInputImageFieldMap[modelIdentifier]; exists {
return fieldName
}
// Default to input_images for all other models
return "input_images"
}
// ToReplicateImageEditInput converts a Bifrost image edit request to Replicate prediction input
func ToReplicateImageEditInput(bifrostReq *schemas.BifrostImageEditRequest) *ReplicatePredictionRequest {
if bifrostReq == nil || bifrostReq.Input == nil {
return nil
}
input := &ReplicatePredictionRequestInput{
Prompt: &bifrostReq.Input.Prompt,
}
// Map image URLs - Replicate requires image URLs, not file bytes
if len(bifrostReq.Input.Images) > 0 {
images := make([]string, 0, len(bifrostReq.Input.Images))
for _, img := range bifrostReq.Input.Images {
if len(img.Image) > 0 {
images = append(images, providerUtils.FileBytesToBase64DataURL(img.Image))
}
}
if len(images) > 0 {
// Determine the appropriate field based on model
fieldName := getInputImageFieldName(bifrostReq.Model)
switch fieldName {
case "image_prompt":
// For flux-1.1-pro variants: use first image as image_prompt
input.ImagePrompt = &images[0]
case "input_image":
// For flux-kontext variants: use first image as input_image
input.InputImage = &images[0]
case "image":
// For flux-dev variants: use first image as image field
input.Image = &images[0]
case "input_images":
// For all other models: use input_images array (preserves multi-image support)
input.InputImages = images
}
}
}
// Map parameters if available
if bifrostReq.Params != nil {
params := bifrostReq.Params
if params.N != nil {
input.NumberOfImages = params.N
}
if params.Size != nil {
aspectRatio, imageSize := providerUtils.ConvertSizeToAspectRatioAndResolution(*params.Size)
_, hasExplicitAspectRatio := params.ExtraParams["aspect_ratio"]
_, hasExplicitResolution := params.ExtraParams["resolution"]
if aspectRatio != "" && !hasExplicitAspectRatio {
input.AspectRatio = &aspectRatio
}
if imageSize != "" && !hasExplicitResolution {
input.Resolution = &imageSize
}
}
if params.OutputFormat != nil {
input.OutputFormat = params.OutputFormat
}
if params.Quality != nil {
input.Quality = params.Quality
}
if params.Background != nil {
input.Background = params.Background
}
if params.Seed != nil {
input.Seed = params.Seed
}
if params.NegativePrompt != nil {
input.NegativePrompt = params.NegativePrompt
}
if params.NumInferenceSteps != nil {
input.NumInferenceStep = params.NumInferenceSteps
}
if params.ExtraParams != nil {
input.ExtraParams = params.ExtraParams
}
}
request := &ReplicatePredictionRequest{
Input: input,
}
// Check if model is a version ID and set version field accordingly
if isVersionID(bifrostReq.Model) {
request.Version = &bifrostReq.Model
}
return request
}

View File

@@ -0,0 +1,75 @@
package replicate
import (
"strings"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// ToBifrostListModelsResponse converts Replicate deployments to a Bifrost list models response.
// Replicate model IDs are composite: "{owner}/{name}" (e.g. "stability-ai/stable-diffusion").
func ToBifrostListModelsResponse(
deploymentsResponse *ReplicateDeploymentListResponse,
providerKey schemas.ModelProvider,
allowedModels schemas.WhiteList,
blacklistedModels schemas.BlackList,
aliases map[string]string,
unfiltered bool,
) *schemas.BifrostListModelsResponse {
bifrostResponse := &schemas.BifrostListModelsResponse{
Data: make([]schemas.Model, 0),
}
pipeline := &providerUtils.ListModelsPipeline{
AllowedModels: allowedModels,
BlacklistedModels: blacklistedModels,
Aliases: aliases,
Unfiltered: unfiltered,
ProviderKey: providerKey,
MatchFns: providerUtils.DefaultMatchFns(),
}
if pipeline.ShouldEarlyExit() {
return bifrostResponse
}
included := make(map[string]bool)
if deploymentsResponse != nil {
for _, deployment := range deploymentsResponse.Results {
// Replicate model IDs are composite owner/name
deploymentID := deployment.Owner + "/" + deployment.Name
var created *int64
if deployment.CurrentRelease != nil && deployment.CurrentRelease.CreatedAt != "" {
createdTimestamp := ParseReplicateTimestamp(deployment.CurrentRelease.CreatedAt)
if createdTimestamp > 0 {
created = schemas.Ptr(createdTimestamp)
}
}
for _, result := range pipeline.FilterModel(deploymentID) {
bifrostModel := schemas.Model{
ID: string(providerKey) + "/" + result.ResolvedID,
Name: schemas.Ptr(deployment.Name),
OwnedBy: schemas.Ptr(deployment.Owner),
Created: created,
}
if result.AliasValue != "" {
bifrostModel.Alias = schemas.Ptr(result.AliasValue)
}
bifrostResponse.Data = append(bifrostResponse.Data, bifrostModel)
included[strings.ToLower(result.ResolvedID)] = true
}
}
if deploymentsResponse.Next != nil {
bifrostResponse.NextPageToken = *deploymentsResponse.Next
}
}
bifrostResponse.Data = append(bifrostResponse.Data,
pipeline.BackfillModels(included)...)
return bifrostResponse
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,300 @@
package replicate
import (
"fmt"
"strings"
"time"
"github.com/maximhq/bifrost/core/schemas"
)
func ToReplicateResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) (*ReplicatePredictionRequest, error) {
if bifrostReq == nil {
return nil, fmt.Errorf("bifrost request is nil")
}
input := &ReplicatePredictionRequestInput{}
if strings.HasPrefix(bifrostReq.Model, "openai/") && strings.Contains(bifrostReq.Model, "gpt-5-structured") {
// handle responses style request
if len(bifrostReq.Input) > 0 {
input.InputItemList = bifrostReq.Input
}
if bifrostReq.Params != nil {
if bifrostReq.Params.Instructions != nil {
input.Instructions = bifrostReq.Params.Instructions
}
if bifrostReq.Params.Tools != nil {
input.Tools = bifrostReq.Params.Tools
}
if bifrostReq.Params.MaxOutputTokens != nil {
input.MaxOutputTokens = bifrostReq.Params.MaxOutputTokens
}
if bifrostReq.Params.Text != nil {
input.JsonSchema = bifrostReq.Params.Text
}
if bifrostReq.Params.ExtraParams != nil {
input.ExtraParams = bifrostReq.Params.ExtraParams
}
}
} else {
// handle chat style request (same logic as chat converter)
if len(bifrostReq.Input) > 0 {
// if model is from openai family, use messages
if strings.HasPrefix(bifrostReq.Model, string(schemas.OpenAI)) {
input.Messages = schemas.ToChatMessages(bifrostReq.Input)
} else {
// convert input to prompt and system prompt
var systemPrompt string
var conversationParts []string
var imageInput []string
for _, msg := range bifrostReq.Input {
if msg.Content == nil {
continue
}
// Get message content as string
var contentStr string
if msg.Content.ContentStr != nil {
contentStr = *msg.Content.ContentStr
} else if msg.Content.ContentBlocks != nil {
// Concatenate text blocks only
var textParts []string
for _, block := range msg.Content.ContentBlocks {
if block.Text != nil && *block.Text != "" {
textParts = append(textParts, *block.Text)
}
if block.ResponsesInputMessageContentBlockImage != nil && block.ResponsesInputMessageContentBlockImage.ImageURL != nil && *block.ResponsesInputMessageContentBlockImage.ImageURL != "" {
imageInput = append(imageInput, *block.ResponsesInputMessageContentBlockImage.ImageURL)
}
}
contentStr = strings.Join(textParts, "\n")
}
if contentStr == "" {
continue
}
// Handle different roles
if msg.Role != nil {
switch *msg.Role {
case schemas.ResponsesInputMessageRoleSystem:
if systemPrompt == "" {
systemPrompt = contentStr
} else {
systemPrompt += "\n" + contentStr
}
case schemas.ResponsesInputMessageRoleUser:
conversationParts = append(conversationParts, contentStr)
case schemas.ResponsesInputMessageRoleAssistant:
// For assistant messages, we can include them in the conversation context
conversationParts = append(conversationParts, contentStr)
}
}
}
// Set system prompt if present and model supports it
modelSupportsSystemPrompt := supportsSystemPrompt(bifrostReq.Model)
if systemPrompt != "" {
if modelSupportsSystemPrompt {
// Model supports system_prompt field
input.SystemPrompt = &systemPrompt
} else {
// Model doesn't support system_prompt - prepend to prompt
if len(conversationParts) > 0 {
// Prepend system prompt to conversation
conversationParts = append([]string{systemPrompt}, conversationParts...)
} else {
// No conversation parts, use system prompt as the prompt
conversationParts = []string{systemPrompt}
}
}
}
// Build the final prompt from conversation parts
if len(conversationParts) > 0 {
prompt := strings.Join(conversationParts, "\n\n")
input.Prompt = &prompt
}
if len(imageInput) > 0 {
input.ImageInput = imageInput
}
}
}
// Map parameters if present
if bifrostReq.Params != nil {
params := bifrostReq.Params
// Temperature
if params.Temperature != nil {
input.Temperature = params.Temperature
}
// Top P
if params.TopP != nil {
input.TopP = params.TopP
}
// Max tokens - use max_completion_tokens if available
if params.MaxOutputTokens != nil {
if strings.HasPrefix(bifrostReq.Model, string(schemas.OpenAI)) {
input.MaxCompletionTokens = params.MaxOutputTokens
} else {
input.MaxTokens = params.MaxOutputTokens
}
}
// Reasoning effort
if params.Reasoning != nil {
if params.Reasoning.Effort != nil {
input.ReasoningEffort = params.Reasoning.Effort
}
}
if params.Instructions != nil && *params.Instructions != "" {
if supportsSystemPrompt(bifrostReq.Model) {
if input.SystemPrompt == nil {
input.SystemPrompt = params.Instructions
}
} else {
if input.Prompt != nil && *input.Prompt != "" {
prefixed := *params.Instructions + "\n\n" + *input.Prompt
input.Prompt = schemas.Ptr(prefixed)
} else if input.Prompt == nil {
input.Prompt = params.Instructions
}
}
}
if params.ExtraParams != nil {
input.ExtraParams = params.ExtraParams
}
}
}
// Check if model is a version ID and set version field accordingly
req := &ReplicatePredictionRequest{
Input: input,
}
if isVersionID(bifrostReq.Model) {
req.Version = &bifrostReq.Model
}
if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil {
if webhook, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["webhook"]); ok {
req.Webhook = webhook
}
if webhookEventsFilter, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["webhook_events_filter"]); ok {
req.WebhookEventsFilter = webhookEventsFilter
}
}
return req, nil
}
func (response *ReplicatePredictionResponse) ToBifrostResponsesResponse() *schemas.BifrostResponsesResponse {
if response == nil {
return nil
}
// Parse timestamps
createdAt := ParseReplicateTimestamp(response.CreatedAt)
if createdAt == 0 {
createdAt = time.Now().Unix()
}
var completedAt *int
if response.CompletedAt != nil {
completed := int(ParseReplicateTimestamp(*response.CompletedAt))
if completed > 0 {
completedAt = &completed
}
}
// Initialize Bifrost response
bifrostResponse := &schemas.BifrostResponsesResponse{
ID: schemas.Ptr(response.ID),
Model: response.Model,
CreatedAt: int(createdAt),
CompletedAt: completedAt,
}
// Convert output to ResponsesMessage
var outputMessages []schemas.ResponsesMessage
if response.Output != nil {
var contentStr *string
// Handle different output types
if response.Output.OutputStr != nil {
contentStr = response.Output.OutputStr
} else if response.Output.OutputArray != nil {
// Join array of strings into a single string
joined := strings.Join(response.Output.OutputArray, "")
contentStr = &joined
} else if response.Output.OutputObject != nil && response.Output.OutputObject.Text != nil {
// Use text field from OutputObject
contentStr = response.Output.OutputObject.Text
}
if contentStr != nil && *contentStr != "" {
messageType := schemas.ResponsesMessageTypeMessage
role := schemas.ResponsesInputMessageRoleAssistant
outputMsg := schemas.ResponsesMessage{
Type: &messageType,
Role: &role,
Content: &schemas.ResponsesMessageContent{
ContentStr: contentStr,
},
}
outputMessages = append(outputMessages, outputMsg)
}
}
bifrostResponse.Output = outputMessages
// Set status based on prediction status
var status string
switch response.Status {
case ReplicatePredictionStatusSucceeded:
status = "completed"
case ReplicatePredictionStatusFailed:
status = "failed"
case ReplicatePredictionStatusCanceled:
status = "cancelled"
case ReplicatePredictionStatusProcessing:
status = "in_progress"
case ReplicatePredictionStatusStarting:
status = "queued"
default:
status = string(response.Status)
}
bifrostResponse.Status = &status
// Set error if present
if response.Error != nil && *response.Error != "" {
bifrostResponse.Error = &schemas.ResponsesResponseError{
Code: "provider_error",
Message: *response.Error,
}
}
// Extract usage information from logs
if response.Logs != nil {
inputTokens, outputTokens, totalTokens, found := parseTokenUsageFromLogs(response.Logs, schemas.ResponsesRequest)
if found {
bifrostResponse.Usage = &schemas.ResponsesResponseUsage{
InputTokens: inputTokens,
OutputTokens: outputTokens,
TotalTokens: totalTokens,
}
}
}
return bifrostResponse
}

View File

@@ -0,0 +1,148 @@
package replicate
import (
"fmt"
"strings"
schemas "github.com/maximhq/bifrost/core/schemas"
)
func ToReplicateTextRequest(bifrostReq *schemas.BifrostTextCompletionRequest) (*ReplicatePredictionRequest, error) {
if bifrostReq == nil || bifrostReq.Input == nil {
return nil, fmt.Errorf("bifrost request is nil or prompt is nil")
}
input := &ReplicatePredictionRequestInput{}
if bifrostReq.Input.PromptStr != nil {
input.Prompt = bifrostReq.Input.PromptStr
} else if len(bifrostReq.Input.PromptArray) > 0 {
prompt := strings.Join(bifrostReq.Input.PromptArray, "\n")
input.Prompt = &prompt
}
// Map parameters if present
if bifrostReq.Params != nil {
params := bifrostReq.Params
// Temperature
if params.Temperature != nil {
input.Temperature = params.Temperature
}
// Top P
if params.TopP != nil {
input.TopP = params.TopP
}
// Max tokens
if params.MaxTokens != nil {
input.MaxTokens = params.MaxTokens
}
// Presence penalty
if params.PresencePenalty != nil {
input.PresencePenalty = params.PresencePenalty
}
// Frequency penalty
if params.FrequencyPenalty != nil {
input.FrequencyPenalty = params.FrequencyPenalty
}
// Top K (from ExtraParams)
if topK, ok := schemas.SafeExtractIntPointer(params.ExtraParams["top_k"]); ok {
input.TopK = topK
}
// Seed
if params.Seed != nil {
input.Seed = params.Seed
}
if params.ExtraParams != nil {
input.ExtraParams = params.ExtraParams
}
}
// Check if model is a version ID and set version field accordingly
req := &ReplicatePredictionRequest{
Input: input,
}
if isVersionID(bifrostReq.Model) {
req.Version = &bifrostReq.Model
}
if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil {
if webhook, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["webhook"]); ok {
req.Webhook = webhook
}
if webhookEventsFilter, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["webhook_events_filter"]); ok {
req.WebhookEventsFilter = webhookEventsFilter
}
}
return req, nil
}
// ToBifrostTextCompletionResponse converts a Replicate prediction response to Bifrost format
func (response *ReplicatePredictionResponse) ToBifrostTextCompletionResponse() *schemas.BifrostTextCompletionResponse {
if response == nil {
return nil
}
// Initialize Bifrost response
bifrostResponse := &schemas.BifrostTextCompletionResponse{
ID: response.ID,
Model: response.Model,
Object: "text_completion",
}
// Convert output to text
var textOutput *string
if response.Output != nil {
if response.Output.OutputStr != nil {
textOutput = response.Output.OutputStr
} else if response.Output.OutputArray != nil {
// Join array of strings into a single string
joined := strings.Join(response.Output.OutputArray, "")
textOutput = &joined
}
}
// Determine finish reason based on status
var finishReason *string
switch response.Status {
case ReplicatePredictionStatusSucceeded:
finishReason = schemas.Ptr("stop")
case ReplicatePredictionStatusFailed:
finishReason = schemas.Ptr("error")
case ReplicatePredictionStatusCanceled:
finishReason = schemas.Ptr("stop")
}
// Create choice with text completion response choice
choice := schemas.BifrostResponseChoice{
Index: 0,
TextCompletionResponseChoice: &schemas.TextCompletionResponseChoice{
Text: textOutput,
},
FinishReason: finishReason,
}
bifrostResponse.Choices = []schemas.BifrostResponseChoice{choice}
// Extract usage information from logs
if response.Logs != nil {
inputTokens, outputTokens, totalTokens, found := parseTokenUsageFromLogs(response.Logs, schemas.TextCompletionRequest)
if found {
bifrostResponse.Usage = &schemas.BifrostLLMUsage{
PromptTokens: inputTokens,
CompletionTokens: outputTokens,
TotalTokens: totalTokens,
}
}
}
return bifrostResponse
}

View File

@@ -0,0 +1,507 @@
package replicate
import (
"encoding/json"
"fmt"
"time"
"github.com/bytedance/sonic"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
)
// ==================== REQUEST TYPES ====================
// ReplicatePredictionRequest represents a request to create a prediction
type ReplicatePredictionRequest struct {
Version *string `json:"version,omitempty"` // Required: Model version ID
Input *ReplicatePredictionRequestInput `json:"input"` // Required: Input parameters for the model
Stream *bool `json:"stream,omitempty"` // Enable streaming output
Webhook *string `json:"webhook,omitempty"` // Webhook URL for notifications
WebhookEventsFilter []string `json:"webhook_events_filter,omitempty"` // Filter webhook events: start, output, logs, completed
OutputFileURLPrefix *string `json:"output_file_url_prefix,omitempty"` // Custom prefix for output file URLs
PollTimeout *int `json:"poll_timeout,omitempty"` // Timeout in seconds for polling (used with Prefer: wait header)
UseFileOutput *bool `json:"use_file_output,omitempty"` // Output files as URLs instead of data URIs
ExtraParams map[string]interface{} `json:"-"` // Extra parameters to merge into the request
}
// GetExtraParams implements the RequestBodyWithExtraParams interface
func (req *ReplicatePredictionRequest) GetExtraParams() map[string]interface{} {
return req.ExtraParams
}
// ReplicatePredictionRequestInput represents the input parameters for a model prediction
// This is flexible to support different model types - exact fields depend on the model
type ReplicatePredictionRequestInput struct {
Prompt *string `json:"prompt,omitempty"`
Messages []schemas.ChatMessage `json:"messages,omitempty"`
SystemPrompt *string `json:"system_prompt,omitempty"`
Image *string `json:"image,omitempty"` // URL or data URI
NumberOfImages *int `json:"number_of_images,omitempty"` // Number of images to generate
Quality *string `json:"quality,omitempty"` // Quality of the image
Background *string `json:"background,omitempty"` // Background of the image
Seed *int `json:"seed,omitempty"` // Random seed
ReasoningEffort *string `json:"reasoning_effort,omitempty"` // Reasoning effort
NumInferenceStep *int `json:"num_inference_steps,omitempty"` // Number of inference steps
NegativePrompt *string `json:"negative_prompt,omitempty"` // Negative prompt
// Responses parameters
Instructions *string `json:"instructions,omitempty"`
InputItemList []schemas.ResponsesMessage `json:"input_item_list,omitempty"`
Tools []schemas.ResponsesTool `json:"tools,omitempty"`
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
JsonSchema *schemas.ResponsesTextConfig `json:"json_schema,omitempty"`
// Chat parameters
Temperature *float64 `json:"temperature,omitempty"` // Temperature for sampling
TopP *float64 `json:"top_p,omitempty"` // Top-p sampling
TopK *int `json:"top_k,omitempty"` // Top-k sampling
MaxTokens *int `json:"max_tokens,omitempty"` // Maximum tokens to generate
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` // Maximum completion tokens to generate
PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Presence penalty
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Frequency penalty
// Image generation parameters
AspectRatio *string `json:"aspect_ratio,omitempty"`
Resolution *string `json:"resolution,omitempty"` // Resolution tier: "1k", "2k", "4k"
OutputFormat *string `json:"output_format,omitempty"`
InputImages []string `json:"input_images,omitempty"` // Image input for image-to-image models
ImagePrompt *string `json:"image_prompt,omitempty"` // Image prompt for image models (flux family)
ImageInput []string `json:"image_input,omitempty"` // Image input for chat models (openai family)
InputImage *string `json:"input_image,omitempty"` // Image input for image-to-image models
// video generation parameters
Duration *int `json:"duration,omitempty"`
InputReference *string `json:"input_reference,omitempty"`
ExtraParams map[string]interface{} `json:"-"` // Additional model-specific parameters
}
// MarshalJSON implements custom JSON marshalling for ReplicatePredictionRequestInput.
// It marshals all defined fields and then flattens ExtraParams at the top level.
func (r *ReplicatePredictionRequestInput) MarshalJSON() ([]byte, error) {
if r == nil {
return []byte("null"), nil
}
// Create a temporary type to avoid infinite recursion
type Alias ReplicatePredictionRequestInput
// Marshal the struct normally (ExtraParams will be omitted due to json:"-" tag)
aliasData, err := providerUtils.MarshalSorted((*Alias)(r))
if err != nil {
return nil, err
}
// If there are no ExtraParams, return the marshaled data as-is
if len(r.ExtraParams) == 0 {
return aliasData, nil
}
// Use order-preserving merge to avoid destroying key ordering in the serialized JSON.
return providerUtils.MergeExtraParamsIntoJSON(aliasData, r.ExtraParams)
}
// UnmarshalJSON implements custom JSON unmarshalling for ReplicatePredictionRequestInput.
// It unmarshals known fields and captures any unrecognized fields in ExtraParams.
func (r *ReplicatePredictionRequestInput) UnmarshalJSON(data []byte) error {
// Create a temporary type to avoid infinite recursion
type Alias ReplicatePredictionRequestInput
// Unmarshal into the alias type
aux := (*Alias)(r)
if err := sonic.Unmarshal(data, aux); err != nil {
return err
}
// Unmarshal into a map to find extra fields
var rawMap map[string]interface{}
if err := sonic.Unmarshal(data, &rawMap); err != nil {
return err
}
// List of known field names (in JSON format)
knownFields := map[string]bool{
"prompt": true,
"messages": true,
"system_prompt": true,
"image": true,
"number_of_images": true,
"quality": true,
"background": true,
"seed": true,
"reasoning_effort": true,
"num_inference_steps": true,
"negative_prompt": true,
"instructions": true,
"input_item_list": true,
"tools": true,
"max_output_tokens": true,
"json_schema": true,
"temperature": true,
"top_p": true,
"top_k": true,
"max_tokens": true,
"max_completion_tokens": true,
"presence_penalty": true,
"frequency_penalty": true,
"aspect_ratio": true,
"resolution": true,
"output_format": true,
"input_images": true,
"image_prompt": true,
"input_image": true,
"image_input": true,
"duration": true,
"input_reference": true,
}
// Collect extra fields
r.ExtraParams = make(map[string]interface{})
for key, value := range rawMap {
if !knownFields[key] {
r.ExtraParams[key] = value
}
}
// If no extra params found, keep it as nil instead of empty map
if len(r.ExtraParams) == 0 {
r.ExtraParams = nil
}
return nil
}
// ReplicateModelListRequest represents a request to list/search models
type ReplicateModelListRequest struct {
Query *string `json:"query,omitempty"` // Search query
Limit *int `json:"limit,omitempty"` // Maximum results (1-50, default 20)
}
// ==================== RESPONSE TYPES ====================
// ReplicatePredictionStatus represents the status of a prediction
type ReplicatePredictionStatus string
const (
ReplicatePredictionStatusStarting ReplicatePredictionStatus = "starting"
ReplicatePredictionStatusProcessing ReplicatePredictionStatus = "processing"
ReplicatePredictionStatusSucceeded ReplicatePredictionStatus = "succeeded"
ReplicatePredictionStatusFailed ReplicatePredictionStatus = "failed"
ReplicatePredictionStatusCanceled ReplicatePredictionStatus = "canceled"
)
// ReplicatePredictionResponse represents a prediction response
type ReplicatePredictionResponse struct {
ID string `json:"id"`
Model string `json:"model"` // Model identifier (owner/name or owner/name:version)
Version string `json:"version"` // Model version ID
Input json.RawMessage `json:"input"` // Input parameters used (json.RawMessage preserves key ordering)
Output *ReplicateOutput `json:"output,omitempty"` // Output data (can be various types)
Logs *string `json:"logs,omitempty"` // Execution logs
Error *string `json:"error,omitempty"` // Error message if failed
Status ReplicatePredictionStatus `json:"status"` // Current status
CreatedAt string `json:"created_at"` // ISO 8601 timestamp
StartedAt *string `json:"started_at,omitempty"` // ISO 8601 timestamp
CompletedAt *string `json:"completed_at,omitempty"` // ISO 8601 timestamp
URLs *ReplicatePredictionURLs `json:"urls,omitempty"` // URLs for API operations
Metrics *ReplicateMetrics `json:"metrics,omitempty"` // Execution metrics
DataRemoved *bool `json:"data_removed,omitempty"` // Whether data has been removed
Source *string `json:"source,omitempty"` // Source of the prediction (web/api)
WebhookCompleted *bool `json:"webhook_completed,omitempty"` // Whether webhook was completed
Stream *bool `json:"stream,omitempty"` // Whether the prediction is streaming
}
type ReplicateOutputText struct {
ResponseId *string `json:"response_id,omitempty"`
Text *string `json:"text,omitempty"`
}
type ReplicateOutput struct {
OutputStr *string
OutputArray []string
OutputObject *ReplicateOutputText
}
// MarshalJSON implements custom JSON marshalling for ReplicateOutput.
// It marshals either OutputStr, OutputArray, or OutputObject directly without wrapping.
func (mc ReplicateOutput) MarshalJSON() ([]byte, error) {
// Validation: ensure only one field is set at a time
fieldsSet := 0
if mc.OutputStr != nil {
fieldsSet++
}
if mc.OutputArray != nil {
fieldsSet++
}
if mc.OutputObject != nil {
fieldsSet++
}
if fieldsSet > 1 {
return nil, fmt.Errorf("multiple output fields are set; only one should be non-nil")
}
if mc.OutputStr != nil {
return providerUtils.MarshalSorted(*mc.OutputStr)
}
if mc.OutputArray != nil {
return providerUtils.MarshalSorted(mc.OutputArray)
}
if mc.OutputObject != nil {
return providerUtils.MarshalSorted(mc.OutputObject)
}
// If all are nil, return null
return providerUtils.MarshalSorted(nil)
}
// UnmarshalJSON implements custom JSON unmarshalling for ReplicateOutput.
// It determines whether "content" is a string, array, or object and assigns to the appropriate field.
func (mc *ReplicateOutput) UnmarshalJSON(data []byte) error {
mc.OutputStr = nil
mc.OutputArray = nil
mc.OutputObject = nil
if string(data) == "null" || len(data) == 0 {
return nil
}
// First, try to unmarshal as a direct string
var stringContent string
if err := sonic.Unmarshal(data, &stringContent); err == nil {
mc.OutputStr = schemas.Ptr(stringContent)
return nil
}
// Try to unmarshal as a direct array of strings
var arrayContent []string
if err := sonic.Unmarshal(data, &arrayContent); err == nil {
mc.OutputArray = arrayContent
return nil
}
// Try to unmarshal as an object (ReplicateOutputText)
var objectContent ReplicateOutputText
if err := sonic.Unmarshal(data, &objectContent); err == nil {
mc.OutputObject = &objectContent
return nil
}
return fmt.Errorf("output field is neither a string, array of strings, nor a valid object")
}
// ReplicatePredictionURLs represents URLs for prediction operations
type ReplicatePredictionURLs struct {
Get string `json:"get"` // URL to get prediction details
Cancel string `json:"cancel"` // URL to cancel prediction
Stream *string `json:"stream,omitempty"` // URL for streaming output (if applicable)
Web *string `json:"web,omitempty"` // URL for web output (if applicable)
}
// ReplicateMetrics represents execution metrics
type ReplicateMetrics struct {
PredictTime *float64 `json:"predict_time,omitempty"` // Time spent in prediction (seconds)
TotalTime *float64 `json:"total_time,omitempty"` // Total time including queue (seconds)
ImageCount *int `json:"image_count,omitempty"` // Number of images generated
TimeToFirstToken *float64 `json:"time_to_first_token,omitempty"` // Time to first token (seconds)
TokensPerSecond *float64 `json:"tokens_per_second,omitempty"` // Tokens generated per second
}
// ReplicatePredictionListResponse represents a paginated list of predictions
type ReplicatePredictionListResponse struct {
Next *string `json:"next"` // URL for next page
Previous *string `json:"previous"` // URL for previous page
Results []ReplicatePredictionResponse `json:"results"` // List of predictions
}
// ReplicateModelResponse represents a model response
type ReplicateModelResponse struct {
URL string `json:"url"` // Model API URL
Owner string `json:"owner"` // Owner username or org name
Name string `json:"name"` // Model name
Description *string `json:"description,omitempty"` // Model description
Visibility string `json:"visibility"` // "public" or "private"
GithubURL *string `json:"github_url,omitempty"` // GitHub repository URL
PaperURL *string `json:"paper_url,omitempty"` // Research paper URL
LicenseURL *string `json:"license_url,omitempty"` // License URL
RunCount *int `json:"run_count,omitempty"` // Number of times run
CoverImageURL *string `json:"cover_image_url,omitempty"` // Cover image URL
DefaultExample *json.RawMessage `json:"default_example,omitempty"` // Default example prediction (json.RawMessage preserves key ordering)
LatestVersion *ReplicateModelVersion `json:"latest_version,omitempty"` // Latest version details
FeaturedVersion *ReplicateModelVersion `json:"featured_version,omitempty"` // Featured version details
}
// ReplicateModelVersion represents a model version
type ReplicateModelVersion struct {
ID string `json:"id"` // Version ID
CreatedAt string `json:"created_at"` // ISO 8601 timestamp
CogVersion *string `json:"cog_version,omitempty"` // Cog version used
OpenAPISchema json.RawMessage `json:"openapi_schema,omitempty"` // OpenAPI schema for the model (json.RawMessage preserves key ordering)
DockerImageID *string `json:"docker_image_id,omitempty"` // Docker image ID
}
// ReplicateModelListResponse represents a paginated list of models
type ReplicateModelListResponse struct {
Next *string `json:"next"` // URL for next page
Previous *string `json:"previous"` // URL for previous page
Results []ReplicateModelResponse `json:"results"` // List of models
}
// ReplicateDeploymentOwner represents the owner of a deployment
type ReplicateDeploymentOwner struct {
Type string `json:"type"` // "user" or "organization"
Username string `json:"username"` // Username or organization name
Name *string `json:"name,omitempty"` // Display name
AvatarURL *string `json:"avatar_url,omitempty"` // Avatar URL
GithubURL *string `json:"github_url,omitempty"` // GitHub URL
}
// ReplicateDeploymentConfiguration represents the deployment configuration
type ReplicateDeploymentConfiguration struct {
Hardware string `json:"hardware"` // Hardware type (e.g., "gpu-t4")
MinInstances int `json:"min_instances"` // Minimum number of instances
MaxInstances int `json:"max_instances"` // Maximum number of instances
}
// ReplicateDeploymentRelease represents a deployment release
type ReplicateDeploymentRelease struct {
Number int `json:"number"` // Release number
Model string `json:"model"` // Model identifier (owner/name)
Version string `json:"version"` // Model version ID
CreatedAt string `json:"created_at"` // ISO 8601 timestamp
CreatedBy *ReplicateDeploymentOwner `json:"created_by"` // User or organization that created the release
Configuration *ReplicateDeploymentConfiguration `json:"configuration"` // Deployment configuration
}
// ReplicateDeployment represents a deployment
type ReplicateDeployment struct {
Owner string `json:"owner"` // Owner username or org name
Name string `json:"name"` // Deployment name
CurrentRelease *ReplicateDeploymentRelease `json:"current_release"` // Current active release
}
// ReplicateDeploymentListResponse represents a paginated list of deployments
type ReplicateDeploymentListResponse struct {
Next *string `json:"next"` // URL for next page
Previous *string `json:"previous"` // URL for previous page
Results []ReplicateDeployment `json:"results"` // List of deployments
}
// ==================== ERROR TYPES ====================
// ReplicateError represents an error response from the Replicate API
type ReplicateError struct {
Detail string `json:"detail"` // Error message
Status int `json:"status"` // HTTP status code
Title *string `json:"title,omitempty"` // Error title
Type *string `json:"type,omitempty"` // Error type
}
// ==================== STREAMING TYPES ====================
// ReplicateStreamEvent represents a streaming event
type ReplicateStreamEvent struct {
Event string `json:"event,omitempty"` // Event type (output, logs, done, error)
Data interface{} `json:"data,omitempty"` // Event data
Error *string `json:"error,omitempty"` // Error message if event is error
}
// ==================== WEBHOOK TYPES ====================
// ReplicateWebhookPayload represents a webhook payload
type ReplicateWebhookPayload struct {
ID string `json:"id"`
Model string `json:"model"`
Version string `json:"version"`
Input json.RawMessage `json:"input"`
Output interface{} `json:"output,omitempty"`
Logs *string `json:"logs,omitempty"`
Error *string `json:"error,omitempty"`
Status ReplicatePredictionStatus `json:"status"`
CreatedAt string `json:"created_at"`
StartedAt *string `json:"started_at,omitempty"`
CompletedAt *string `json:"completed_at,omitempty"`
URLs *ReplicatePredictionURLs `json:"urls,omitempty"`
Metrics *ReplicateMetrics `json:"metrics,omitempty"`
}
// ==================== SSE TYPES ====================
// ReplicateSSEEvent represents a Server-Sent Event from Replicate streaming API
type ReplicateSSEEvent struct {
Event string // Event type: "output", "done", "error"
Data string // Event data - can be plain text, JSON object, or data URI
ID string // Event ID (e.g., "1690212292:0")
}
// ReplicateDoneEvent represents the data payload of a "done" event
type ReplicateDoneEvent struct {
Reason string `json:"reason,omitempty"` // Reason for completion: "canceled", "error", or empty for success
Output interface{} `json:"output,omitempty"` // Output data if available (e.g., error message)
}
// ReplicateErrorEvent represents the data payload of an "error" event
type ReplicateErrorEvent struct {
Detail string `json:"detail"` // Error message
}
// ==================== UTILITY FUNCTIONS ====================
// ParseReplicateTimestamp parses a Replicate ISO 8601 timestamp to Unix timestamp
func ParseReplicateTimestamp(timestamp string) int64 {
if timestamp == "" {
return 0
}
t, err := time.Parse(time.RFC3339Nano, timestamp)
if err != nil {
return 0
}
return t.Unix()
}
// ToBifrostPredictionStatus converts Replicate status to Bifrost status
func ToBifrostPredictionStatus(status ReplicatePredictionStatus) string {
switch status {
case ReplicatePredictionStatusStarting:
return "starting"
case ReplicatePredictionStatusProcessing:
return "processing"
case ReplicatePredictionStatusSucceeded:
return "succeeded"
case ReplicatePredictionStatusFailed:
return "failed"
case ReplicatePredictionStatusCanceled:
return "canceled"
default:
return string(status)
}
}
// ==================== FILE API TYPES ====================
// ReplicateFileResponse represents a file resource from Replicate API
type ReplicateFileResponse struct {
ID string `json:"id"` // Unique file identifier
Checksums *ReplicateFileChecksum `json:"checksums,omitempty"` // File checksums
ContentType string `json:"content_type"` // MIME type
CreatedAt string `json:"created_at"` // ISO 8601 timestamp
ExpiresAt string `json:"expires_at,omitempty"` // ISO 8601 timestamp
Metadata json.RawMessage `json:"metadata,omitempty"` // User-provided metadata (json.RawMessage preserves key ordering)
Name string `json:"name,omitempty"` // File name
Size int64 `json:"size"` // File size in bytes
URLs *ReplicateFileURLs `json:"urls,omitempty"` // Associated URLs
}
// ReplicateFileChecksum represents checksums for a file
type ReplicateFileChecksum struct {
SHA256 string `json:"sha256,omitempty"` // SHA256 checksum
}
// ReplicateFileURLs represents URLs associated with a file
type ReplicateFileURLs struct {
Get string `json:"get"` // URL to retrieve file metadata
}
// ReplicateFileListResponse represents a paginated list of files
type ReplicateFileListResponse struct {
Next *string `json:"next,omitempty"` // URL for next page
Previous *string `json:"previous,omitempty"` // URL for previous page
Results []ReplicateFileResponse `json:"results"` // List of files
}

View File

@@ -0,0 +1,280 @@
package replicate
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"regexp"
"strconv"
"strings"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
// isTerminalStatus checks if a prediction status is terminal (completed/failed/canceled)
func isTerminalStatus(status ReplicatePredictionStatus) bool {
return status == ReplicatePredictionStatusSucceeded ||
status == ReplicatePredictionStatusFailed ||
status == ReplicatePredictionStatusCanceled
}
// checkForErrorStatus returns an error if the prediction failed
func checkForErrorStatus(prediction *ReplicatePredictionResponse) *schemas.BifrostError {
if prediction.Status == ReplicatePredictionStatusFailed {
errorMsg := "prediction failed"
if prediction.Error != nil && *prediction.Error != "" {
errorMsg = *prediction.Error
}
return providerUtils.NewBifrostOperationError(
"prediction failed",
fmt.Errorf("%s", errorMsg))
}
if prediction.Status == ReplicatePredictionStatusCanceled {
return providerUtils.NewBifrostOperationError(
"prediction was canceled",
fmt.Errorf("prediction was canceled"))
}
return nil
}
// parsePreferHeader parses the Prefer header to extract wait duration
// Examples: "wait", "wait=30", "wait=60"
// Returns the header value to use and whether sync mode should be enabled
func parsePreferHeader(extraHeaders map[string]string) bool {
if preferValue, exists := extraHeaders["Prefer"]; exists {
if strings.HasPrefix(preferValue, "wait") {
return true
}
return false
}
return false
}
// Streaming requests should always be async and not wait for completion,
// so the Prefer header (which enables sync mode) must be excluded.
func stripPreferHeader(extraHeaders map[string]string) map[string]string {
if extraHeaders == nil {
return nil
}
// Check if Prefer header exists
if _, exists := extraHeaders["Prefer"]; !exists {
// No Prefer header, return original map
return extraHeaders
}
// Create new map without Prefer header
filtered := make(map[string]string, len(extraHeaders)-1)
for key, value := range extraHeaders {
if key != "Prefer" {
filtered[key] = value
}
}
return filtered
}
// listenToReplicateStreamURL listens to a Replicate stream URL and processes SSE events.
// This is a reusable utility for any Replicate streaming endpoint.
// It returns the response body stream (as io.Reader) and any error that occurred during connection.
func listenToReplicateStreamURL(
ctx *schemas.BifrostContext,
client *fasthttp.Client,
streamURL string,
key schemas.Key,
) (io.Reader, *fasthttp.Response, *schemas.BifrostError) {
// Create request
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
resp.StreamBody = true
// Set URL and headers
req.SetRequestURI(streamURL)
req.Header.SetMethod(http.MethodGet)
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
// Set authorization header
if value := key.Value.GetValue(); value != "" {
req.Header.Set("Authorization", "Bearer "+value)
}
// Make request
err := client.Do(req, resp)
fasthttp.ReleaseRequest(req)
if err != nil {
providerUtils.ReleaseStreamingResponse(resp)
if errors.Is(err, context.Canceled) {
return nil, nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Type: schemas.Ptr(schemas.RequestCancelled),
Message: schemas.ErrRequestCancelled,
Error: err,
},
}
}
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
return nil, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err)
}
return nil, nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err)
}
// Extract provider response headers before status check so error responses also forward them
if ctx != nil {
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
}
// Check for HTTP errors
if resp.StatusCode() != fasthttp.StatusOK {
defer providerUtils.ReleaseStreamingResponse(resp)
return nil, nil, parseReplicateError(resp.Body(), resp.StatusCode())
}
return resp.BodyStream(), resp, nil
}
// parseDataURIImage extracts the base64 data from a data URI
// Example: "data:image/webp;base64,UklGRmSu..." -> "UklGRmSu..."
func parseDataURIImage(dataURI string) (base64Data string, mimeType string) {
// Format: data:image/webp;base64,<base64-data>
if !strings.HasPrefix(dataURI, "data:") {
return dataURI, "" // Return as-is if not a data URI
}
// Split by comma to separate metadata and data
parts := strings.SplitN(dataURI[len("data:"):], ",", 2)
if len(parts) != 2 {
return dataURI, ""
}
// Parse MIME type from metadata (e.g., "image/webp;base64")
metadata := parts[0]
metaParts := strings.Split(metadata, ";")
if len(metaParts) > 0 {
mimeType = metaParts[0]
}
// Return the base64 data
return parts[1], mimeType
}
// versionIDPattern matches a 64-character hexadecimal string (Replicate version ID format)
var versionIDPattern = regexp.MustCompile(`^[a-f0-9]{64}$`)
// isVersionID checks if a string is a Replicate version ID (64-character hex string)
func isVersionID(s string) bool {
return versionIDPattern.MatchString(s)
}
// buildPredictionURL builds the appropriate URL for creating a prediction
// Returns the URL for the appropriate prediction endpoint.
func buildPredictionURL(ctx *schemas.BifrostContext, baseURL, model string, customProviderConfig *schemas.CustomProviderConfig, requestType schemas.RequestType, useDeploymentsEndpoint bool) string {
var defaultPath string
if useDeploymentsEndpoint {
defaultPath = "/v1/deployments/" + model + "/predictions"
} else if isVersionID(model) {
// If model is a version ID, use base predictions endpoint
defaultPath = "/v1/predictions"
} else {
// If model is a name (owner/name), use model-specific endpoint
defaultPath = "/v1/models/" + model + "/predictions"
}
path, isCompleteURL := providerUtils.GetRequestPath(ctx, defaultPath, customProviderConfig, requestType)
if isCompleteURL {
return path
}
return baseURL + path
}
// parseTokenUsageFromLogs extracts token counts from Replicate's logs field
// Handles multiple log formats with varying levels of detail
func parseTokenUsageFromLogs(logs *string, requestType schemas.RequestType) (inputTokens, outputTokens, totalTokens int, found bool) {
if logs == nil || *logs == "" {
return 0, 0, 0, false
}
logText := *logs
foundAny := false
// Pattern 1: Detailed format with input/output breakdown
// "Input token count: 20"
// "Input text token count: 15"
inputPatterns := []string{
`Input token count:\s*(\d+)`,
`Input text token count:\s*(\d+)`,
}
for _, pattern := range inputPatterns {
if matches := regexp.MustCompile(pattern).FindStringSubmatch(logText); len(matches) > 1 {
if val, err := strconv.Atoi(matches[1]); err == nil {
inputTokens = val
foundAny = true
break
}
}
}
// "Input image token count: 0" (for image generation)
if matches := regexp.MustCompile(`Input image token count:\s*(\d+)`).FindStringSubmatch(logText); len(matches) > 1 {
if val, err := strconv.Atoi(matches[1]); err == nil {
inputTokens += val // Add to text input tokens
foundAny = true
}
}
// "Output token count: 28"
if matches := regexp.MustCompile(`Output token count:\s*(\d+)`).FindStringSubmatch(logText); len(matches) > 1 {
if val, err := strconv.Atoi(matches[1]); err == nil {
outputTokens = val
foundAny = true
}
}
// "Total token count: 48"
if matches := regexp.MustCompile(`Total token count:\s*(\d+)`).FindStringSubmatch(logText); len(matches) > 1 {
if val, err := strconv.Atoi(matches[1]); err == nil {
totalTokens = val
foundAny = true
}
}
// Pattern 2: Simple "Tokens: X" format (ambiguous - need heuristic)
// Only use if detailed format not found
if !foundAny {
if matches := regexp.MustCompile(`Tokens:\s*(\d+)`).FindStringSubmatch(logText); len(matches) > 1 {
if val, err := strconv.Atoi(matches[1]); err == nil {
// Heuristic based on response type
switch requestType {
case schemas.ImageGenerationRequest:
// For image generation, "Tokens: X" typically means output tokens
outputTokens = val
totalTokens = val
case schemas.TextCompletionRequest, schemas.ChatCompletionRequest, schemas.ResponsesRequest:
// For text, unclear - could be total or output
// Conservative approach: treat as total tokens
totalTokens = val
default:
// Unknown type - treat as total
totalTokens = val
}
foundAny = true
}
}
}
// If we found input/output but not total, compute it
if foundAny && totalTokens == 0 {
totalTokens = inputTokens + outputTokens
}
return inputTokens, outputTokens, totalTokens, foundAny
}

View File

@@ -0,0 +1,152 @@
package replicate
import (
"fmt"
"strconv"
"strings"
schemas "github.com/maximhq/bifrost/core/schemas"
)
func ToReplicateVideoGenerationInput(bifrostReq *schemas.BifrostVideoGenerationRequest) (*ReplicatePredictionRequest, error) {
if bifrostReq == nil || bifrostReq.Input == nil {
return nil, fmt.Errorf("bifrost request or input is nil")
}
input := &ReplicatePredictionRequestInput{
Prompt: &bifrostReq.Input.Prompt,
}
if bifrostReq.Input.InputReference != nil {
// convert input reference to base64
// if provider is openai, set input reference to base64
sanitizedURL, err := schemas.SanitizeImageURL(*bifrostReq.Input.InputReference)
if err != nil {
return nil, fmt.Errorf("invalid input reference: %w", err)
}
if strings.HasPrefix(bifrostReq.Model, string(schemas.OpenAI)) {
input.InputReference = schemas.Ptr(sanitizedURL)
} else {
input.Image = schemas.Ptr(sanitizedURL)
}
}
// Map parameters if available
if bifrostReq.Params != nil {
params := bifrostReq.Params
if params.Seconds != nil {
seconds, err := strconv.Atoi(*params.Seconds)
if err != nil {
return nil, fmt.Errorf("invalid seconds value: %w", err)
}
input.Duration = &seconds
}
if params.Seed != nil {
input.Seed = params.Seed
}
if params.NegativePrompt != nil {
input.NegativePrompt = params.NegativePrompt
}
if params.ExtraParams != nil {
input.ExtraParams = params.ExtraParams
}
}
request := &ReplicatePredictionRequest{
Input: input,
}
// Check if model is a version ID and set version field accordingly
if isVersionID(bifrostReq.Model) {
request.Version = &bifrostReq.Model
}
if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil {
request.ExtraParams = bifrostReq.Params.ExtraParams
if webhook, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["webhook"]); ok {
delete(request.ExtraParams, "webhook")
request.Webhook = webhook
}
if webhookEventsFilter, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["webhook_events_filter"]); ok {
delete(request.ExtraParams, "webhook_events_filter")
request.WebhookEventsFilter = webhookEventsFilter
}
}
return request, nil
}
func ToBifrostVideoGenerationResponse(prediction *ReplicatePredictionResponse) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
if prediction == nil {
return nil, &schemas.BifrostError{
IsBifrostError: true,
Error: &schemas.ErrorField{
Message: "prediction response is nil",
},
}
}
response := &schemas.BifrostVideoGenerationResponse{
ID: prediction.ID,
CreatedAt: ParseReplicateTimestamp(prediction.CreatedAt),
Model: prediction.Model,
Object: "video",
}
// Map Replicate status to Bifrost video status.
switch prediction.Status {
case ReplicatePredictionStatusStarting:
response.Status = schemas.VideoStatusQueued
case ReplicatePredictionStatusProcessing:
response.Status = schemas.VideoStatusInProgress
case ReplicatePredictionStatusSucceeded:
response.Status = schemas.VideoStatusCompleted
case ReplicatePredictionStatusFailed, ReplicatePredictionStatusCanceled:
response.Status = schemas.VideoStatusFailed
default:
response.Status = schemas.VideoStatusQueued
}
// Surface provider error details on failed terminal states.
if response.Status == schemas.VideoStatusFailed {
errorMsg := "prediction failed"
errorCode := string(prediction.Status)
if prediction.Error != nil && *prediction.Error != "" {
errorMsg = *prediction.Error
}
response.Error = &schemas.VideoCreateError{
Code: errorCode,
Message: errorMsg,
}
}
if prediction.CompletedAt != nil {
response.CompletedAt = schemas.Ptr(ParseReplicateTimestamp(*prediction.CompletedAt))
}
// Convert output to ImageData
// Replicate output can be either a string (single URL) or array of strings
if prediction.Output != nil {
if prediction.Output.OutputStr != nil && *prediction.Output.OutputStr != "" {
response.Videos = append(response.Videos, schemas.VideoOutput{
Type: schemas.VideoOutputTypeURL,
URL: schemas.Ptr(*prediction.Output.OutputStr),
ContentType: "video/mp4",
})
} else if len(prediction.Output.OutputArray) > 0 {
for _, url := range prediction.Output.OutputArray {
response.Videos = append(response.Videos, schemas.VideoOutput{
Type: schemas.VideoOutputTypeURL,
URL: schemas.Ptr(url),
ContentType: "video/mp4",
})
}
}
}
return response, nil
}