Files
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

316 lines
8.6 KiB
Go

package replicate
import (
"fmt"
"slices"
"strings"
"time"
schemas "github.com/maximhq/bifrost/core/schemas"
)
// unsupportedSystemPromptModels is a set of models that don't support the system_prompt field.
var unsupportedSystemPromptModels = []string{
"meta/meta-llama-3-8b",
"meta/llama-2-70b",
"openai/gpt-oss-20b",
"openai/o1-mini",
"xai/grok-4",
}
func ToReplicateChatRequest(bifrostReq *schemas.BifrostChatRequest) (*ReplicatePredictionRequest, error) {
if bifrostReq == nil || bifrostReq.Input == nil {
return nil, fmt.Errorf("bifrost request is nil or input is nil")
}
// Build the input from messages
input := &ReplicatePredictionRequestInput{}
isGPT5Structured := strings.HasPrefix(bifrostReq.Model, string(schemas.OpenAI)) && strings.Contains(bifrostReq.Model, "gpt-5-structured")
// openai models support messages
if len(bifrostReq.Input) > 0 && strings.HasPrefix(bifrostReq.Model, string(schemas.OpenAI)) {
if isGPT5Structured {
responsesMessages := []schemas.ResponsesMessage{}
for _, msg := range bifrostReq.Input {
responsesMessages = append(responsesMessages, msg.ToResponsesMessages()...)
}
if len(responsesMessages) > 0 {
input.InputItemList = responsesMessages
}
} else {
input.Messages = bifrostReq.Input
}
} else {
// Extract system prompt and build conversation prompt
var systemPrompt string
var conversationParts []string
var imageInput []string
for _, msg := range bifrostReq.Input {
if msg.Content == nil {
continue
}
// Get message content as string
var contentStr string
if msg.Content.ContentStr != nil {
contentStr = *msg.Content.ContentStr
} else if msg.Content.ContentBlocks != nil {
// Concatenate text blocks only
var textParts []string
for _, block := range msg.Content.ContentBlocks {
if block.Text != nil && *block.Text != "" {
textParts = append(textParts, *block.Text)
}
if block.ImageURLStruct != nil && block.ImageURLStruct.URL != "" {
imageInput = append(imageInput, block.ImageURLStruct.URL)
}
}
contentStr = strings.Join(textParts, "\n")
}
if contentStr == "" {
continue
}
// Handle different roles
switch msg.Role {
case schemas.ChatMessageRoleSystem:
if systemPrompt == "" {
systemPrompt = contentStr
} else {
systemPrompt += "\n" + contentStr
}
case schemas.ChatMessageRoleUser:
conversationParts = append(conversationParts, contentStr)
case schemas.ChatMessageRoleAssistant:
// For assistant messages, we can include them in the conversation context
conversationParts = append(conversationParts, contentStr)
}
}
// Set system prompt if present and model supports it
modelSupportsSystemPrompt := supportsSystemPrompt(bifrostReq.Model)
if systemPrompt != "" {
if modelSupportsSystemPrompt {
// Model supports system_prompt field
input.SystemPrompt = &systemPrompt
} else {
// Model doesn't support system_prompt - prepend to prompt
if len(conversationParts) > 0 {
// Prepend system prompt to conversation
conversationParts = append([]string{systemPrompt}, conversationParts...)
} else {
// No conversation parts, use system prompt as the prompt
conversationParts = []string{systemPrompt}
}
}
}
// Build the final prompt from conversation parts
if len(conversationParts) > 0 {
prompt := strings.Join(conversationParts, "\n\n")
input.Prompt = &prompt
}
// Ensure we have at least some content (prompt or system prompt)
if input.Prompt == nil && input.SystemPrompt == nil {
return nil, fmt.Errorf("no content found in chat messages - need at least one user or system message")
}
if len(imageInput) > 0 {
input.ImageInput = imageInput
}
}
// Map parameters if present
if bifrostReq.Params != nil {
params := bifrostReq.Params
// Temperature
if params.Temperature != nil {
input.Temperature = params.Temperature
}
// Top P
if params.TopP != nil {
input.TopP = params.TopP
}
// Max tokens - use max_completion_tokens if available
if params.MaxCompletionTokens != nil {
if isGPT5Structured {
input.MaxOutputTokens = params.MaxCompletionTokens
} else if strings.HasPrefix(bifrostReq.Model, string(schemas.OpenAI)) {
input.MaxCompletionTokens = params.MaxCompletionTokens
} else {
input.MaxTokens = params.MaxCompletionTokens
}
}
// Presence penalty
if params.PresencePenalty != nil {
input.PresencePenalty = params.PresencePenalty
}
// Frequency penalty
if params.FrequencyPenalty != nil {
input.FrequencyPenalty = params.FrequencyPenalty
}
// Seed
if params.Seed != nil {
input.Seed = params.Seed
}
if params.Reasoning != nil {
if params.Reasoning.Effort != nil {
input.ReasoningEffort = params.Reasoning.Effort
}
}
if isGPT5Structured {
if len(params.Tools) > 0 {
responsesTools := []schemas.ResponsesTool{}
for _, tool := range params.Tools {
responsesTools = append(responsesTools, *tool.ToResponsesTool())
}
if len(responsesTools) > 0 {
input.Tools = responsesTools
}
}
}
if params.ExtraParams != nil {
input.ExtraParams = params.ExtraParams
}
}
// Check if model is a version ID and set version field accordingly
req := &ReplicatePredictionRequest{
Input: input,
}
if isVersionID(bifrostReq.Model) {
req.Version = &bifrostReq.Model
}
if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil {
if webhook, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["webhook"]); ok {
req.Webhook = webhook
}
if webhookEventsFilter, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["webhook_events_filter"]); ok {
req.WebhookEventsFilter = webhookEventsFilter
}
}
return req, nil
}
// ToBifrostChatResponse converts a Replicate prediction response to Bifrost format
func (response *ReplicatePredictionResponse) ToBifrostChatResponse() *schemas.BifrostChatResponse {
if response == nil {
return nil
}
// Parse timestamps
createdAt := ParseReplicateTimestamp(response.CreatedAt)
if createdAt == 0 {
createdAt = time.Now().Unix()
}
// Initialize Bifrost response
bifrostResponse := &schemas.BifrostChatResponse{
ID: response.ID,
Model: response.Model,
Object: "chat.completion",
Created: int(createdAt),
}
// Convert output to content
var contentStr *string
if response.Output != nil {
if response.Output.OutputStr != nil {
contentStr = response.Output.OutputStr
} else if response.Output.OutputArray != nil {
// Join array of strings into a single string
joined := strings.Join(response.Output.OutputArray, "")
contentStr = &joined
} else if response.Output.OutputObject != nil && response.Output.OutputObject.Text != nil {
contentStr = response.Output.OutputObject.Text
}
}
// Create message content
messageContent := schemas.ChatMessageContent{
ContentStr: contentStr,
}
// Create the assistant message
message := schemas.ChatMessage{
Role: schemas.ChatMessageRoleAssistant,
Content: &messageContent,
}
// Determine finish reason based on status
var finishReason *string
switch response.Status {
case ReplicatePredictionStatusSucceeded:
reason := "stop"
finishReason = &reason
case ReplicatePredictionStatusFailed:
reason := "error"
finishReason = &reason
case ReplicatePredictionStatusCanceled:
reason := "stop"
finishReason = &reason
}
// Create choice
choice := schemas.BifrostResponseChoice{
Index: 0,
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
Message: &message,
},
FinishReason: finishReason,
}
bifrostResponse.Choices = []schemas.BifrostResponseChoice{choice}
// Extract usage information from logs
if response.Logs != nil {
inputTokens, outputTokens, totalTokens, found := parseTokenUsageFromLogs(response.Logs, schemas.ChatCompletionRequest)
if found {
bifrostResponse.Usage = &schemas.BifrostLLMUsage{
PromptTokens: inputTokens,
CompletionTokens: outputTokens,
TotalTokens: totalTokens,
}
}
}
return bifrostResponse
}
// supportsSystemPrompt checks if a model supports the system_prompt field.
func supportsSystemPrompt(model string) bool {
// Normalize model name to lowercase for comparison
modelLower := strings.ToLower(model)
// Extract model identifier (handle both "owner/name" and "owner/name:version" formats)
modelIdentifier := modelLower
if idx := strings.Index(modelLower, ":"); idx != -1 {
modelIdentifier = modelLower[:idx]
}
// All deepseek models don't support system prompt
if strings.HasPrefix(modelIdentifier, "deepseek-ai/deepseek") {
return false
}
isUnsupported := slices.Contains(unsupportedSystemPromptModels, modelIdentifier)
return !isUnsupported
}