149 lines
3.8 KiB
Go
149 lines
3.8 KiB
Go
package replicate
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
|
|
schemas "github.com/maximhq/bifrost/core/schemas"
|
|
)
|
|
|
|
func ToReplicateTextRequest(bifrostReq *schemas.BifrostTextCompletionRequest) (*ReplicatePredictionRequest, error) {
|
|
if bifrostReq == nil || bifrostReq.Input == nil {
|
|
return nil, fmt.Errorf("bifrost request is nil or prompt is nil")
|
|
}
|
|
|
|
input := &ReplicatePredictionRequestInput{}
|
|
if bifrostReq.Input.PromptStr != nil {
|
|
input.Prompt = bifrostReq.Input.PromptStr
|
|
} else if len(bifrostReq.Input.PromptArray) > 0 {
|
|
prompt := strings.Join(bifrostReq.Input.PromptArray, "\n")
|
|
input.Prompt = &prompt
|
|
}
|
|
|
|
// Map parameters if present
|
|
if bifrostReq.Params != nil {
|
|
params := bifrostReq.Params
|
|
|
|
// Temperature
|
|
if params.Temperature != nil {
|
|
input.Temperature = params.Temperature
|
|
}
|
|
|
|
// Top P
|
|
if params.TopP != nil {
|
|
input.TopP = params.TopP
|
|
}
|
|
|
|
// Max tokens
|
|
if params.MaxTokens != nil {
|
|
input.MaxTokens = params.MaxTokens
|
|
}
|
|
|
|
// Presence penalty
|
|
if params.PresencePenalty != nil {
|
|
input.PresencePenalty = params.PresencePenalty
|
|
}
|
|
|
|
// Frequency penalty
|
|
if params.FrequencyPenalty != nil {
|
|
input.FrequencyPenalty = params.FrequencyPenalty
|
|
}
|
|
|
|
// Top K (from ExtraParams)
|
|
if topK, ok := schemas.SafeExtractIntPointer(params.ExtraParams["top_k"]); ok {
|
|
input.TopK = topK
|
|
}
|
|
|
|
// Seed
|
|
if params.Seed != nil {
|
|
input.Seed = params.Seed
|
|
}
|
|
|
|
if params.ExtraParams != nil {
|
|
input.ExtraParams = params.ExtraParams
|
|
}
|
|
}
|
|
|
|
// Check if model is a version ID and set version field accordingly
|
|
req := &ReplicatePredictionRequest{
|
|
Input: input,
|
|
}
|
|
|
|
if isVersionID(bifrostReq.Model) {
|
|
req.Version = &bifrostReq.Model
|
|
}
|
|
|
|
if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil {
|
|
if webhook, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["webhook"]); ok {
|
|
req.Webhook = webhook
|
|
}
|
|
if webhookEventsFilter, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["webhook_events_filter"]); ok {
|
|
req.WebhookEventsFilter = webhookEventsFilter
|
|
}
|
|
}
|
|
|
|
return req, nil
|
|
}
|
|
|
|
// ToBifrostTextCompletionResponse converts a Replicate prediction response to Bifrost format
|
|
func (response *ReplicatePredictionResponse) ToBifrostTextCompletionResponse() *schemas.BifrostTextCompletionResponse {
|
|
if response == nil {
|
|
return nil
|
|
}
|
|
|
|
// Initialize Bifrost response
|
|
bifrostResponse := &schemas.BifrostTextCompletionResponse{
|
|
ID: response.ID,
|
|
Model: response.Model,
|
|
Object: "text_completion",
|
|
}
|
|
|
|
// Convert output to text
|
|
var textOutput *string
|
|
if response.Output != nil {
|
|
if response.Output.OutputStr != nil {
|
|
textOutput = response.Output.OutputStr
|
|
} else if response.Output.OutputArray != nil {
|
|
// Join array of strings into a single string
|
|
joined := strings.Join(response.Output.OutputArray, "")
|
|
textOutput = &joined
|
|
}
|
|
}
|
|
|
|
// Determine finish reason based on status
|
|
var finishReason *string
|
|
switch response.Status {
|
|
case ReplicatePredictionStatusSucceeded:
|
|
finishReason = schemas.Ptr("stop")
|
|
case ReplicatePredictionStatusFailed:
|
|
finishReason = schemas.Ptr("error")
|
|
case ReplicatePredictionStatusCanceled:
|
|
finishReason = schemas.Ptr("stop")
|
|
}
|
|
|
|
// Create choice with text completion response choice
|
|
choice := schemas.BifrostResponseChoice{
|
|
Index: 0,
|
|
TextCompletionResponseChoice: &schemas.TextCompletionResponseChoice{
|
|
Text: textOutput,
|
|
},
|
|
FinishReason: finishReason,
|
|
}
|
|
|
|
bifrostResponse.Choices = []schemas.BifrostResponseChoice{choice}
|
|
|
|
// Extract usage information from logs
|
|
if response.Logs != nil {
|
|
inputTokens, outputTokens, totalTokens, found := parseTokenUsageFromLogs(response.Logs, schemas.TextCompletionRequest)
|
|
if found {
|
|
bifrostResponse.Usage = &schemas.BifrostLLMUsage{
|
|
PromptTokens: inputTokens,
|
|
CompletionTokens: outputTokens,
|
|
TotalTokens: totalTokens,
|
|
}
|
|
}
|
|
}
|
|
|
|
return bifrostResponse
|
|
}
|