1441 lines
45 KiB
Go
1441 lines
45 KiB
Go
package replicate_test
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"mime"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/maximhq/bifrost/core/internal/llmtests"
|
|
"github.com/maximhq/bifrost/core/providers/replicate"
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type testLogger struct{}
|
|
|
|
func (l *testLogger) Debug(string, ...any) {}
|
|
func (l *testLogger) Info(string, ...any) {}
|
|
func (l *testLogger) Warn(string, ...any) {}
|
|
func (l *testLogger) Error(string, ...any) {}
|
|
func (l *testLogger) Fatal(string, ...any) {}
|
|
func (l *testLogger) SetLevel(schemas.LogLevel) {}
|
|
func (l *testLogger) SetOutputType(schemas.LoggerOutputType) {}
|
|
func (l *testLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
|
|
return schemas.NoopLogEvent
|
|
}
|
|
|
|
func multipartFieldOrderFromRequest(r *http.Request) ([]string, error) {
|
|
mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if mediaType != "multipart/form-data" {
|
|
return nil, assert.AnError
|
|
}
|
|
|
|
reader := multipart.NewReader(r.Body, params["boundary"])
|
|
var order []string
|
|
for {
|
|
part, err := reader.NextPart()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
order = append(order, part.FormName())
|
|
_, _ = io.Copy(io.Discard, part)
|
|
_ = part.Close()
|
|
}
|
|
return order, nil
|
|
}
|
|
|
|
func TestFileUpload_OrdersMetadataBeforeFile(t *testing.T) {
|
|
var (
|
|
order []string
|
|
handlerErr error
|
|
)
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
order, handlerErr = multipartFieldOrderFromRequest(r)
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
_, _ = w.Write([]byte(`{"title":"forced error"}`))
|
|
}))
|
|
defer server.Close()
|
|
|
|
provider, err := replicate.NewReplicateProvider(&schemas.ProviderConfig{
|
|
NetworkConfig: schemas.NetworkConfig{BaseURL: server.URL},
|
|
}, &testLogger{})
|
|
require.NoError(t, err)
|
|
|
|
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
defer ctx.Cancel()
|
|
|
|
contentType := "application/json"
|
|
_, bifrostErr := provider.FileUpload(ctx, schemas.Key{}, &schemas.BifrostFileUploadRequest{
|
|
Provider: schemas.Replicate,
|
|
File: []byte(`{"hello":"world"}`),
|
|
Filename: "payload.json",
|
|
ContentType: &contentType,
|
|
ExtraParams: map[string]interface{}{
|
|
"metadata": map[string]interface{}{"owner": "oss", "purpose": "test"},
|
|
},
|
|
})
|
|
require.NotNil(t, bifrostErr)
|
|
require.NoError(t, handlerErr)
|
|
assert.Equal(t, []string{"filename", "type", "metadata", "content"}, order)
|
|
}
|
|
|
|
func TestReplicate(t *testing.T) {
|
|
t.Parallel()
|
|
if strings.TrimSpace(os.Getenv("REPLICATE_API_KEY")) == "" {
|
|
t.Skip("Skipping Replicate tests because REPLICATE_API_KEY is not set")
|
|
}
|
|
|
|
client, ctx, cancel, err := llmtests.SetupTest()
|
|
if err != nil {
|
|
t.Fatalf("Error initializing test setup: %v", err)
|
|
}
|
|
defer cancel()
|
|
defer client.Shutdown()
|
|
|
|
testConfig := llmtests.ComprehensiveTestConfig{
|
|
Provider: schemas.Replicate,
|
|
ChatModel: "openai/gpt-4.1-mini",
|
|
TextModel: "openai/gpt-4.1-mini",
|
|
ImageGenerationModel: "black-forest-labs/flux-dev",
|
|
ImageEditModel: "black-forest-labs/flux-dev",
|
|
VideoGenerationModel: "openai/sora-2-pro",
|
|
FileExtraParams: map[string]interface{}{
|
|
"owner": os.Getenv("REPLICATE_OWNER"),
|
|
"expiry": 1830297599,
|
|
},
|
|
Scenarios: llmtests.TestScenarios{
|
|
TextCompletion: true,
|
|
SimpleChat: true,
|
|
CompletionStream: true,
|
|
MultiTurnConversation: false,
|
|
ToolCalls: false,
|
|
ToolCallsStreaming: false,
|
|
MultipleToolCalls: false,
|
|
End2EndToolCalling: false,
|
|
AutomaticFunctionCall: false,
|
|
ImageURL: false,
|
|
ImageBase64: false,
|
|
ImageGeneration: true,
|
|
ImageEdit: true,
|
|
ImageEditStream: true,
|
|
ImageGenerationStream: true,
|
|
FileBase64: false,
|
|
FileURL: false,
|
|
MultipleImages: false,
|
|
CompleteEnd2End: false,
|
|
Reasoning: false,
|
|
Embedding: false,
|
|
ListModels: true,
|
|
FileUpload: true,
|
|
FileList: true,
|
|
FileRetrieve: true,
|
|
FileDelete: true,
|
|
FileContent: false,
|
|
VideoGeneration: false, // disabled for now because of long running operations
|
|
VideoRetrieve: false,
|
|
VideoRemix: false,
|
|
VideoDownload: false,
|
|
VideoList: false,
|
|
VideoDelete: false,
|
|
},
|
|
}
|
|
|
|
t.Run("ReplicateTests", func(t *testing.T) {
|
|
llmtests.RunAllComprehensiveTests(t, client, ctx, testConfig)
|
|
})
|
|
}
|
|
|
|
// TestBifrostToReplicateChatRequestConversion tests the conversion from Bifrost chat request to Replicate format
|
|
// with special handling based on model names
|
|
func TestBifrostToReplicateChatRequestConversion(t *testing.T) {
|
|
maxTokens := 100
|
|
temp := 0.7
|
|
topP := 0.9
|
|
seed := 42
|
|
presencePenalty := 0.5
|
|
frequencyPenalty := 0.3
|
|
|
|
tests := []struct {
|
|
name string
|
|
input *schemas.BifrostChatRequest
|
|
validate func(t *testing.T, result *replicate.ReplicatePredictionRequest)
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "OpenAI_Model_With_Messages",
|
|
input: &schemas.BifrostChatRequest{
|
|
Model: "openai/gpt-4.1-mini",
|
|
Input: []schemas.ChatMessage{
|
|
{
|
|
Role: schemas.ChatMessageRoleSystem,
|
|
Content: &schemas.ChatMessageContent{
|
|
ContentStr: schemas.Ptr("You are a helpful assistant."),
|
|
},
|
|
},
|
|
{
|
|
Role: schemas.ChatMessageRoleUser,
|
|
Content: &schemas.ChatMessageContent{
|
|
ContentStr: schemas.Ptr("Hello!"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
// OpenAI models should use Messages field
|
|
assert.NotNil(t, result.Input.Messages)
|
|
assert.Len(t, result.Input.Messages, 2)
|
|
assert.Equal(t, schemas.ChatMessageRoleSystem, result.Input.Messages[0].Role)
|
|
assert.Equal(t, schemas.ChatMessageRoleUser, result.Input.Messages[1].Role)
|
|
},
|
|
},
|
|
{
|
|
name: "OpenAI_Model_MaxCompletionTokens",
|
|
input: &schemas.BifrostChatRequest{
|
|
Model: "openai/gpt-4o",
|
|
Input: []schemas.ChatMessage{
|
|
{
|
|
Role: schemas.ChatMessageRoleUser,
|
|
Content: &schemas.ChatMessageContent{
|
|
ContentStr: schemas.Ptr("Test"),
|
|
},
|
|
},
|
|
},
|
|
Params: &schemas.ChatParameters{
|
|
MaxCompletionTokens: &maxTokens,
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
// OpenAI models should use MaxCompletionTokens field
|
|
assert.NotNil(t, result.Input.MaxCompletionTokens)
|
|
assert.Equal(t, maxTokens, *result.Input.MaxCompletionTokens)
|
|
assert.Nil(t, result.Input.MaxTokens)
|
|
},
|
|
},
|
|
{
|
|
name: "Deepseek_Model_NoSystemPrompt",
|
|
input: &schemas.BifrostChatRequest{
|
|
Model: "deepseek-ai/deepseek-coder-33b-instruct",
|
|
Input: []schemas.ChatMessage{
|
|
{
|
|
Role: schemas.ChatMessageRoleSystem,
|
|
Content: &schemas.ChatMessageContent{
|
|
ContentStr: schemas.Ptr("You are a helpful assistant."),
|
|
},
|
|
},
|
|
{
|
|
Role: schemas.ChatMessageRoleUser,
|
|
Content: &schemas.ChatMessageContent{
|
|
ContentStr: schemas.Ptr("Hello!"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
// Deepseek models don't support system_prompt, it should be prepended to prompt
|
|
assert.Nil(t, result.Input.SystemPrompt)
|
|
assert.NotNil(t, result.Input.Prompt)
|
|
// System prompt should be prepended to conversation
|
|
assert.Contains(t, *result.Input.Prompt, "You are a helpful assistant.")
|
|
assert.Contains(t, *result.Input.Prompt, "Hello!")
|
|
},
|
|
},
|
|
{
|
|
name: "Meta_Llama_NoSystemPrompt",
|
|
input: &schemas.BifrostChatRequest{
|
|
Model: "meta/meta-llama-3-8b",
|
|
Input: []schemas.ChatMessage{
|
|
{
|
|
Role: schemas.ChatMessageRoleSystem,
|
|
Content: &schemas.ChatMessageContent{
|
|
ContentStr: schemas.Ptr("Be concise."),
|
|
},
|
|
},
|
|
{
|
|
Role: schemas.ChatMessageRoleUser,
|
|
Content: &schemas.ChatMessageContent{
|
|
ContentStr: schemas.Ptr("What is AI?"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
// Meta llama models don't support system_prompt
|
|
assert.Nil(t, result.Input.SystemPrompt)
|
|
assert.NotNil(t, result.Input.Prompt)
|
|
assert.Contains(t, *result.Input.Prompt, "Be concise.")
|
|
assert.Contains(t, *result.Input.Prompt, "What is AI?")
|
|
},
|
|
},
|
|
{
|
|
name: "Regular_Model_WithSystemPrompt",
|
|
input: &schemas.BifrostChatRequest{
|
|
Model: "meta/llama-3.1-70b-instruct",
|
|
Input: []schemas.ChatMessage{
|
|
{
|
|
Role: schemas.ChatMessageRoleSystem,
|
|
Content: &schemas.ChatMessageContent{
|
|
ContentStr: schemas.Ptr("You are helpful."),
|
|
},
|
|
},
|
|
{
|
|
Role: schemas.ChatMessageRoleUser,
|
|
Content: &schemas.ChatMessageContent{
|
|
ContentStr: schemas.Ptr("Hi there"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
// Regular models support system_prompt
|
|
assert.NotNil(t, result.Input.SystemPrompt)
|
|
assert.Equal(t, "You are helpful.", *result.Input.SystemPrompt)
|
|
assert.NotNil(t, result.Input.Prompt)
|
|
assert.Equal(t, "Hi there", *result.Input.Prompt)
|
|
},
|
|
},
|
|
{
|
|
name: "Non_OpenAI_Model_MaxTokens",
|
|
input: &schemas.BifrostChatRequest{
|
|
Model: "meta/llama-3.1-70b-instruct",
|
|
Input: []schemas.ChatMessage{
|
|
{
|
|
Role: schemas.ChatMessageRoleUser,
|
|
Content: &schemas.ChatMessageContent{
|
|
ContentStr: schemas.Ptr("Test"),
|
|
},
|
|
},
|
|
},
|
|
Params: &schemas.ChatParameters{
|
|
MaxCompletionTokens: &maxTokens,
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
// Non-OpenAI models should use MaxTokens field
|
|
assert.NotNil(t, result.Input.MaxTokens)
|
|
assert.Equal(t, maxTokens, *result.Input.MaxTokens)
|
|
assert.Nil(t, result.Input.MaxCompletionTokens)
|
|
},
|
|
},
|
|
{
|
|
name: "Model_With_Version_ID",
|
|
input: &schemas.BifrostChatRequest{
|
|
Model: "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef",
|
|
Input: []schemas.ChatMessage{
|
|
{
|
|
Role: schemas.ChatMessageRoleUser,
|
|
Content: &schemas.ChatMessageContent{
|
|
ContentStr: schemas.Ptr("Test version ID"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
// Version ID should be set in Version field
|
|
assert.NotNil(t, result.Version)
|
|
assert.Equal(t, "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", *result.Version)
|
|
},
|
|
},
|
|
{
|
|
name: "AllParameters",
|
|
input: &schemas.BifrostChatRequest{
|
|
Model: "meta/llama-3.1-70b-instruct",
|
|
Input: []schemas.ChatMessage{
|
|
{
|
|
Role: schemas.ChatMessageRoleUser,
|
|
Content: &schemas.ChatMessageContent{
|
|
ContentStr: schemas.Ptr("Test all params"),
|
|
},
|
|
},
|
|
},
|
|
Params: &schemas.ChatParameters{
|
|
MaxCompletionTokens: &maxTokens,
|
|
Temperature: &temp,
|
|
TopP: &topP,
|
|
Seed: &seed,
|
|
PresencePenalty: &presencePenalty,
|
|
FrequencyPenalty: &frequencyPenalty,
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
assert.Equal(t, maxTokens, *result.Input.MaxTokens)
|
|
assert.Equal(t, temp, *result.Input.Temperature)
|
|
assert.Equal(t, topP, *result.Input.TopP)
|
|
assert.Equal(t, seed, *result.Input.Seed)
|
|
assert.Equal(t, presencePenalty, *result.Input.PresencePenalty)
|
|
assert.Equal(t, frequencyPenalty, *result.Input.FrequencyPenalty)
|
|
},
|
|
},
|
|
{
|
|
name: "MultipartContent_WithImageURL",
|
|
input: &schemas.BifrostChatRequest{
|
|
Model: "meta/llama-3.1-70b-instruct",
|
|
Input: []schemas.ChatMessage{
|
|
{
|
|
Role: schemas.ChatMessageRoleUser,
|
|
Content: &schemas.ChatMessageContent{
|
|
ContentBlocks: []schemas.ChatContentBlock{
|
|
{
|
|
Text: schemas.Ptr("Describe this image"),
|
|
},
|
|
{
|
|
ImageURLStruct: &schemas.ChatInputImage{
|
|
URL: "https://example.com/image.jpg",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
// Image URL should be added to ImageInput
|
|
assert.NotNil(t, result.Input.ImageInput)
|
|
assert.Len(t, result.Input.ImageInput, 1)
|
|
assert.Equal(t, "https://example.com/image.jpg", result.Input.ImageInput[0])
|
|
// Text should be in prompt
|
|
assert.NotNil(t, result.Input.Prompt)
|
|
assert.Equal(t, "Describe this image", *result.Input.Prompt)
|
|
},
|
|
},
|
|
{
|
|
name: "MultipartContent_Base64Images",
|
|
input: &schemas.BifrostChatRequest{
|
|
Model: "meta/llama-3.1-70b-instruct",
|
|
Input: []schemas.ChatMessage{
|
|
{
|
|
Role: schemas.ChatMessageRoleUser,
|
|
Content: &schemas.ChatMessageContent{
|
|
ContentBlocks: []schemas.ChatContentBlock{
|
|
{
|
|
Text: schemas.Ptr("Test"),
|
|
},
|
|
{
|
|
ImageURLStruct: &schemas.ChatInputImage{
|
|
URL: "data:image/png;base64,iVBORw0KGgoAAAANS",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
assert.NotNil(t, result.Input.ImageInput)
|
|
},
|
|
},
|
|
{
|
|
name: "ReasoningEffort",
|
|
input: &schemas.BifrostChatRequest{
|
|
Model: "meta/llama-3.1-70b-instruct",
|
|
Input: []schemas.ChatMessage{
|
|
{
|
|
Role: schemas.ChatMessageRoleUser,
|
|
Content: &schemas.ChatMessageContent{
|
|
ContentStr: schemas.Ptr("Test reasoning"),
|
|
},
|
|
},
|
|
},
|
|
Params: &schemas.ChatParameters{
|
|
Reasoning: &schemas.ChatReasoning{
|
|
Effort: schemas.Ptr("high"),
|
|
},
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
assert.NotNil(t, result.Input.ReasoningEffort)
|
|
assert.Equal(t, "high", *result.Input.ReasoningEffort)
|
|
},
|
|
},
|
|
{
|
|
name: "NilRequest",
|
|
input: nil,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "NilInput",
|
|
input: &schemas.BifrostChatRequest{
|
|
Model: "meta/llama-3.1-70b-instruct",
|
|
Input: nil,
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "EmptyMessages",
|
|
input: &schemas.BifrostChatRequest{
|
|
Model: "meta/llama-3.1-70b-instruct",
|
|
Input: []schemas.ChatMessage{},
|
|
},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
actual, err := replicate.ToReplicateChatRequest(tt.input)
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
assert.Nil(t, actual)
|
|
} else {
|
|
require.NoError(t, err)
|
|
require.NotNil(t, actual)
|
|
if tt.validate != nil {
|
|
tt.validate(t, actual)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestBifrostToReplicateImageGenerationConversion tests the conversion from Bifrost image generation request
|
|
// to Replicate format with special handling for different model families
|
|
func TestBifrostToReplicateImageGenerationConversion(t *testing.T) {
|
|
prompt := "A beautiful sunset"
|
|
aspectRatio := "16:9"
|
|
numImages := 2
|
|
seed := 42
|
|
negativePrompt := "blurry"
|
|
numInferenceSteps := 50
|
|
quality := "high"
|
|
outputFormat := "png"
|
|
background := "white"
|
|
|
|
tests := []struct {
|
|
name string
|
|
input *schemas.BifrostImageGenerationRequest
|
|
validate func(t *testing.T, result *replicate.ReplicatePredictionRequest)
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "AllParameters",
|
|
input: &schemas.BifrostImageGenerationRequest{
|
|
Model: "black-forest-labs/flux-dev",
|
|
Input: &schemas.ImageGenerationInput{
|
|
Prompt: prompt,
|
|
},
|
|
Params: &schemas.ImageGenerationParameters{
|
|
N: &numImages,
|
|
AspectRatio: &aspectRatio,
|
|
Seed: &seed,
|
|
NegativePrompt: &negativePrompt,
|
|
NumInferenceSteps: &numInferenceSteps,
|
|
Quality: &quality,
|
|
OutputFormat: &outputFormat,
|
|
Background: &background,
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
assert.Equal(t, prompt, *result.Input.Prompt)
|
|
assert.Equal(t, numImages, *result.Input.NumberOfImages)
|
|
assert.Equal(t, aspectRatio, *result.Input.AspectRatio)
|
|
assert.Equal(t, seed, *result.Input.Seed)
|
|
assert.Equal(t, negativePrompt, *result.Input.NegativePrompt)
|
|
assert.Equal(t, numInferenceSteps, *result.Input.NumInferenceStep)
|
|
assert.Equal(t, quality, *result.Input.Quality)
|
|
assert.Equal(t, outputFormat, *result.Input.OutputFormat)
|
|
assert.Equal(t, background, *result.Input.Background)
|
|
},
|
|
},
|
|
{
|
|
name: "Version_ID",
|
|
input: &schemas.BifrostImageGenerationRequest{
|
|
Model: "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef",
|
|
Input: &schemas.ImageGenerationInput{
|
|
Prompt: prompt,
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
// Version ID should be set in Version field
|
|
assert.NotNil(t, result.Version)
|
|
assert.Equal(t, "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", *result.Version)
|
|
},
|
|
},
|
|
{
|
|
name: "ExtraParams",
|
|
input: &schemas.BifrostImageGenerationRequest{
|
|
Model: "black-forest-labs/flux-dev",
|
|
Input: &schemas.ImageGenerationInput{
|
|
Prompt: prompt,
|
|
},
|
|
Params: &schemas.ImageGenerationParameters{
|
|
ExtraParams: map[string]interface{}{
|
|
"custom_param": "value",
|
|
},
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
assert.NotNil(t, result.Input.ExtraParams)
|
|
assert.Equal(t, "value", result.Input.ExtraParams["custom_param"])
|
|
},
|
|
},
|
|
{
|
|
name: "NilRequest",
|
|
input: nil,
|
|
wantErr: false, // Function returns nil, not error
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
assert.Nil(t, result)
|
|
},
|
|
},
|
|
{
|
|
name: "NilInput",
|
|
input: &schemas.BifrostImageGenerationRequest{
|
|
Model: "black-forest-labs/flux-dev",
|
|
Input: nil,
|
|
},
|
|
wantErr: false,
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
assert.Nil(t, result)
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
actual := replicate.ToReplicateImageGenerationInput(tt.input)
|
|
if tt.wantErr {
|
|
assert.Nil(t, actual)
|
|
} else {
|
|
if tt.validate != nil {
|
|
tt.validate(t, actual)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestBifrostToReplicateResponsesRequestConversion tests the conversion from Bifrost responses request to Replicate format
|
|
func TestBifrostToReplicateResponsesRequestConversion(t *testing.T) {
|
|
maxOutputTokens := 100
|
|
temp := 0.7
|
|
topP := 0.9
|
|
reasoningEffort := "medium"
|
|
instructions := "Be concise"
|
|
|
|
tests := []struct {
|
|
name string
|
|
input *schemas.BifrostResponsesRequest
|
|
validate func(t *testing.T, result *replicate.ReplicatePredictionRequest)
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "GPT5_Structured_With_InputItemList",
|
|
input: &schemas.BifrostResponsesRequest{
|
|
Model: "openai/gpt-5-structured",
|
|
Input: []schemas.ResponsesMessage{
|
|
{
|
|
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
|
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
|
Content: &schemas.ResponsesMessageContent{
|
|
ContentStr: schemas.Ptr("Hello, how are you?"),
|
|
},
|
|
},
|
|
},
|
|
Params: &schemas.ResponsesParameters{
|
|
Instructions: &instructions,
|
|
MaxOutputTokens: &maxOutputTokens,
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
// GPT-5 structured models should use InputItemList
|
|
assert.NotNil(t, result.Input.InputItemList)
|
|
assert.Len(t, result.Input.InputItemList, 1)
|
|
assert.Equal(t, schemas.ResponsesInputMessageRoleUser, *result.Input.InputItemList[0].Role)
|
|
assert.Equal(t, "Hello, how are you?", *result.Input.InputItemList[0].Content.ContentStr)
|
|
// Check parameters
|
|
assert.Equal(t, &instructions, result.Input.Instructions)
|
|
assert.Equal(t, &maxOutputTokens, result.Input.MaxOutputTokens)
|
|
},
|
|
},
|
|
{
|
|
name: "GPT5_Structured_With_Tools",
|
|
input: &schemas.BifrostResponsesRequest{
|
|
Model: "openai/gpt-5-structured-preview",
|
|
Input: []schemas.ResponsesMessage{
|
|
{
|
|
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
|
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
|
Content: &schemas.ResponsesMessageContent{
|
|
ContentStr: schemas.Ptr("What's the weather?"),
|
|
},
|
|
},
|
|
},
|
|
Params: &schemas.ResponsesParameters{
|
|
Tools: []schemas.ResponsesTool{
|
|
{
|
|
Type: schemas.ResponsesToolTypeFunction,
|
|
Name: schemas.Ptr("get_weather"),
|
|
Description: schemas.Ptr("Get weather information"),
|
|
ResponsesToolFunction: &schemas.ResponsesToolFunction{
|
|
Parameters: &schemas.ToolFunctionParameters{
|
|
Type: "object",
|
|
Properties: schemas.NewOrderedMapFromPairs(
|
|
schemas.KV("location", map[string]interface{}{
|
|
"type": "string",
|
|
}),
|
|
),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
assert.NotNil(t, result.Input.Tools)
|
|
assert.Len(t, result.Input.Tools, 1)
|
|
assert.Equal(t, "get_weather", *result.Input.Tools[0].Name)
|
|
assert.NotNil(t, result.Input.InputItemList)
|
|
},
|
|
},
|
|
{
|
|
name: "OpenAI_Family_With_Messages",
|
|
input: &schemas.BifrostResponsesRequest{
|
|
Model: "openai/gpt-4o",
|
|
Input: []schemas.ResponsesMessage{
|
|
{
|
|
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
|
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleSystem),
|
|
Content: &schemas.ResponsesMessageContent{
|
|
ContentStr: schemas.Ptr("You are helpful."),
|
|
},
|
|
},
|
|
{
|
|
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
|
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
|
Content: &schemas.ResponsesMessageContent{
|
|
ContentStr: schemas.Ptr("Hello!"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
// OpenAI family (non-gpt5-structured) should use Messages
|
|
assert.NotNil(t, result.Input.Messages)
|
|
assert.Len(t, result.Input.Messages, 2)
|
|
assert.Equal(t, schemas.ChatMessageRoleSystem, result.Input.Messages[0].Role)
|
|
assert.Equal(t, schemas.ChatMessageRoleUser, result.Input.Messages[1].Role)
|
|
assert.Nil(t, result.Input.InputItemList)
|
|
},
|
|
},
|
|
{
|
|
name: "NonOpenAI_Model_With_SystemPrompt",
|
|
input: &schemas.BifrostResponsesRequest{
|
|
Model: "meta/llama-3.1-70b-instruct",
|
|
Input: []schemas.ResponsesMessage{
|
|
{
|
|
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
|
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleSystem),
|
|
Content: &schemas.ResponsesMessageContent{
|
|
ContentStr: schemas.Ptr("Be helpful."),
|
|
},
|
|
},
|
|
{
|
|
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
|
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
|
Content: &schemas.ResponsesMessageContent{
|
|
ContentStr: schemas.Ptr("Hi there!"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
// Non-OpenAI models that support system_prompt
|
|
assert.NotNil(t, result.Input.SystemPrompt)
|
|
assert.Equal(t, "Be helpful.", *result.Input.SystemPrompt)
|
|
assert.NotNil(t, result.Input.Prompt)
|
|
assert.Equal(t, "Hi there!", *result.Input.Prompt)
|
|
},
|
|
},
|
|
{
|
|
name: "NonOpenAI_Model_NoSystemPrompt_Support",
|
|
input: &schemas.BifrostResponsesRequest{
|
|
Model: "deepseek-ai/deepseek-coder-33b-instruct",
|
|
Input: []schemas.ResponsesMessage{
|
|
{
|
|
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
|
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleSystem),
|
|
Content: &schemas.ResponsesMessageContent{
|
|
ContentStr: schemas.Ptr("You are a code assistant."),
|
|
},
|
|
},
|
|
{
|
|
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
|
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
|
Content: &schemas.ResponsesMessageContent{
|
|
ContentStr: schemas.Ptr("Write a function."),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
// Deepseek models don't support system_prompt, should be prepended to prompt
|
|
assert.Nil(t, result.Input.SystemPrompt)
|
|
assert.NotNil(t, result.Input.Prompt)
|
|
assert.Contains(t, *result.Input.Prompt, "You are a code assistant.")
|
|
assert.Contains(t, *result.Input.Prompt, "Write a function.")
|
|
},
|
|
},
|
|
{
|
|
name: "ContentBlocks_With_Text",
|
|
input: &schemas.BifrostResponsesRequest{
|
|
Model: "meta/llama-3.1-70b-instruct",
|
|
Input: []schemas.ResponsesMessage{
|
|
{
|
|
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
|
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
|
Content: &schemas.ResponsesMessageContent{
|
|
ContentBlocks: []schemas.ResponsesMessageContentBlock{
|
|
{
|
|
Text: schemas.Ptr("Part 1"),
|
|
},
|
|
{
|
|
Text: schemas.Ptr("Part 2"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
assert.NotNil(t, result.Input.Prompt)
|
|
// Text parts should be joined with newline
|
|
assert.Equal(t, "Part 1\nPart 2", *result.Input.Prompt)
|
|
},
|
|
},
|
|
{
|
|
name: "ContentBlocks_With_ImageURL",
|
|
input: &schemas.BifrostResponsesRequest{
|
|
Model: "meta/llama-3.1-70b-instruct",
|
|
Input: []schemas.ResponsesMessage{
|
|
{
|
|
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
|
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
|
Content: &schemas.ResponsesMessageContent{
|
|
ContentBlocks: []schemas.ResponsesMessageContentBlock{
|
|
{
|
|
Text: schemas.Ptr("Describe this"),
|
|
},
|
|
{
|
|
ResponsesInputMessageContentBlockImage: &schemas.ResponsesInputMessageContentBlockImage{
|
|
ImageURL: schemas.Ptr("https://example.com/image.jpg"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
// Non-base64 image URLs should be added to ImageInput
|
|
assert.NotNil(t, result.Input.ImageInput)
|
|
assert.Len(t, result.Input.ImageInput, 1)
|
|
assert.Equal(t, "https://example.com/image.jpg", result.Input.ImageInput[0])
|
|
assert.NotNil(t, result.Input.Prompt)
|
|
assert.Equal(t, "Describe this", *result.Input.Prompt)
|
|
},
|
|
},
|
|
{
|
|
name: "ContentBlocks_Base64_Images",
|
|
input: &schemas.BifrostResponsesRequest{
|
|
Model: "meta/llama-3.1-70b-instruct",
|
|
Input: []schemas.ResponsesMessage{
|
|
{
|
|
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
|
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
|
Content: &schemas.ResponsesMessageContent{
|
|
ContentBlocks: []schemas.ResponsesMessageContentBlock{
|
|
{
|
|
Text: schemas.Ptr("Test"),
|
|
},
|
|
{
|
|
ResponsesInputMessageContentBlockImage: &schemas.ResponsesInputMessageContentBlockImage{
|
|
ImageURL: schemas.Ptr("data:image/png;base64,abc123"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
assert.NotNil(t, result.Input.ImageInput)
|
|
},
|
|
},
|
|
{
|
|
name: "MultipleMessages_Assistant_User",
|
|
input: &schemas.BifrostResponsesRequest{
|
|
Model: "meta/llama-3.1-70b-instruct",
|
|
Input: []schemas.ResponsesMessage{
|
|
{
|
|
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
|
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
|
Content: &schemas.ResponsesMessageContent{
|
|
ContentStr: schemas.Ptr("What is AI?"),
|
|
},
|
|
},
|
|
{
|
|
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
|
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant),
|
|
Content: &schemas.ResponsesMessageContent{
|
|
ContentStr: schemas.Ptr("AI is artificial intelligence."),
|
|
},
|
|
},
|
|
{
|
|
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
|
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
|
Content: &schemas.ResponsesMessageContent{
|
|
ContentStr: schemas.Ptr("Tell me more."),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
assert.NotNil(t, result.Input.Prompt)
|
|
// All conversation parts should be joined
|
|
assert.Contains(t, *result.Input.Prompt, "What is AI?")
|
|
assert.Contains(t, *result.Input.Prompt, "AI is artificial intelligence.")
|
|
assert.Contains(t, *result.Input.Prompt, "Tell me more.")
|
|
},
|
|
},
|
|
{
|
|
name: "Parameters_OpenAI_MaxCompletionTokens",
|
|
input: &schemas.BifrostResponsesRequest{
|
|
Model: "openai/gpt-4o",
|
|
Input: []schemas.ResponsesMessage{
|
|
{
|
|
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
|
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
|
Content: &schemas.ResponsesMessageContent{
|
|
ContentStr: schemas.Ptr("Test"),
|
|
},
|
|
},
|
|
},
|
|
Params: &schemas.ResponsesParameters{
|
|
MaxOutputTokens: &maxOutputTokens,
|
|
Temperature: &temp,
|
|
TopP: &topP,
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
// OpenAI models should use MaxCompletionTokens
|
|
assert.NotNil(t, result.Input.MaxCompletionTokens)
|
|
assert.Equal(t, maxOutputTokens, *result.Input.MaxCompletionTokens)
|
|
assert.Nil(t, result.Input.MaxTokens)
|
|
// Check other parameters
|
|
assert.Equal(t, temp, *result.Input.Temperature)
|
|
assert.Equal(t, topP, *result.Input.TopP)
|
|
},
|
|
},
|
|
{
|
|
name: "Parameters_NonOpenAI_MaxTokens",
|
|
input: &schemas.BifrostResponsesRequest{
|
|
Model: "meta/llama-3.1-70b-instruct",
|
|
Input: []schemas.ResponsesMessage{
|
|
{
|
|
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
|
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
|
Content: &schemas.ResponsesMessageContent{
|
|
ContentStr: schemas.Ptr("Test"),
|
|
},
|
|
},
|
|
},
|
|
Params: &schemas.ResponsesParameters{
|
|
MaxOutputTokens: &maxOutputTokens,
|
|
Temperature: &temp,
|
|
TopP: &topP,
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
// Non-OpenAI models should use MaxTokens
|
|
assert.NotNil(t, result.Input.MaxTokens)
|
|
assert.Equal(t, maxOutputTokens, *result.Input.MaxTokens)
|
|
assert.Nil(t, result.Input.MaxCompletionTokens)
|
|
},
|
|
},
|
|
{
|
|
name: "ReasoningEffort",
|
|
input: &schemas.BifrostResponsesRequest{
|
|
Model: "meta/llama-3.1-70b-instruct",
|
|
Input: []schemas.ResponsesMessage{
|
|
{
|
|
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
|
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
|
Content: &schemas.ResponsesMessageContent{
|
|
ContentStr: schemas.Ptr("Test reasoning"),
|
|
},
|
|
},
|
|
},
|
|
Params: &schemas.ResponsesParameters{
|
|
Reasoning: &schemas.ResponsesParametersReasoning{
|
|
Effort: &reasoningEffort,
|
|
},
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
assert.NotNil(t, result.Input.ReasoningEffort)
|
|
assert.Equal(t, reasoningEffort, *result.Input.ReasoningEffort)
|
|
},
|
|
},
|
|
{
|
|
name: "Instructions_With_SystemPrompt_Support",
|
|
input: &schemas.BifrostResponsesRequest{
|
|
Model: "meta/llama-3.1-70b-instruct",
|
|
Input: []schemas.ResponsesMessage{
|
|
{
|
|
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
|
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
|
Content: &schemas.ResponsesMessageContent{
|
|
ContentStr: schemas.Ptr("Hello"),
|
|
},
|
|
},
|
|
},
|
|
Params: &schemas.ResponsesParameters{
|
|
Instructions: &instructions,
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
// Models that support system_prompt should use it for instructions
|
|
assert.NotNil(t, result.Input.SystemPrompt)
|
|
assert.Equal(t, instructions, *result.Input.SystemPrompt)
|
|
},
|
|
},
|
|
{
|
|
name: "Instructions_Without_SystemPrompt_Support",
|
|
input: &schemas.BifrostResponsesRequest{
|
|
Model: "deepseek-ai/deepseek-coder-33b-instruct",
|
|
Input: []schemas.ResponsesMessage{
|
|
{
|
|
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
|
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
|
Content: &schemas.ResponsesMessageContent{
|
|
ContentStr: schemas.Ptr("Hello"),
|
|
},
|
|
},
|
|
},
|
|
Params: &schemas.ResponsesParameters{
|
|
Instructions: &instructions,
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
// Models that don't support system_prompt should prepend instructions to prompt
|
|
assert.Nil(t, result.Input.SystemPrompt)
|
|
assert.NotNil(t, result.Input.Prompt)
|
|
assert.Contains(t, *result.Input.Prompt, instructions)
|
|
assert.Contains(t, *result.Input.Prompt, "Hello")
|
|
},
|
|
},
|
|
{
|
|
name: "Instructions_Prepended_To_Existing_Prompt",
|
|
input: &schemas.BifrostResponsesRequest{
|
|
Model: "deepseek-ai/deepseek-coder-33b-instruct",
|
|
Input: []schemas.ResponsesMessage{
|
|
{
|
|
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
|
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
|
Content: &schemas.ResponsesMessageContent{
|
|
ContentStr: schemas.Ptr("Existing content"),
|
|
},
|
|
},
|
|
},
|
|
Params: &schemas.ResponsesParameters{
|
|
Instructions: &instructions,
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
assert.NotNil(t, result.Input.Prompt)
|
|
// Instructions should be prepended before existing content
|
|
promptStr := *result.Input.Prompt
|
|
instrIdx := strings.Index(promptStr, instructions)
|
|
contentIdx := strings.Index(promptStr, "Existing content")
|
|
assert.Less(t, instrIdx, contentIdx, "Instructions should come before content")
|
|
},
|
|
},
|
|
{
|
|
name: "Version_ID",
|
|
input: &schemas.BifrostResponsesRequest{
|
|
Model: "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef",
|
|
Input: []schemas.ResponsesMessage{
|
|
{
|
|
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
|
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
|
Content: &schemas.ResponsesMessageContent{
|
|
ContentStr: schemas.Ptr("Test version"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
// Version ID should be set in Version field
|
|
assert.NotNil(t, result.Version)
|
|
assert.Equal(t, "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", *result.Version)
|
|
},
|
|
},
|
|
{
|
|
name: "EmptyContent_Messages_Skipped",
|
|
input: &schemas.BifrostResponsesRequest{
|
|
Model: "meta/llama-3.1-70b-instruct",
|
|
Input: []schemas.ResponsesMessage{
|
|
{
|
|
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
|
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
|
Content: &schemas.ResponsesMessageContent{
|
|
ContentStr: schemas.Ptr(""),
|
|
},
|
|
},
|
|
{
|
|
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
|
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
|
Content: &schemas.ResponsesMessageContent{
|
|
ContentStr: schemas.Ptr("Valid message"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Input)
|
|
assert.NotNil(t, result.Input.Prompt)
|
|
// Only valid message should be present
|
|
assert.Equal(t, "Valid message", *result.Input.Prompt)
|
|
},
|
|
},
|
|
{
|
|
name: "NilRequest",
|
|
input: nil,
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
actual, err := replicate.ToReplicateResponsesRequest(tt.input)
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
assert.Nil(t, actual)
|
|
} else {
|
|
require.NoError(t, err)
|
|
require.NotNil(t, actual)
|
|
if tt.validate != nil {
|
|
tt.validate(t, actual)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestReplicateToBifrostResponsesResponse tests the conversion from Replicate prediction response to Bifrost responses format
|
|
func TestReplicateToBifrostResponsesResponse(t *testing.T) {
|
|
predictionID := "test-prediction-123"
|
|
model := "openai/gpt-4o"
|
|
createdAt := "2024-01-25T12:00:00.000000Z"
|
|
completedAt := "2024-01-25T12:00:05.000000Z"
|
|
errorMsg := "Something went wrong"
|
|
|
|
tests := []struct {
|
|
name string
|
|
input *replicate.ReplicatePredictionResponse
|
|
validate func(t *testing.T, result *schemas.BifrostResponsesResponse)
|
|
}{
|
|
{
|
|
name: "Successful_Response_OutputStr",
|
|
input: &replicate.ReplicatePredictionResponse{
|
|
ID: predictionID,
|
|
Model: model,
|
|
CreatedAt: createdAt,
|
|
Status: replicate.ReplicatePredictionStatusSucceeded,
|
|
Output: &replicate.ReplicateOutput{
|
|
OutputStr: schemas.Ptr("This is the response text."),
|
|
},
|
|
Logs: schemas.Ptr("Input token count: 10\nOutput token count: 20\nTotal token count: 30"),
|
|
},
|
|
validate: func(t *testing.T, result *schemas.BifrostResponsesResponse) {
|
|
require.NotNil(t, result)
|
|
assert.Equal(t, predictionID, *result.ID)
|
|
assert.Equal(t, model, result.Model)
|
|
assert.NotNil(t, result.Status)
|
|
assert.Equal(t, "completed", *result.Status)
|
|
// Check output messages
|
|
require.NotNil(t, result.Output)
|
|
require.Len(t, result.Output, 1)
|
|
assert.Equal(t, schemas.ResponsesMessageTypeMessage, *result.Output[0].Type)
|
|
assert.Equal(t, schemas.ResponsesInputMessageRoleAssistant, *result.Output[0].Role)
|
|
assert.Equal(t, "This is the response text.", *result.Output[0].Content.ContentStr)
|
|
// Check usage
|
|
require.NotNil(t, result.Usage)
|
|
assert.Equal(t, 10, result.Usage.InputTokens)
|
|
assert.Equal(t, 20, result.Usage.OutputTokens)
|
|
assert.Equal(t, 30, result.Usage.TotalTokens)
|
|
},
|
|
},
|
|
{
|
|
name: "Successful_Response_OutputArray",
|
|
input: &replicate.ReplicatePredictionResponse{
|
|
ID: predictionID,
|
|
Model: model,
|
|
CreatedAt: createdAt,
|
|
Status: replicate.ReplicatePredictionStatusSucceeded,
|
|
Output: &replicate.ReplicateOutput{
|
|
OutputArray: []string{"Part 1", " Part 2", " Part 3"},
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *schemas.BifrostResponsesResponse) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Output)
|
|
require.Len(t, result.Output, 1)
|
|
// Array should be joined into a single string
|
|
assert.Equal(t, "Part 1 Part 2 Part 3", *result.Output[0].Content.ContentStr)
|
|
},
|
|
},
|
|
{
|
|
name: "Successful_Response_OutputObject",
|
|
input: &replicate.ReplicatePredictionResponse{
|
|
ID: predictionID,
|
|
Model: model,
|
|
CreatedAt: createdAt,
|
|
Status: replicate.ReplicatePredictionStatusSucceeded,
|
|
Output: &replicate.ReplicateOutput{
|
|
OutputObject: &replicate.ReplicateOutputText{
|
|
Text: schemas.Ptr("Object text content"),
|
|
},
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *schemas.BifrostResponsesResponse) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Output)
|
|
require.Len(t, result.Output, 1)
|
|
assert.Equal(t, "Object text content", *result.Output[0].Content.ContentStr)
|
|
},
|
|
},
|
|
{
|
|
name: "Failed_Response_With_Error",
|
|
input: &replicate.ReplicatePredictionResponse{
|
|
ID: predictionID,
|
|
Model: model,
|
|
CreatedAt: createdAt,
|
|
Status: replicate.ReplicatePredictionStatusFailed,
|
|
Error: &errorMsg,
|
|
},
|
|
validate: func(t *testing.T, result *schemas.BifrostResponsesResponse) {
|
|
require.NotNil(t, result)
|
|
assert.NotNil(t, result.Status)
|
|
assert.Equal(t, "failed", *result.Status)
|
|
// Check error
|
|
require.NotNil(t, result.Error)
|
|
assert.Equal(t, "provider_error", result.Error.Code)
|
|
assert.Equal(t, errorMsg, result.Error.Message)
|
|
},
|
|
},
|
|
{
|
|
name: "Cancelled_Response",
|
|
input: &replicate.ReplicatePredictionResponse{
|
|
ID: predictionID,
|
|
Model: model,
|
|
CreatedAt: createdAt,
|
|
Status: replicate.ReplicatePredictionStatusCanceled,
|
|
},
|
|
validate: func(t *testing.T, result *schemas.BifrostResponsesResponse) {
|
|
require.NotNil(t, result)
|
|
assert.NotNil(t, result.Status)
|
|
assert.Equal(t, "cancelled", *result.Status)
|
|
},
|
|
},
|
|
{
|
|
name: "InProgress_Response",
|
|
input: &replicate.ReplicatePredictionResponse{
|
|
ID: predictionID,
|
|
Model: model,
|
|
CreatedAt: createdAt,
|
|
Status: replicate.ReplicatePredictionStatusProcessing,
|
|
},
|
|
validate: func(t *testing.T, result *schemas.BifrostResponsesResponse) {
|
|
require.NotNil(t, result)
|
|
assert.NotNil(t, result.Status)
|
|
assert.Equal(t, "in_progress", *result.Status)
|
|
},
|
|
},
|
|
{
|
|
name: "Queued_Response",
|
|
input: &replicate.ReplicatePredictionResponse{
|
|
ID: predictionID,
|
|
Model: model,
|
|
CreatedAt: createdAt,
|
|
Status: replicate.ReplicatePredictionStatusStarting,
|
|
},
|
|
validate: func(t *testing.T, result *schemas.BifrostResponsesResponse) {
|
|
require.NotNil(t, result)
|
|
assert.NotNil(t, result.Status)
|
|
assert.Equal(t, "queued", *result.Status)
|
|
},
|
|
},
|
|
{
|
|
name: "Response_With_CompletedAt",
|
|
input: &replicate.ReplicatePredictionResponse{
|
|
ID: predictionID,
|
|
Model: model,
|
|
CreatedAt: createdAt,
|
|
CompletedAt: &completedAt,
|
|
Status: replicate.ReplicatePredictionStatusSucceeded,
|
|
Output: &replicate.ReplicateOutput{
|
|
OutputStr: schemas.Ptr("Done"),
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *schemas.BifrostResponsesResponse) {
|
|
require.NotNil(t, result)
|
|
assert.NotZero(t, result.CreatedAt)
|
|
assert.NotNil(t, result.CompletedAt)
|
|
assert.NotZero(t, *result.CompletedAt)
|
|
},
|
|
},
|
|
{
|
|
name: "Response_With_Partial_Usage",
|
|
input: &replicate.ReplicatePredictionResponse{
|
|
ID: predictionID,
|
|
Model: model,
|
|
CreatedAt: createdAt,
|
|
Status: replicate.ReplicatePredictionStatusSucceeded,
|
|
Output: &replicate.ReplicateOutput{
|
|
OutputStr: schemas.Ptr("Response"),
|
|
},
|
|
Logs: schemas.Ptr("Input token count: 15\nOutput token count: 0"),
|
|
},
|
|
validate: func(t *testing.T, result *schemas.BifrostResponsesResponse) {
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.Usage)
|
|
assert.Equal(t, 15, result.Usage.InputTokens)
|
|
assert.Equal(t, 0, result.Usage.OutputTokens)
|
|
assert.Equal(t, 15, result.Usage.TotalTokens)
|
|
},
|
|
},
|
|
{
|
|
name: "Empty_Output_Content",
|
|
input: &replicate.ReplicatePredictionResponse{
|
|
ID: predictionID,
|
|
Model: model,
|
|
CreatedAt: createdAt,
|
|
Status: replicate.ReplicatePredictionStatusSucceeded,
|
|
Output: &replicate.ReplicateOutput{
|
|
OutputStr: schemas.Ptr(""),
|
|
},
|
|
},
|
|
validate: func(t *testing.T, result *schemas.BifrostResponsesResponse) {
|
|
require.NotNil(t, result)
|
|
// Empty content should not create output messages
|
|
assert.Empty(t, result.Output)
|
|
},
|
|
},
|
|
{
|
|
name: "No_Output",
|
|
input: &replicate.ReplicatePredictionResponse{
|
|
ID: predictionID,
|
|
Model: model,
|
|
CreatedAt: createdAt,
|
|
Status: replicate.ReplicatePredictionStatusProcessing,
|
|
Output: nil,
|
|
},
|
|
validate: func(t *testing.T, result *schemas.BifrostResponsesResponse) {
|
|
require.NotNil(t, result)
|
|
assert.Empty(t, result.Output)
|
|
},
|
|
},
|
|
{
|
|
name: "Empty_Error_Not_Set",
|
|
input: &replicate.ReplicatePredictionResponse{
|
|
ID: predictionID,
|
|
Model: model,
|
|
CreatedAt: createdAt,
|
|
Status: replicate.ReplicatePredictionStatusFailed,
|
|
Error: schemas.Ptr(""),
|
|
},
|
|
validate: func(t *testing.T, result *schemas.BifrostResponsesResponse) {
|
|
require.NotNil(t, result)
|
|
// Empty error string should not set error field
|
|
assert.Nil(t, result.Error)
|
|
},
|
|
},
|
|
{
|
|
name: "Nil_Response",
|
|
input: nil,
|
|
validate: func(t *testing.T, result *schemas.BifrostResponsesResponse) {
|
|
assert.Nil(t, result)
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
actual := tt.input.ToBifrostResponsesResponse()
|
|
if tt.validate != nil {
|
|
tt.validate(t, actual)
|
|
}
|
|
})
|
|
}
|
|
}
|