first commit
This commit is contained in:
148
core/providers/replicate/text.go
Normal file
148
core/providers/replicate/text.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package replicate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func ToReplicateTextRequest(bifrostReq *schemas.BifrostTextCompletionRequest) (*ReplicatePredictionRequest, error) {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil {
|
||||
return nil, fmt.Errorf("bifrost request is nil or prompt is nil")
|
||||
}
|
||||
|
||||
input := &ReplicatePredictionRequestInput{}
|
||||
if bifrostReq.Input.PromptStr != nil {
|
||||
input.Prompt = bifrostReq.Input.PromptStr
|
||||
} else if len(bifrostReq.Input.PromptArray) > 0 {
|
||||
prompt := strings.Join(bifrostReq.Input.PromptArray, "\n")
|
||||
input.Prompt = &prompt
|
||||
}
|
||||
|
||||
// Map parameters if present
|
||||
if bifrostReq.Params != nil {
|
||||
params := bifrostReq.Params
|
||||
|
||||
// Temperature
|
||||
if params.Temperature != nil {
|
||||
input.Temperature = params.Temperature
|
||||
}
|
||||
|
||||
// Top P
|
||||
if params.TopP != nil {
|
||||
input.TopP = params.TopP
|
||||
}
|
||||
|
||||
// Max tokens
|
||||
if params.MaxTokens != nil {
|
||||
input.MaxTokens = params.MaxTokens
|
||||
}
|
||||
|
||||
// Presence penalty
|
||||
if params.PresencePenalty != nil {
|
||||
input.PresencePenalty = params.PresencePenalty
|
||||
}
|
||||
|
||||
// Frequency penalty
|
||||
if params.FrequencyPenalty != nil {
|
||||
input.FrequencyPenalty = params.FrequencyPenalty
|
||||
}
|
||||
|
||||
// Top K (from ExtraParams)
|
||||
if topK, ok := schemas.SafeExtractIntPointer(params.ExtraParams["top_k"]); ok {
|
||||
input.TopK = topK
|
||||
}
|
||||
|
||||
// Seed
|
||||
if params.Seed != nil {
|
||||
input.Seed = params.Seed
|
||||
}
|
||||
|
||||
if params.ExtraParams != nil {
|
||||
input.ExtraParams = params.ExtraParams
|
||||
}
|
||||
}
|
||||
|
||||
// Check if model is a version ID and set version field accordingly
|
||||
req := &ReplicatePredictionRequest{
|
||||
Input: input,
|
||||
}
|
||||
|
||||
if isVersionID(bifrostReq.Model) {
|
||||
req.Version = &bifrostReq.Model
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil {
|
||||
if webhook, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["webhook"]); ok {
|
||||
req.Webhook = webhook
|
||||
}
|
||||
if webhookEventsFilter, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["webhook_events_filter"]); ok {
|
||||
req.WebhookEventsFilter = webhookEventsFilter
|
||||
}
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// ToBifrostTextCompletionResponse converts a Replicate prediction response to Bifrost format
|
||||
func (response *ReplicatePredictionResponse) ToBifrostTextCompletionResponse() *schemas.BifrostTextCompletionResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initialize Bifrost response
|
||||
bifrostResponse := &schemas.BifrostTextCompletionResponse{
|
||||
ID: response.ID,
|
||||
Model: response.Model,
|
||||
Object: "text_completion",
|
||||
}
|
||||
|
||||
// Convert output to text
|
||||
var textOutput *string
|
||||
if response.Output != nil {
|
||||
if response.Output.OutputStr != nil {
|
||||
textOutput = response.Output.OutputStr
|
||||
} else if response.Output.OutputArray != nil {
|
||||
// Join array of strings into a single string
|
||||
joined := strings.Join(response.Output.OutputArray, "")
|
||||
textOutput = &joined
|
||||
}
|
||||
}
|
||||
|
||||
// Determine finish reason based on status
|
||||
var finishReason *string
|
||||
switch response.Status {
|
||||
case ReplicatePredictionStatusSucceeded:
|
||||
finishReason = schemas.Ptr("stop")
|
||||
case ReplicatePredictionStatusFailed:
|
||||
finishReason = schemas.Ptr("error")
|
||||
case ReplicatePredictionStatusCanceled:
|
||||
finishReason = schemas.Ptr("stop")
|
||||
}
|
||||
|
||||
// Create choice with text completion response choice
|
||||
choice := schemas.BifrostResponseChoice{
|
||||
Index: 0,
|
||||
TextCompletionResponseChoice: &schemas.TextCompletionResponseChoice{
|
||||
Text: textOutput,
|
||||
},
|
||||
FinishReason: finishReason,
|
||||
}
|
||||
|
||||
bifrostResponse.Choices = []schemas.BifrostResponseChoice{choice}
|
||||
|
||||
// Extract usage information from logs
|
||||
if response.Logs != nil {
|
||||
inputTokens, outputTokens, totalTokens, found := parseTokenUsageFromLogs(response.Logs, schemas.TextCompletionRequest)
|
||||
if found {
|
||||
bifrostResponse.Usage = &schemas.BifrostLLMUsage{
|
||||
PromptTokens: inputTokens,
|
||||
CompletionTokens: outputTokens,
|
||||
TotalTokens: totalTokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
Reference in New Issue
Block a user