Files
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

149 lines
3.8 KiB
Go

package replicate
import (
"fmt"
"strings"
schemas "github.com/maximhq/bifrost/core/schemas"
)
func ToReplicateTextRequest(bifrostReq *schemas.BifrostTextCompletionRequest) (*ReplicatePredictionRequest, error) {
if bifrostReq == nil || bifrostReq.Input == nil {
return nil, fmt.Errorf("bifrost request is nil or prompt is nil")
}
input := &ReplicatePredictionRequestInput{}
if bifrostReq.Input.PromptStr != nil {
input.Prompt = bifrostReq.Input.PromptStr
} else if len(bifrostReq.Input.PromptArray) > 0 {
prompt := strings.Join(bifrostReq.Input.PromptArray, "\n")
input.Prompt = &prompt
}
// Map parameters if present
if bifrostReq.Params != nil {
params := bifrostReq.Params
// Temperature
if params.Temperature != nil {
input.Temperature = params.Temperature
}
// Top P
if params.TopP != nil {
input.TopP = params.TopP
}
// Max tokens
if params.MaxTokens != nil {
input.MaxTokens = params.MaxTokens
}
// Presence penalty
if params.PresencePenalty != nil {
input.PresencePenalty = params.PresencePenalty
}
// Frequency penalty
if params.FrequencyPenalty != nil {
input.FrequencyPenalty = params.FrequencyPenalty
}
// Top K (from ExtraParams)
if topK, ok := schemas.SafeExtractIntPointer(params.ExtraParams["top_k"]); ok {
input.TopK = topK
}
// Seed
if params.Seed != nil {
input.Seed = params.Seed
}
if params.ExtraParams != nil {
input.ExtraParams = params.ExtraParams
}
}
// Check if model is a version ID and set version field accordingly
req := &ReplicatePredictionRequest{
Input: input,
}
if isVersionID(bifrostReq.Model) {
req.Version = &bifrostReq.Model
}
if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil {
if webhook, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["webhook"]); ok {
req.Webhook = webhook
}
if webhookEventsFilter, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["webhook_events_filter"]); ok {
req.WebhookEventsFilter = webhookEventsFilter
}
}
return req, nil
}
// ToBifrostTextCompletionResponse converts a Replicate prediction response to Bifrost format
func (response *ReplicatePredictionResponse) ToBifrostTextCompletionResponse() *schemas.BifrostTextCompletionResponse {
if response == nil {
return nil
}
// Initialize Bifrost response
bifrostResponse := &schemas.BifrostTextCompletionResponse{
ID: response.ID,
Model: response.Model,
Object: "text_completion",
}
// Convert output to text
var textOutput *string
if response.Output != nil {
if response.Output.OutputStr != nil {
textOutput = response.Output.OutputStr
} else if response.Output.OutputArray != nil {
// Join array of strings into a single string
joined := strings.Join(response.Output.OutputArray, "")
textOutput = &joined
}
}
// Determine finish reason based on status
var finishReason *string
switch response.Status {
case ReplicatePredictionStatusSucceeded:
finishReason = schemas.Ptr("stop")
case ReplicatePredictionStatusFailed:
finishReason = schemas.Ptr("error")
case ReplicatePredictionStatusCanceled:
finishReason = schemas.Ptr("stop")
}
// Create choice with text completion response choice
choice := schemas.BifrostResponseChoice{
Index: 0,
TextCompletionResponseChoice: &schemas.TextCompletionResponseChoice{
Text: textOutput,
},
FinishReason: finishReason,
}
bifrostResponse.Choices = []schemas.BifrostResponseChoice{choice}
// Extract usage information from logs
if response.Logs != nil {
inputTokens, outputTokens, totalTokens, found := parseTokenUsageFromLogs(response.Logs, schemas.TextCompletionRequest)
if found {
bifrostResponse.Usage = &schemas.BifrostLLMUsage{
PromptTokens: inputTokens,
CompletionTokens: outputTokens,
TotalTokens: totalTokens,
}
}
}
return bifrostResponse
}