3293 lines
126 KiB
Go
3293 lines
126 KiB
Go
// Package providers implements various LLM providers and their utility functions.
|
|
// This file contains the replicate provider implementation.
|
|
package replicate
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"net/textproto"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/bytedance/sonic"
|
|
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
|
schemas "github.com/maximhq/bifrost/core/schemas"
|
|
"github.com/valyala/fasthttp"
|
|
)
|
|
|
|
// ReplicateProvider implements the Provider interface for Replicate's API.
|
|
type ReplicateProvider struct {
|
|
logger schemas.Logger // Logger for provider operations
|
|
client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response)
|
|
streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader)
|
|
networkConfig schemas.NetworkConfig // Network configuration including extra headers
|
|
sendBackRawRequest bool // Whether to include raw request in BifrostResponse
|
|
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
|
|
customProviderConfig *schemas.CustomProviderConfig
|
|
}
|
|
|
|
// NewReplicateProvider creates a new Replicate provider instance.
|
|
// It initializes the HTTP client with the provided configuration and sets up response pools.
|
|
// The client is configured with timeouts, concurrency limits, and optional proxy settings.
|
|
func NewReplicateProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*ReplicateProvider, error) {
|
|
config.CheckAndSetDefaults()
|
|
|
|
requestTimeout := time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)
|
|
client := &fasthttp.Client{
|
|
ReadTimeout: requestTimeout,
|
|
WriteTimeout: requestTimeout,
|
|
MaxConnsPerHost: config.NetworkConfig.MaxConnsPerHost,
|
|
MaxIdleConnDuration: 30 * time.Second,
|
|
MaxConnWaitTimeout: requestTimeout,
|
|
MaxConnDuration: time.Second * time.Duration(schemas.DefaultMaxConnDurationInSeconds),
|
|
ConnPoolStrategy: fasthttp.FIFO,
|
|
}
|
|
|
|
// Configure proxy and retry policy
|
|
client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger)
|
|
client = providerUtils.ConfigureDialer(client)
|
|
client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger)
|
|
streamingClient := providerUtils.BuildStreamingClient(client)
|
|
config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/")
|
|
|
|
if config.NetworkConfig.BaseURL == "" {
|
|
config.NetworkConfig.BaseURL = replicateAPIBaseURL
|
|
}
|
|
|
|
return &ReplicateProvider{
|
|
logger: logger,
|
|
client: client,
|
|
streamingClient: streamingClient,
|
|
networkConfig: config.NetworkConfig,
|
|
sendBackRawRequest: config.SendBackRawRequest,
|
|
sendBackRawResponse: config.SendBackRawResponse,
|
|
customProviderConfig: config.CustomProviderConfig,
|
|
}, nil
|
|
}
|
|
|
|
// GetProviderKey returns the provider identifier for Replicate.
|
|
func (provider *ReplicateProvider) GetProviderKey() schemas.ModelProvider {
|
|
return schemas.Replicate
|
|
}
|
|
|
|
// buildRequestURL builds the request URL with custom provider config support
|
|
func (provider *ReplicateProvider) buildRequestURL(ctx *schemas.BifrostContext, defaultPath string, requestType schemas.RequestType) string {
|
|
path, isCompleteURL := providerUtils.GetRequestPath(ctx, defaultPath, provider.customProviderConfig, requestType)
|
|
if isCompleteURL {
|
|
return path
|
|
}
|
|
return provider.networkConfig.BaseURL + path
|
|
}
|
|
|
|
const (
|
|
replicateAPIBaseURL = "https://api.replicate.com"
|
|
pollingInterval = 2 * time.Second
|
|
)
|
|
|
|
// useDeploymentsEndpoint returns whether the key uses the deployments endpoint.
|
|
// Nil ReplicateKeyConfig is treated as false (default models/predictions behavior).
|
|
func useDeploymentsEndpoint(key schemas.Key) bool {
|
|
return key.ReplicateKeyConfig != nil && key.ReplicateKeyConfig.UseDeploymentsEndpoint
|
|
}
|
|
|
|
// createPrediction creates a new prediction on Replicate API
|
|
// Supports both sync (with Prefer: wait header) and async modes
|
|
// stripPrefer should be true for streaming requests to exclude the Prefer header
|
|
func createPrediction(
|
|
ctx *schemas.BifrostContext,
|
|
client *fasthttp.Client,
|
|
jsonBody []byte,
|
|
key schemas.Key,
|
|
url string,
|
|
extraHeaders map[string]string,
|
|
stripPrefer bool,
|
|
logger schemas.Logger,
|
|
sendBackRawRequest bool,
|
|
sendBackRawResponse bool,
|
|
) (*ReplicatePredictionResponse, interface{}, time.Duration, map[string]string, *schemas.BifrostError) {
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
// Set URL
|
|
req.SetRequestURI(url)
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
// Set authorization header
|
|
if value := key.Value.GetValue(); value != "" {
|
|
req.Header.Set("Authorization", "Bearer "+value)
|
|
}
|
|
|
|
// Set any extra headers from network config
|
|
// Strip Prefer header for streaming requests to ensure async mode
|
|
headersToUse := extraHeaders
|
|
if stripPrefer {
|
|
headersToUse = stripPreferHeader(extraHeaders)
|
|
}
|
|
providerUtils.SetExtraHeaders(ctx, req, headersToUse, nil)
|
|
|
|
if !providerUtils.ApplyLargePayloadRequestBodyWithModelNormalization(ctx, req, schemas.Replicate) {
|
|
req.SetBody(jsonBody)
|
|
}
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, client, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, nil, latency, nil, bifrostErr
|
|
}
|
|
|
|
// Extract provider response headers before releasing the response
|
|
providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusCreated {
|
|
logger.Debug(fmt.Sprintf("error from replicate provider: %s", string(resp.Body())))
|
|
return nil, nil, latency, providerResponseHeaders, parseReplicateError(resp.Body(), resp.StatusCode())
|
|
}
|
|
|
|
// Parse response
|
|
body, decodeErr := providerUtils.CheckAndDecodeBody(resp)
|
|
if decodeErr != nil {
|
|
return nil, nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr)
|
|
}
|
|
|
|
var prediction ReplicatePredictionResponse
|
|
_, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &prediction, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, sendBackRawResponse))
|
|
if bifrostErr != nil {
|
|
return nil, nil, latency, providerResponseHeaders, bifrostErr
|
|
}
|
|
|
|
return &prediction, rawResponse, latency, providerResponseHeaders, nil
|
|
}
|
|
|
|
// getPrediction retrieves the current state of a prediction
|
|
func getPrediction(
|
|
ctx *schemas.BifrostContext,
|
|
client *fasthttp.Client,
|
|
predictionURL string,
|
|
key schemas.Key,
|
|
logger schemas.Logger,
|
|
sendBackRawResponse bool,
|
|
) (*ReplicatePredictionResponse, interface{}, map[string]string, *schemas.BifrostError) {
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
// Set URL
|
|
req.SetRequestURI(predictionURL)
|
|
req.Header.SetMethod(http.MethodGet)
|
|
|
|
// Set authorization header
|
|
if value := key.Value.GetValue(); value != "" {
|
|
req.Header.Set("Authorization", "Bearer "+value)
|
|
}
|
|
|
|
// Make request
|
|
_, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, client, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, nil, nil, bifrostErr
|
|
}
|
|
|
|
// Extract provider response headers before releasing the response
|
|
providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
logger.Debug(fmt.Sprintf("error from replicate provider: %s", string(resp.Body())))
|
|
return nil, nil, providerResponseHeaders, parseReplicateError(resp.Body(), resp.StatusCode())
|
|
}
|
|
|
|
// Parse response
|
|
body, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
return nil, nil, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
|
}
|
|
|
|
prediction := &ReplicatePredictionResponse{}
|
|
_, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, prediction, nil, false, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
return nil, nil, providerResponseHeaders, bifrostErr
|
|
}
|
|
|
|
return prediction, rawResponse, providerResponseHeaders, nil
|
|
}
|
|
|
|
// pollPrediction polls a prediction URL until it reaches a terminal state or timeout
|
|
func pollPrediction(
|
|
ctx *schemas.BifrostContext,
|
|
client *fasthttp.Client,
|
|
predictionURL string,
|
|
key schemas.Key,
|
|
timeoutSeconds int,
|
|
logger schemas.Logger,
|
|
sendBackRawResponse bool,
|
|
) (*ReplicatePredictionResponse, interface{}, map[string]string, *schemas.BifrostError) {
|
|
// Create context with timeout
|
|
pollCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
|
|
defer cancel()
|
|
|
|
ticker := time.NewTicker(pollingInterval)
|
|
defer ticker.Stop()
|
|
|
|
// Poll immediately first time
|
|
prediction, rawResponse, providerResponseHeaders, err := getPrediction(pollCtx, client, predictionURL, key, logger, sendBackRawResponse)
|
|
if err != nil {
|
|
return nil, nil, providerResponseHeaders, err
|
|
}
|
|
|
|
// If already in terminal state, return immediately
|
|
if isTerminalStatus(prediction.Status) {
|
|
return prediction, rawResponse, providerResponseHeaders, checkForErrorStatus(prediction)
|
|
}
|
|
|
|
logger.Debug(fmt.Sprintf("polling replicate prediction %s, status: %s", prediction.ID, prediction.Status))
|
|
|
|
// Continue polling until terminal state or timeout
|
|
for {
|
|
select {
|
|
case <-pollCtx.Done():
|
|
return nil, nil, providerResponseHeaders, providerUtils.NewBifrostOperationError(
|
|
schemas.ErrProviderRequestTimedOut,
|
|
fmt.Errorf("prediction polling timed out after %d seconds", timeoutSeconds))
|
|
case <-ticker.C:
|
|
prediction, rawResponse, providerResponseHeaders, err = getPrediction(pollCtx, client, predictionURL, key, logger, sendBackRawResponse)
|
|
if err != nil {
|
|
return nil, nil, providerResponseHeaders, err
|
|
}
|
|
|
|
logger.Debug(fmt.Sprintf("prediction %s status: %s", prediction.ID, prediction.Status))
|
|
|
|
if isTerminalStatus(prediction.Status) {
|
|
return prediction, rawResponse, providerResponseHeaders, checkForErrorStatus(prediction)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// listDeploymentsByKey performs a list deployments request for a single key.
|
|
// Deployments are account-specific, so this needs to be called per key.
|
|
func (provider *ReplicateProvider) listDeploymentsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
|
|
providerName := provider.GetProviderKey()
|
|
client := provider.client
|
|
extraHeaders := provider.networkConfig.ExtraHeaders
|
|
|
|
if !useDeploymentsEndpoint(key) {
|
|
return ToBifrostListModelsResponse(
|
|
&ReplicateDeploymentListResponse{},
|
|
providerName,
|
|
key.Models,
|
|
key.BlacklistedModels,
|
|
key.Aliases,
|
|
request.Unfiltered,
|
|
), nil
|
|
}
|
|
|
|
// Build deployments URL
|
|
deploymentsURL := provider.buildRequestURL(ctx, "/v1/deployments", schemas.ListModelsRequest)
|
|
|
|
// Initialize pagination variables
|
|
currentURL := deploymentsURL
|
|
allDeployments := []ReplicateDeployment{}
|
|
|
|
// Follow pagination until there are no more pages
|
|
for currentURL != "" {
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
|
|
// Set up request
|
|
req.SetRequestURI(currentURL)
|
|
req.Header.SetMethod(http.MethodGet)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
// Set authorization header if key is provided
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Set extra headers from network config
|
|
providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil)
|
|
|
|
// Make request
|
|
_, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, client, req, resp)
|
|
|
|
// Release resources
|
|
wait()
|
|
fasthttp.ReleaseRequest(req)
|
|
|
|
if bifrostErr != nil {
|
|
fasthttp.ReleaseResponse(resp)
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
errorResponse := parseReplicateError(resp.Body(), resp.StatusCode())
|
|
fasthttp.ReleaseResponse(resp)
|
|
return nil, errorResponse
|
|
}
|
|
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
|
|
|
// Make a copy of the response body before releasing
|
|
bodyCopy := make([]byte, len(resp.Body()))
|
|
copy(bodyCopy, resp.Body())
|
|
|
|
fasthttp.ReleaseResponse(resp)
|
|
|
|
// Parse response from the copy
|
|
var pageResponse ReplicateDeploymentListResponse
|
|
if err := sonic.Unmarshal(bodyCopy, &pageResponse); err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(
|
|
"failed to parse deployments response",
|
|
err)
|
|
}
|
|
|
|
// Append results from this page
|
|
allDeployments = append(allDeployments, pageResponse.Results...)
|
|
|
|
// Check if there's a next page
|
|
if pageResponse.Next != nil && *pageResponse.Next != "" {
|
|
currentURL = *pageResponse.Next
|
|
} else {
|
|
currentURL = ""
|
|
}
|
|
}
|
|
|
|
// Wrap deployments in response structure
|
|
deploymentsResponse := &ReplicateDeploymentListResponse{
|
|
Results: allDeployments,
|
|
}
|
|
|
|
// Convert deployments to Bifrost response (no public models here)
|
|
response := ToBifrostListModelsResponse(
|
|
deploymentsResponse,
|
|
providerName,
|
|
key.Models,
|
|
key.BlacklistedModels,
|
|
key.Aliases,
|
|
request.Unfiltered,
|
|
)
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// ListModels performs a list models request to Replicate's API.
|
|
func (provider *ReplicateProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.Replicate, provider.customProviderConfig, schemas.ListModelsRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if provider.networkConfig.BaseURL == "" {
|
|
return nil, providerUtils.NewConfigurationError("base_url is not set")
|
|
}
|
|
|
|
startTime := time.Now()
|
|
|
|
response, err := providerUtils.HandleMultipleListModelsRequests(
|
|
ctx,
|
|
keys,
|
|
request,
|
|
provider.listDeploymentsByKey,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Update metadata with total latency
|
|
latency := time.Since(startTime)
|
|
response.ExtraFields.Latency = latency.Milliseconds()
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// TextCompletion performs a text completion request to the replicate API.
|
|
func (provider *ReplicateProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.Replicate, provider.customProviderConfig, schemas.TextCompletionRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// build replicate request
|
|
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateTextRequest(request) })
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Check for Prefer: wait header from context for sync mode
|
|
isSync := parsePreferHeader(provider.networkConfig.ExtraHeaders)
|
|
|
|
// Build prediction URL based on model type (version ID or model name)
|
|
predictionURL := buildPredictionURL(
|
|
ctx,
|
|
provider.networkConfig.BaseURL,
|
|
request.Model,
|
|
provider.customProviderConfig,
|
|
schemas.TextCompletionRequest,
|
|
useDeploymentsEndpoint(key),
|
|
)
|
|
|
|
// create prediction
|
|
prediction, rawResponse, latency, providerResponseHeaders, err := createPrediction(
|
|
ctx,
|
|
provider.client,
|
|
jsonData,
|
|
key,
|
|
providerUtils.GetPathFromContext(ctx, predictionURL),
|
|
provider.networkConfig.ExtraHeaders,
|
|
false,
|
|
provider.logger,
|
|
provider.sendBackRawRequest,
|
|
provider.sendBackRawResponse,
|
|
)
|
|
if err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
// if not sync, poll until done
|
|
if !isSync && !isTerminalStatus(prediction.Status) {
|
|
prediction, rawResponse, providerResponseHeaders, err = pollPrediction(
|
|
ctx,
|
|
provider.client,
|
|
prediction.URLs.Get,
|
|
key,
|
|
provider.networkConfig.DefaultRequestTimeoutInSeconds,
|
|
provider.logger,
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
)
|
|
if err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
}
|
|
|
|
// Check for terminal error status (failed/canceled) after sync mode or polling
|
|
if err := checkForErrorStatus(prediction); err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
if providerResponseHeaders != nil {
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
}
|
|
|
|
// Convert to Bifrost response
|
|
bifrostResponse := prediction.ToBifrostTextCompletionResponse()
|
|
|
|
// Set extra fields
|
|
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
|
|
bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData)
|
|
}
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
bifrostResponse.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return bifrostResponse, nil
|
|
}
|
|
|
|
// TextCompletionStream performs a streaming text completion request to replicate's API.
|
|
// It formats the request, sends it to replicate, and processes the response.
|
|
// Returns a channel of BifrostStream objects or an error if the request fails.
|
|
func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.Replicate, provider.customProviderConfig, schemas.TextCompletionStreamRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Convert Bifrost request to Replicate format with streaming enabled
|
|
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
replicateReq, err := ToReplicateTextRequest(request)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
replicateReq.Stream = schemas.Ptr(true)
|
|
return replicateReq, nil
|
|
})
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Build prediction URL based on model type (version ID or model name)
|
|
predictionURL := buildPredictionURL(
|
|
ctx,
|
|
provider.networkConfig.BaseURL,
|
|
request.Model,
|
|
provider.customProviderConfig,
|
|
schemas.TextCompletionStreamRequest,
|
|
useDeploymentsEndpoint(key),
|
|
)
|
|
|
|
// Create prediction
|
|
prediction, _, _, _, err := createPrediction(
|
|
ctx,
|
|
provider.client,
|
|
jsonData,
|
|
key,
|
|
providerUtils.GetPathFromContext(ctx, predictionURL),
|
|
provider.networkConfig.ExtraHeaders,
|
|
true, // Streaming request, strip Prefer header for async mode
|
|
provider.logger,
|
|
provider.sendBackRawRequest,
|
|
provider.sendBackRawResponse,
|
|
)
|
|
if err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
// Verify stream URL is available
|
|
if prediction.URLs == nil || prediction.URLs.Stream == nil || *prediction.URLs.Stream == "" {
|
|
bifrostErr := providerUtils.NewBifrostOperationError(
|
|
"stream URL not available in prediction response",
|
|
fmt.Errorf("prediction response missing stream URL"))
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
streamURL := *prediction.URLs.Stream
|
|
|
|
// Connect to stream URL
|
|
_, resp, bifrostErr := listenToReplicateStreamURL(ctx, provider.streamingClient, streamURL, key)
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
// Store provider response headers in context for transport layer
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
|
|
|
// Large payload streaming passthrough — pipe raw upstream SSE to client
|
|
if providerUtils.SetupStreamingPassthrough(ctx, resp) {
|
|
responseChan := make(chan *schemas.BifrostStreamChunk)
|
|
close(responseChan)
|
|
return responseChan, nil
|
|
}
|
|
|
|
// Create response channel
|
|
responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize)
|
|
|
|
providerUtils.SetStreamIdleTimeoutIfEmpty(ctx, provider.networkConfig.StreamIdleTimeoutInSeconds)
|
|
|
|
// Start streaming in a goroutine
|
|
go func() {
|
|
defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer)
|
|
defer func() {
|
|
if ctx.Err() == context.Canceled {
|
|
providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer)
|
|
} else if ctx.Err() == context.DeadlineExceeded {
|
|
providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer)
|
|
}
|
|
close(responseChan)
|
|
}()
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
|
|
// Decompress gzip-encoded streams transparently (no-op for non-gzip)
|
|
reader, releaseGzip := providerUtils.DecompressStreamBody(resp)
|
|
defer releaseGzip()
|
|
|
|
// Wrap reader with idle timeout to detect stalled streams.
|
|
reader, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(reader, resp.BodyStream(), providerUtils.GetStreamIdleTimeout(ctx))
|
|
defer stopIdleTimeout()
|
|
|
|
// Setup cancellation handler to close the raw network stream on ctx cancellation,
|
|
// which immediately unblocks any in-progress read (including reads blocked inside a gzip decompression layer).
|
|
stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger)
|
|
defer stopCancellation()
|
|
|
|
startTime := time.Now()
|
|
lastChunkTime := startTime
|
|
chunkIndex := 0
|
|
|
|
// Setup SSE event reader for event+data format
|
|
sseReader := providerUtils.GetSSEEventReader(ctx, reader)
|
|
|
|
messageID := prediction.ID
|
|
|
|
for {
|
|
// Check for context cancellation
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
eventType, eventDataBytes, readErr := sseReader.ReadEvent()
|
|
if readErr != nil {
|
|
if readErr != io.EOF {
|
|
// If context was cancelled/timed out, let defer handle it
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
provider.logger.Warn("Error reading stream: %v", readErr)
|
|
enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger, postHookSpanFinalizer)
|
|
}
|
|
break
|
|
}
|
|
|
|
eventData := string(eventDataBytes)
|
|
if eventType == "" && eventData == "" {
|
|
continue
|
|
}
|
|
|
|
// Process the complete event
|
|
switch eventType {
|
|
case "output":
|
|
// Accumulate content from data field
|
|
if eventData != "" {
|
|
// Create a streaming chunk with text completion response
|
|
text := eventData
|
|
response := &schemas.BifrostTextCompletionResponse{
|
|
ID: messageID,
|
|
Model: request.Model,
|
|
Object: "text_completion",
|
|
Choices: []schemas.BifrostResponseChoice{
|
|
{
|
|
Index: 0,
|
|
TextCompletionResponseChoice: &schemas.TextCompletionResponseChoice{
|
|
Text: &text,
|
|
},
|
|
},
|
|
},
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
ChunkIndex: chunkIndex,
|
|
Latency: time.Since(lastChunkTime).Milliseconds(),
|
|
},
|
|
}
|
|
|
|
// Set raw response if enabled (per-chunk event as JSON string)
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
rawEvent := ReplicateSSEEvent{Event: eventType, Data: eventData}
|
|
if eventJSON, err := providerUtils.MarshalSorted(rawEvent); err == nil {
|
|
response.ExtraFields.RawResponse = string(eventJSON)
|
|
}
|
|
}
|
|
|
|
lastChunkTime = time.Now()
|
|
chunkIndex++
|
|
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner,
|
|
providerUtils.GetBifrostResponseForStreamResponse(response, nil, nil, nil, nil, nil),
|
|
responseChan, postHookSpanFinalizer)
|
|
}
|
|
|
|
case "done":
|
|
// Parse done event data
|
|
var doneData ReplicateDoneEvent
|
|
if eventData != "" && eventData != "{}" {
|
|
if err := sonic.Unmarshal(eventDataBytes, &doneData); err != nil {
|
|
provider.logger.Warn(fmt.Sprintf("Failed to parse done event data: %v", err))
|
|
}
|
|
}
|
|
|
|
// Check for cancellation or error
|
|
switch doneData.Reason {
|
|
case "canceled":
|
|
bifrostErr := providerUtils.NewBifrostOperationError(
|
|
"prediction was canceled",
|
|
fmt.Errorf("stream ended: prediction canceled"))
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger, postHookSpanFinalizer)
|
|
// Explicitly close the body stream to terminate connection to Replicate
|
|
resp.CloseBodyStream()
|
|
return
|
|
|
|
case "error":
|
|
errorMsg := "prediction failed"
|
|
if doneData.Output != nil {
|
|
errorMsg = fmt.Sprintf("prediction failed: %v", doneData.Output)
|
|
}
|
|
bifrostErr := providerUtils.NewBifrostOperationError(
|
|
errorMsg,
|
|
fmt.Errorf("stream ended with error"))
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger, postHookSpanFinalizer)
|
|
// Explicitly close the body stream to terminate connection to Replicate
|
|
resp.CloseBodyStream()
|
|
return
|
|
}
|
|
|
|
// Send final chunk with finish reason
|
|
finishReason := schemas.Ptr("stop")
|
|
finalResponse := providerUtils.CreateBifrostTextCompletionChunkResponse(
|
|
messageID,
|
|
nil, // usage - not available in done event
|
|
finishReason,
|
|
chunkIndex,
|
|
schemas.TextCompletionStreamRequest)
|
|
|
|
// Set raw request if enabled
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
providerUtils.ParseAndSetRawRequest(&finalResponse.ExtraFields, jsonData)
|
|
}
|
|
|
|
finalResponse.ExtraFields.Latency = time.Since(startTime).Milliseconds()
|
|
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner,
|
|
providerUtils.GetBifrostResponseForStreamResponse(finalResponse, nil, nil, nil, nil, nil),
|
|
responseChan, postHookSpanFinalizer)
|
|
resp.CloseBodyStream()
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
return responseChan, nil
|
|
}
|
|
|
|
// ChatCompletion performs a chat completion request to the replicate API.
|
|
func (provider *ReplicateProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.Replicate, provider.customProviderConfig, schemas.ChatCompletionRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// build replicate request
|
|
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateChatRequest(request) })
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Check for Prefer: wait header from context for sync mode
|
|
isSync := parsePreferHeader(provider.networkConfig.ExtraHeaders)
|
|
|
|
// Build prediction URL based on model type (version ID or model name)
|
|
predictionURL := buildPredictionURL(
|
|
ctx,
|
|
provider.networkConfig.BaseURL,
|
|
request.Model,
|
|
provider.customProviderConfig,
|
|
schemas.ChatCompletionRequest,
|
|
useDeploymentsEndpoint(key),
|
|
)
|
|
|
|
// create prediction
|
|
prediction, rawResponse, latency, providerResponseHeaders, err := createPrediction(
|
|
ctx,
|
|
provider.client,
|
|
jsonData,
|
|
key,
|
|
providerUtils.GetPathFromContext(ctx, predictionURL),
|
|
provider.networkConfig.ExtraHeaders,
|
|
false,
|
|
provider.logger,
|
|
provider.sendBackRawRequest,
|
|
provider.sendBackRawResponse,
|
|
)
|
|
if err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
// if not sync, poll until done
|
|
if !isSync && !isTerminalStatus(prediction.Status) {
|
|
prediction, rawResponse, providerResponseHeaders, err = pollPrediction(
|
|
ctx,
|
|
provider.client,
|
|
prediction.URLs.Get,
|
|
key,
|
|
provider.networkConfig.DefaultRequestTimeoutInSeconds,
|
|
provider.logger,
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
)
|
|
if err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
}
|
|
|
|
// Check for terminal error status (failed/canceled) after sync mode or polling
|
|
if err := checkForErrorStatus(prediction); err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
if providerResponseHeaders != nil {
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
}
|
|
|
|
// Convert to Bifrost response
|
|
bifrostResponse := prediction.ToBifrostChatResponse()
|
|
|
|
// Set extra fields
|
|
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
|
|
bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData)
|
|
}
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
bifrostResponse.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return bifrostResponse, nil
|
|
}
|
|
|
|
// ChatCompletionStream performs a streaming chat completion request to the replicate API.
|
|
// It supports real-time streaming of responses using Server-Sent Events (SSE).
|
|
// Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails.
|
|
func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.Replicate, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Convert Bifrost request to Replicate format with streaming enabled
|
|
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
replicateReq, err := ToReplicateChatRequest(request)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
replicateReq.Stream = schemas.Ptr(true)
|
|
return replicateReq, nil
|
|
})
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Build prediction URL based on model type (version ID or model name)
|
|
predictionURL := buildPredictionURL(
|
|
ctx,
|
|
provider.networkConfig.BaseURL,
|
|
request.Model,
|
|
provider.customProviderConfig,
|
|
schemas.ChatCompletionStreamRequest,
|
|
useDeploymentsEndpoint(key),
|
|
)
|
|
|
|
// Create prediction
|
|
prediction, _, _, _, err := createPrediction(
|
|
ctx,
|
|
provider.client,
|
|
jsonData,
|
|
key,
|
|
providerUtils.GetPathFromContext(ctx, predictionURL),
|
|
provider.networkConfig.ExtraHeaders,
|
|
true,
|
|
provider.logger,
|
|
provider.sendBackRawRequest,
|
|
provider.sendBackRawResponse,
|
|
)
|
|
if err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
// Verify stream URL is available
|
|
if prediction.URLs == nil || prediction.URLs.Stream == nil || *prediction.URLs.Stream == "" {
|
|
bifrostErr := providerUtils.NewBifrostOperationError(
|
|
"stream URL not available in prediction response",
|
|
fmt.Errorf("prediction response missing stream URL"))
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
streamURL := *prediction.URLs.Stream
|
|
|
|
// Connect to stream URL
|
|
_, resp, bifrostErr := listenToReplicateStreamURL(ctx, provider.streamingClient, streamURL, key)
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
// Store provider response headers in context for transport layer
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
|
|
|
// Large payload streaming passthrough — pipe raw upstream SSE to client
|
|
if providerUtils.SetupStreamingPassthrough(ctx, resp) {
|
|
responseChan := make(chan *schemas.BifrostStreamChunk)
|
|
close(responseChan)
|
|
return responseChan, nil
|
|
}
|
|
|
|
// Create response channel
|
|
responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize)
|
|
|
|
providerUtils.SetStreamIdleTimeoutIfEmpty(ctx, provider.networkConfig.StreamIdleTimeoutInSeconds)
|
|
|
|
// Start streaming in a goroutine
|
|
go func() {
|
|
defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer)
|
|
defer func() {
|
|
if ctx.Err() == context.Canceled {
|
|
providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer)
|
|
} else if ctx.Err() == context.DeadlineExceeded {
|
|
providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer)
|
|
}
|
|
close(responseChan)
|
|
}()
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
|
|
// Decompress gzip-encoded streams transparently (no-op for non-gzip)
|
|
reader, releaseGzip := providerUtils.DecompressStreamBody(resp)
|
|
defer releaseGzip()
|
|
|
|
// Wrap reader with idle timeout to detect stalled streams.
|
|
reader, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(reader, resp.BodyStream(), providerUtils.GetStreamIdleTimeout(ctx))
|
|
defer stopIdleTimeout()
|
|
|
|
// Setup cancellation handler to close the raw network stream on ctx cancellation,
|
|
// which immediately unblocks any in-progress read (including reads blocked inside a gzip decompression layer).
|
|
stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger)
|
|
defer stopCancellation()
|
|
|
|
startTime := time.Now()
|
|
lastChunkTime := startTime
|
|
chunkIndex := 0
|
|
|
|
// Setup SSE event reader for event+data format
|
|
sseReader := providerUtils.GetSSEEventReader(ctx, reader)
|
|
|
|
messageID := prediction.ID
|
|
|
|
for {
|
|
// Check for context cancellation
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
eventType, eventDataBytes, readErr := sseReader.ReadEvent()
|
|
if readErr != nil {
|
|
if readErr != io.EOF {
|
|
// If context was cancelled/timed out, let defer handle it
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
provider.logger.Warn("Error reading stream: %v", readErr)
|
|
enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger, postHookSpanFinalizer)
|
|
}
|
|
break
|
|
}
|
|
|
|
eventData := string(eventDataBytes)
|
|
if eventType == "" && eventData == "" {
|
|
continue
|
|
}
|
|
|
|
// Process the complete event
|
|
switch eventType {
|
|
case "output":
|
|
// Accumulate content from data field
|
|
if eventData != "" {
|
|
// Create a streaming chunk
|
|
content := eventData
|
|
role := string(schemas.ChatMessageRoleAssistant)
|
|
delta := &schemas.ChatStreamResponseChoiceDelta{
|
|
Content: &content,
|
|
Role: &role,
|
|
}
|
|
|
|
response := &schemas.BifrostChatResponse{
|
|
ID: messageID,
|
|
Model: request.Model,
|
|
Object: "chat.completion.chunk",
|
|
Created: int(time.Now().Unix()),
|
|
Choices: []schemas.BifrostResponseChoice{
|
|
{
|
|
Index: 0,
|
|
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
|
Delta: delta,
|
|
},
|
|
},
|
|
},
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
ChunkIndex: chunkIndex,
|
|
Latency: time.Since(lastChunkTime).Milliseconds(),
|
|
},
|
|
}
|
|
|
|
// Set raw response if enabled (per-chunk event as JSON string)
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
rawEvent := ReplicateSSEEvent{Event: eventType, Data: eventData}
|
|
if eventJSON, err := providerUtils.MarshalSorted(rawEvent); err == nil {
|
|
response.ExtraFields.RawResponse = string(eventJSON)
|
|
}
|
|
}
|
|
|
|
lastChunkTime = time.Now()
|
|
chunkIndex++
|
|
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner,
|
|
providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil),
|
|
responseChan, postHookSpanFinalizer)
|
|
}
|
|
|
|
case "done":
|
|
// Parse done event data
|
|
var doneData ReplicateDoneEvent
|
|
if eventData != "" && eventData != "{}" {
|
|
if err := sonic.Unmarshal(eventDataBytes, &doneData); err != nil {
|
|
provider.logger.Warn(fmt.Sprintf("Failed to parse done event data: %v", err))
|
|
}
|
|
}
|
|
|
|
// Check for cancellation or error
|
|
switch doneData.Reason {
|
|
case "canceled":
|
|
bifrostErr := providerUtils.NewBifrostOperationError(
|
|
"prediction was canceled",
|
|
fmt.Errorf("stream ended: prediction canceled"))
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger, postHookSpanFinalizer)
|
|
// Explicitly close the body stream to terminate connection to Replicate
|
|
resp.CloseBodyStream()
|
|
return
|
|
|
|
case "error":
|
|
errorMsg := "prediction failed"
|
|
if doneData.Output != nil {
|
|
errorMsg = fmt.Sprintf("prediction failed: %v", doneData.Output)
|
|
}
|
|
bifrostErr := providerUtils.NewBifrostOperationError(
|
|
errorMsg,
|
|
fmt.Errorf("stream ended with error"))
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger, postHookSpanFinalizer)
|
|
// Explicitly close the body stream to terminate connection to Replicate
|
|
resp.CloseBodyStream()
|
|
return
|
|
}
|
|
|
|
// Send final chunk with finish reason
|
|
finishReason := "stop"
|
|
finalResponse := &schemas.BifrostChatResponse{
|
|
ID: messageID,
|
|
Model: request.Model,
|
|
Object: "chat.completion.chunk",
|
|
Created: int(time.Now().Unix()),
|
|
Choices: []schemas.BifrostResponseChoice{
|
|
{
|
|
Index: 0,
|
|
FinishReason: &finishReason,
|
|
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
|
Delta: &schemas.ChatStreamResponseChoiceDelta{},
|
|
},
|
|
},
|
|
},
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
ChunkIndex: chunkIndex,
|
|
Latency: time.Since(startTime).Milliseconds(),
|
|
},
|
|
}
|
|
|
|
// Set raw request if enabled
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
providerUtils.ParseAndSetRawRequest(&finalResponse.ExtraFields, jsonData)
|
|
}
|
|
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner,
|
|
providerUtils.GetBifrostResponseForStreamResponse(nil, finalResponse, nil, nil, nil, nil),
|
|
responseChan, postHookSpanFinalizer)
|
|
resp.CloseBodyStream()
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
return responseChan, nil
|
|
}
|
|
|
|
// Responses performs a responses request to the replicate API.
|
|
func (provider *ReplicateProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.Replicate, provider.customProviderConfig, schemas.ResponsesRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// build replicate request
|
|
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateResponsesRequest(request) })
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Check for Prefer: wait header from context for sync mode
|
|
isSync := parsePreferHeader(provider.networkConfig.ExtraHeaders)
|
|
|
|
// Build prediction URL based on model type (version ID or model name)
|
|
predictionURL := buildPredictionURL(
|
|
ctx,
|
|
provider.networkConfig.BaseURL,
|
|
request.Model,
|
|
provider.customProviderConfig,
|
|
schemas.ResponsesRequest,
|
|
useDeploymentsEndpoint(key),
|
|
)
|
|
|
|
// create prediction
|
|
prediction, rawResponse, latency, providerResponseHeaders, err := createPrediction(
|
|
ctx,
|
|
provider.client,
|
|
jsonData,
|
|
key,
|
|
providerUtils.GetPathFromContext(ctx, predictionURL),
|
|
provider.networkConfig.ExtraHeaders,
|
|
false,
|
|
provider.logger,
|
|
provider.sendBackRawRequest,
|
|
provider.sendBackRawResponse,
|
|
)
|
|
if err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
// if not sync, poll until done
|
|
if !isSync && !isTerminalStatus(prediction.Status) {
|
|
prediction, rawResponse, providerResponseHeaders, err = pollPrediction(
|
|
ctx,
|
|
provider.client,
|
|
prediction.URLs.Get,
|
|
key,
|
|
provider.networkConfig.DefaultRequestTimeoutInSeconds,
|
|
provider.logger,
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
)
|
|
if err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
}
|
|
|
|
// Check for terminal error status (failed/canceled) after sync mode or polling
|
|
if err := checkForErrorStatus(prediction); err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
if providerResponseHeaders != nil {
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
}
|
|
|
|
// Convert to Bifrost response
|
|
response := prediction.ToBifrostResponsesResponse()
|
|
response.ExtraFields.Latency = latency.Milliseconds()
|
|
response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonData)
|
|
}
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
return response, nil
|
|
}
|
|
|
|
// ResponsesStream performs a streaming responses request to the replicate API.
|
|
func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.Replicate, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Build replicate request
|
|
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateResponsesRequest(request) })
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Enable streaming (using sjson to set field directly, preserving key order)
|
|
if updatedData, err := providerUtils.SetJSONField(jsonData, "stream", true); err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("failed to set stream field", err)
|
|
} else {
|
|
jsonData = updatedData
|
|
}
|
|
|
|
// Build prediction URL
|
|
predictionURL := buildPredictionURL(
|
|
ctx,
|
|
provider.networkConfig.BaseURL,
|
|
request.Model,
|
|
provider.customProviderConfig,
|
|
schemas.ResponsesStreamRequest,
|
|
useDeploymentsEndpoint(key),
|
|
)
|
|
|
|
// Create prediction
|
|
prediction, _, _, _, err := createPrediction(
|
|
ctx,
|
|
provider.client,
|
|
jsonData,
|
|
key,
|
|
providerUtils.GetPathFromContext(ctx, predictionURL),
|
|
provider.networkConfig.ExtraHeaders,
|
|
true, // Streaming request, strip Prefer header for async mode
|
|
provider.logger,
|
|
provider.sendBackRawRequest,
|
|
provider.sendBackRawResponse,
|
|
)
|
|
if err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
// Verify stream URL is available
|
|
if prediction.URLs == nil || prediction.URLs.Stream == nil || *prediction.URLs.Stream == "" {
|
|
bifrostErr := providerUtils.NewBifrostOperationError(
|
|
"stream URL not available in prediction response",
|
|
fmt.Errorf("prediction response missing stream URL"))
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
streamURL := *prediction.URLs.Stream
|
|
|
|
// Setup request for streaming
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
resp.StreamBody = true
|
|
defer fasthttp.ReleaseRequest(req)
|
|
|
|
req.Header.SetMethod(http.MethodGet)
|
|
req.SetRequestURI(streamURL)
|
|
req.Header.Set("Accept", "text/event-stream")
|
|
|
|
// Set authorization
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Set extra headers
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
|
|
// Make the streaming request
|
|
streamErr := provider.streamingClient.Do(req, resp)
|
|
if streamErr != nil {
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
if errors.Is(streamErr, context.Canceled) {
|
|
return nil, providerUtils.EnrichError(ctx, &schemas.BifrostError{
|
|
IsBifrostError: false,
|
|
Error: &schemas.ErrorField{
|
|
Type: schemas.Ptr(schemas.RequestCancelled),
|
|
Message: schemas.ErrRequestCancelled,
|
|
Error: streamErr,
|
|
},
|
|
}, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
if errors.Is(streamErr, fasthttp.ErrTimeout) || errors.Is(streamErr, context.DeadlineExceeded) {
|
|
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, streamErr), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, streamErr), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
// Extract provider response headers before status check so error responses also forward them
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
|
|
|
// Check for HTTP errors
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
body := resp.Body()
|
|
return nil, providerUtils.EnrichError(ctx, parseReplicateError(body, resp.StatusCode()), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
// Large payload streaming passthrough — pipe raw upstream SSE to client
|
|
if providerUtils.SetupStreamingPassthrough(ctx, resp) {
|
|
responseChan := make(chan *schemas.BifrostStreamChunk)
|
|
close(responseChan)
|
|
return responseChan, nil
|
|
}
|
|
|
|
// Create response channel
|
|
responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize)
|
|
|
|
providerUtils.SetStreamIdleTimeoutIfEmpty(ctx, provider.networkConfig.StreamIdleTimeoutInSeconds)
|
|
|
|
// Start streaming in a goroutine
|
|
go func() {
|
|
// Registered first so the post-hook span finalizer runs on every exit
|
|
// path — including the empty-reader early return below, which would
|
|
// otherwise skip any finalizer declared later in this goroutine.
|
|
defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer)
|
|
defer func() {
|
|
if ctx.Err() == context.Canceled {
|
|
providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer)
|
|
} else if ctx.Err() == context.DeadlineExceeded {
|
|
providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer)
|
|
}
|
|
close(responseChan)
|
|
}()
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
|
|
// Decompress gzip-encoded streams transparently (no-op for non-gzip)
|
|
reader, releaseGzip := providerUtils.DecompressStreamBody(resp)
|
|
defer releaseGzip()
|
|
|
|
if reader == nil {
|
|
bifrostErr := providerUtils.NewBifrostOperationError(
|
|
"provider returned an empty response",
|
|
fmt.Errorf("provider returned an empty response"))
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse), responseChan, provider.logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
|
|
// Wrap reader with idle timeout to detect stalled streams.
|
|
reader, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(reader, resp.BodyStream(), providerUtils.GetStreamIdleTimeout(ctx))
|
|
defer stopIdleTimeout()
|
|
|
|
// Setup cancellation handler to close the raw network stream on ctx cancellation,
|
|
// which immediately unblocks any in-progress read (including reads blocked inside a gzip decompression layer).
|
|
stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger)
|
|
defer stopCancellation()
|
|
|
|
sseReader := providerUtils.GetSSEEventReader(ctx, reader)
|
|
startTime := time.Now()
|
|
sequenceNumber := 0
|
|
messageID := prediction.ID
|
|
// Generate a unique item ID for the message (needed for accumulator to track deltas)
|
|
itemID := "msg_" + messageID
|
|
|
|
// Track lifecycle state
|
|
var hasEmittedCreated, hasEmittedInProgress bool
|
|
var hasEmittedOutputItemAdded, hasEmittedContentPartAdded bool
|
|
var hasReceivedContent bool
|
|
outputIndex := 0
|
|
contentIndex := 0
|
|
|
|
// Accumulate raw responses for debugging
|
|
var rawResponseChunks []interface{}
|
|
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
|
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
|
|
|
for {
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
|
|
eventType, eventDataBytes, readErr := sseReader.ReadEvent()
|
|
if readErr != nil {
|
|
if readErr != io.EOF {
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
provider.logger.Warn("Error reading stream: %v", readErr)
|
|
bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr)
|
|
|
|
// Include accumulated raw responses in error
|
|
if sendBackRawResponse && len(rawResponseChunks) > 0 {
|
|
bifrostErr.ExtraFields.RawResponse = rawResponseChunks
|
|
}
|
|
|
|
enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
break
|
|
}
|
|
|
|
currentEvent := ReplicateSSEEvent{
|
|
Event: eventType,
|
|
Data: string(eventDataBytes),
|
|
}
|
|
|
|
if currentEvent.Event != "" {
|
|
// Process the event
|
|
switch currentEvent.Event {
|
|
case "output":
|
|
// Text chunk received
|
|
if currentEvent.Data != "" {
|
|
// Accumulate raw response if enabled
|
|
if sendBackRawResponse {
|
|
rawResponseChunks = append(rawResponseChunks, currentEvent)
|
|
}
|
|
|
|
// Emit lifecycle events on first content
|
|
if !hasEmittedCreated {
|
|
// response.created
|
|
createdResp := &schemas.BifrostResponsesStreamResponse{
|
|
Type: schemas.ResponsesStreamResponseTypeCreated,
|
|
SequenceNumber: sequenceNumber,
|
|
Response: &schemas.BifrostResponsesResponse{
|
|
ID: schemas.Ptr(messageID),
|
|
Model: request.Model,
|
|
CreatedAt: int(startTime.Unix()),
|
|
},
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: time.Since(startTime).Milliseconds(),
|
|
ChunkIndex: sequenceNumber,
|
|
},
|
|
}
|
|
if sendBackRawRequest {
|
|
providerUtils.ParseAndSetRawRequest(&createdResp.ExtraFields, jsonData)
|
|
}
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner,
|
|
providerUtils.GetBifrostResponseForStreamResponse(nil, nil, createdResp, nil, nil, nil),
|
|
responseChan, postHookSpanFinalizer)
|
|
sequenceNumber++
|
|
hasEmittedCreated = true
|
|
}
|
|
|
|
if !hasEmittedInProgress {
|
|
// response.in_progress
|
|
inProgressResp := &schemas.BifrostResponsesStreamResponse{
|
|
Type: schemas.ResponsesStreamResponseTypeInProgress,
|
|
SequenceNumber: sequenceNumber,
|
|
Response: &schemas.BifrostResponsesResponse{
|
|
ID: schemas.Ptr(messageID),
|
|
CreatedAt: int(startTime.Unix()),
|
|
},
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
ChunkIndex: sequenceNumber,
|
|
},
|
|
}
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner,
|
|
providerUtils.GetBifrostResponseForStreamResponse(nil, nil, inProgressResp, nil, nil, nil),
|
|
responseChan, postHookSpanFinalizer)
|
|
sequenceNumber++
|
|
hasEmittedInProgress = true
|
|
}
|
|
|
|
if !hasEmittedOutputItemAdded {
|
|
// response.output_item.added
|
|
messageType := schemas.ResponsesMessageTypeMessage
|
|
role := schemas.ResponsesInputMessageRoleAssistant
|
|
status := "in_progress"
|
|
itemAddedResp := &schemas.BifrostResponsesStreamResponse{
|
|
Type: schemas.ResponsesStreamResponseTypeOutputItemAdded,
|
|
SequenceNumber: sequenceNumber,
|
|
OutputIndex: schemas.Ptr(outputIndex),
|
|
Item: &schemas.ResponsesMessage{
|
|
ID: schemas.Ptr(itemID),
|
|
Type: &messageType,
|
|
Role: &role,
|
|
Status: &status,
|
|
Content: &schemas.ResponsesMessageContent{
|
|
ContentBlocks: []schemas.ResponsesMessageContentBlock{},
|
|
},
|
|
},
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
ChunkIndex: sequenceNumber,
|
|
},
|
|
}
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner,
|
|
providerUtils.GetBifrostResponseForStreamResponse(nil, nil, itemAddedResp, nil, nil, nil),
|
|
responseChan, postHookSpanFinalizer)
|
|
sequenceNumber++
|
|
hasEmittedOutputItemAdded = true
|
|
}
|
|
|
|
if !hasEmittedContentPartAdded {
|
|
// response.content_part.added
|
|
emptyText := ""
|
|
partAddedResp := &schemas.BifrostResponsesStreamResponse{
|
|
Type: schemas.ResponsesStreamResponseTypeContentPartAdded,
|
|
SequenceNumber: sequenceNumber,
|
|
OutputIndex: schemas.Ptr(outputIndex),
|
|
ContentIndex: schemas.Ptr(contentIndex),
|
|
ItemID: schemas.Ptr(itemID),
|
|
Part: &schemas.ResponsesMessageContentBlock{
|
|
Type: schemas.ResponsesOutputMessageContentTypeText,
|
|
Text: &emptyText,
|
|
ResponsesOutputMessageContentText: &schemas.ResponsesOutputMessageContentText{
|
|
Annotations: []schemas.ResponsesOutputMessageContentTextAnnotation{},
|
|
LogProbs: []schemas.ResponsesOutputMessageContentTextLogProb{},
|
|
},
|
|
},
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
ChunkIndex: sequenceNumber,
|
|
},
|
|
}
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner,
|
|
providerUtils.GetBifrostResponseForStreamResponse(nil, nil, partAddedResp, nil, nil, nil),
|
|
responseChan, postHookSpanFinalizer)
|
|
sequenceNumber++
|
|
hasEmittedContentPartAdded = true
|
|
}
|
|
|
|
// response.output_text.delta
|
|
deltaResp := &schemas.BifrostResponsesStreamResponse{
|
|
Type: schemas.ResponsesStreamResponseTypeOutputTextDelta,
|
|
SequenceNumber: sequenceNumber,
|
|
OutputIndex: schemas.Ptr(outputIndex),
|
|
ContentIndex: schemas.Ptr(contentIndex),
|
|
ItemID: schemas.Ptr(itemID),
|
|
Delta: schemas.Ptr(currentEvent.Data),
|
|
LogProbs: []schemas.ResponsesOutputMessageContentTextLogProb{},
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
ChunkIndex: sequenceNumber,
|
|
},
|
|
}
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner,
|
|
providerUtils.GetBifrostResponseForStreamResponse(nil, nil, deltaResp, nil, nil, nil),
|
|
responseChan, postHookSpanFinalizer)
|
|
sequenceNumber++
|
|
hasReceivedContent = true
|
|
}
|
|
case "done":
|
|
// Accumulate done event in raw responses if enabled
|
|
if sendBackRawResponse {
|
|
rawResponseChunks = append(rawResponseChunks, currentEvent)
|
|
}
|
|
|
|
// Stream completed
|
|
if hasReceivedContent {
|
|
// response.output_text.done
|
|
textDoneResp := &schemas.BifrostResponsesStreamResponse{
|
|
Type: schemas.ResponsesStreamResponseTypeOutputTextDone,
|
|
SequenceNumber: sequenceNumber,
|
|
OutputIndex: schemas.Ptr(outputIndex),
|
|
ContentIndex: schemas.Ptr(contentIndex),
|
|
ItemID: schemas.Ptr(itemID),
|
|
LogProbs: []schemas.ResponsesOutputMessageContentTextLogProb{},
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
ChunkIndex: sequenceNumber,
|
|
},
|
|
}
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner,
|
|
providerUtils.GetBifrostResponseForStreamResponse(nil, nil, textDoneResp, nil, nil, nil),
|
|
responseChan, postHookSpanFinalizer)
|
|
sequenceNumber++
|
|
|
|
// response.content_part.done
|
|
partDoneResp := &schemas.BifrostResponsesStreamResponse{
|
|
Type: schemas.ResponsesStreamResponseTypeContentPartDone,
|
|
SequenceNumber: sequenceNumber,
|
|
OutputIndex: schemas.Ptr(outputIndex),
|
|
ContentIndex: schemas.Ptr(contentIndex),
|
|
ItemID: schemas.Ptr(itemID),
|
|
Part: &schemas.ResponsesMessageContentBlock{
|
|
Type: schemas.ResponsesOutputMessageContentTypeText,
|
|
ResponsesOutputMessageContentText: &schemas.ResponsesOutputMessageContentText{
|
|
Annotations: []schemas.ResponsesOutputMessageContentTextAnnotation{},
|
|
LogProbs: []schemas.ResponsesOutputMessageContentTextLogProb{},
|
|
},
|
|
},
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
ChunkIndex: sequenceNumber,
|
|
},
|
|
}
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner,
|
|
providerUtils.GetBifrostResponseForStreamResponse(nil, nil, partDoneResp, nil, nil, nil),
|
|
responseChan, postHookSpanFinalizer)
|
|
sequenceNumber++
|
|
|
|
// response.output_item.done
|
|
messageType := schemas.ResponsesMessageTypeMessage
|
|
role := schemas.ResponsesInputMessageRoleAssistant
|
|
status := "completed"
|
|
itemDoneResp := &schemas.BifrostResponsesStreamResponse{
|
|
Type: schemas.ResponsesStreamResponseTypeOutputItemDone,
|
|
SequenceNumber: sequenceNumber,
|
|
OutputIndex: schemas.Ptr(outputIndex),
|
|
Item: &schemas.ResponsesMessage{
|
|
ID: schemas.Ptr(itemID),
|
|
Type: &messageType,
|
|
Role: &role,
|
|
Status: &status,
|
|
Content: &schemas.ResponsesMessageContent{
|
|
ContentBlocks: []schemas.ResponsesMessageContentBlock{
|
|
{
|
|
Type: schemas.ResponsesOutputMessageContentTypeText,
|
|
ResponsesOutputMessageContentText: &schemas.ResponsesOutputMessageContentText{
|
|
Annotations: []schemas.ResponsesOutputMessageContentTextAnnotation{},
|
|
LogProbs: []schemas.ResponsesOutputMessageContentTextLogProb{},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
ChunkIndex: sequenceNumber,
|
|
},
|
|
}
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner,
|
|
providerUtils.GetBifrostResponseForStreamResponse(nil, nil, itemDoneResp, nil, nil, nil),
|
|
responseChan, postHookSpanFinalizer)
|
|
sequenceNumber++
|
|
}
|
|
|
|
// response.completed
|
|
completedResp := &schemas.BifrostResponsesStreamResponse{
|
|
Type: schemas.ResponsesStreamResponseTypeCompleted,
|
|
SequenceNumber: sequenceNumber,
|
|
Response: &schemas.BifrostResponsesResponse{
|
|
ID: schemas.Ptr(messageID),
|
|
Model: request.Model,
|
|
CreatedAt: int(startTime.Unix()),
|
|
CompletedAt: schemas.Ptr(int(time.Now().Unix())),
|
|
},
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: time.Since(startTime).Milliseconds(),
|
|
ChunkIndex: sequenceNumber,
|
|
},
|
|
}
|
|
|
|
// Set raw request if enabled (on final chunk only)
|
|
if sendBackRawRequest {
|
|
providerUtils.ParseAndSetRawRequest(&completedResp.ExtraFields, jsonData)
|
|
}
|
|
|
|
// Set raw response if enabled
|
|
if sendBackRawResponse && len(rawResponseChunks) > 0 {
|
|
completedResp.ExtraFields.RawResponse = rawResponseChunks
|
|
}
|
|
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner,
|
|
providerUtils.GetBifrostResponseForStreamResponse(nil, nil, completedResp, nil, nil, nil),
|
|
responseChan, postHookSpanFinalizer)
|
|
resp.CloseBodyStream()
|
|
return
|
|
case "error":
|
|
// Accumulate error event in raw responses if enabled
|
|
if sendBackRawResponse {
|
|
rawResponseChunks = append(rawResponseChunks, currentEvent)
|
|
}
|
|
|
|
// Handle error
|
|
errorMsg := "stream error"
|
|
if currentEvent.Data != "" {
|
|
errorMsg = currentEvent.Data
|
|
}
|
|
bifrostErr := providerUtils.NewBifrostOperationError(
|
|
errorMsg,
|
|
fmt.Errorf("stream error: %s", errorMsg))
|
|
|
|
// Include accumulated raw responses in error
|
|
if sendBackRawResponse && len(rawResponseChunks) > 0 {
|
|
bifrostErr.ExtraFields.RawResponse = rawResponseChunks
|
|
}
|
|
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger, postHookSpanFinalizer)
|
|
resp.CloseBodyStream()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
return responseChan, nil
|
|
}
|
|
|
|
// Embedding is not supported by the replicate provider.
|
|
func (provider *ReplicateProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.EmbeddingRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// Speech is not supported by the replicate provider.
|
|
func (provider *ReplicateProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// Rerank is not supported by the Replicate provider.
|
|
func (provider *ReplicateProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.RerankRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// OCR is not supported by the Replicate provider.
|
|
func (provider *ReplicateProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostOCRRequest) (*schemas.BifrostOCRResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.OCRRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// SpeechStream is not supported by the replicate provider.
|
|
func (provider *ReplicateProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// Transcription is not supported by the replicate provider.
|
|
func (provider *ReplicateProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// TranscriptionStream is not supported by the replicate provider.
|
|
func (provider *ReplicateProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// ImageGeneration performs an image generation request to the replicate API using predictions.
|
|
func (provider *ReplicateProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.Replicate, provider.customProviderConfig, schemas.ImageGenerationRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Convert Bifrost request to Replicate format
|
|
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
return ToReplicateImageGenerationInput(request), nil
|
|
})
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Check for Prefer: wait header from context for sync mode
|
|
isSync := parsePreferHeader(provider.networkConfig.ExtraHeaders)
|
|
|
|
// Build prediction URL based on model type (version ID or model name)
|
|
predictionURL := buildPredictionURL(
|
|
ctx,
|
|
provider.networkConfig.BaseURL,
|
|
request.Model,
|
|
provider.customProviderConfig,
|
|
schemas.ImageGenerationRequest,
|
|
useDeploymentsEndpoint(key),
|
|
)
|
|
|
|
// Create prediction with appropriate mode
|
|
prediction, rawResponse, latency, providerResponseHeaders, err := createPrediction(
|
|
ctx,
|
|
provider.client,
|
|
jsonData,
|
|
key,
|
|
providerUtils.GetPathFromContext(ctx, predictionURL),
|
|
provider.networkConfig.ExtraHeaders,
|
|
false,
|
|
provider.logger,
|
|
provider.sendBackRawRequest,
|
|
provider.sendBackRawResponse,
|
|
)
|
|
if err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
// If async mode and not complete, poll until done
|
|
if !isSync && !isTerminalStatus(prediction.Status) {
|
|
prediction, rawResponse, providerResponseHeaders, err = pollPrediction(
|
|
ctx,
|
|
provider.client,
|
|
prediction.URLs.Get,
|
|
key,
|
|
provider.networkConfig.DefaultRequestTimeoutInSeconds,
|
|
provider.logger,
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
)
|
|
if err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
}
|
|
|
|
// Check for terminal error status (failed/canceled) after sync mode or polling
|
|
if err := checkForErrorStatus(prediction); err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
if providerResponseHeaders != nil {
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
}
|
|
|
|
// Convert to Bifrost response
|
|
bifrostResponse, err := ToBifrostImageGenerationResponse(prediction)
|
|
if err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
// Set extra fields
|
|
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
|
|
bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData)
|
|
}
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
bifrostResponse.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return bifrostResponse, nil
|
|
}
|
|
|
|
// ImageGenerationStream performs a streaming image generation request to the replicate API.
|
|
// It creates a prediction with streaming enabled and listens to the stream URL for progressive updates.
|
|
func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.Replicate, provider.customProviderConfig, schemas.ImageGenerationStreamRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
|
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
|
|
|
// Convert Bifrost request to Replicate format with streaming enabled
|
|
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
replicateReq := ToReplicateImageGenerationInput(request)
|
|
replicateReq.Stream = schemas.Ptr(true)
|
|
return replicateReq, nil
|
|
})
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Build prediction URL based on model type (version ID or model name)
|
|
predictionURL := buildPredictionURL(
|
|
ctx,
|
|
provider.networkConfig.BaseURL,
|
|
request.Model,
|
|
provider.customProviderConfig,
|
|
schemas.ImageGenerationStreamRequest,
|
|
useDeploymentsEndpoint(key),
|
|
)
|
|
// Create prediction
|
|
prediction, _, _, _, err := createPrediction(
|
|
ctx,
|
|
provider.client,
|
|
jsonData,
|
|
key,
|
|
providerUtils.GetPathFromContext(ctx, predictionURL),
|
|
provider.networkConfig.ExtraHeaders,
|
|
true, // Streaming request, strip Prefer header for async mode
|
|
provider.logger,
|
|
provider.sendBackRawRequest,
|
|
provider.sendBackRawResponse,
|
|
)
|
|
if err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
// Verify stream URL is available
|
|
if prediction.URLs == nil || prediction.URLs.Stream == nil || *prediction.URLs.Stream == "" {
|
|
return nil, providerUtils.EnrichError(
|
|
ctx,
|
|
providerUtils.NewBifrostOperationError(
|
|
"stream URL not available in prediction response",
|
|
fmt.Errorf("prediction response missing stream URL"),
|
|
),
|
|
jsonData,
|
|
nil,
|
|
sendBackRawRequest,
|
|
sendBackRawResponse,
|
|
)
|
|
}
|
|
|
|
streamURL := *prediction.URLs.Stream
|
|
|
|
// Connect to stream URL
|
|
_, resp, bifrostErr := listenToReplicateStreamURL(ctx, provider.streamingClient, streamURL, key)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Store provider response headers in context for transport layer
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
|
|
|
// Large payload streaming passthrough — pipe raw upstream SSE to client
|
|
if providerUtils.SetupStreamingPassthrough(ctx, resp) {
|
|
responseChan := make(chan *schemas.BifrostStreamChunk)
|
|
close(responseChan)
|
|
return responseChan, nil
|
|
}
|
|
|
|
// Create response channel
|
|
responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize)
|
|
|
|
providerUtils.SetStreamIdleTimeoutIfEmpty(ctx, provider.networkConfig.StreamIdleTimeoutInSeconds)
|
|
|
|
// Start streaming in a goroutine
|
|
go func() {
|
|
defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer)
|
|
defer func() {
|
|
if ctx.Err() == context.Canceled {
|
|
providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer)
|
|
} else if ctx.Err() == context.DeadlineExceeded {
|
|
providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer)
|
|
}
|
|
close(responseChan)
|
|
}()
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
|
|
// Decompress gzip-encoded streams transparently (no-op for non-gzip)
|
|
reader, releaseGzip := providerUtils.DecompressStreamBody(resp)
|
|
defer releaseGzip()
|
|
|
|
// Wrap reader with idle timeout to detect stalled streams.
|
|
reader, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(reader, resp.BodyStream(), providerUtils.GetStreamIdleTimeout(ctx))
|
|
defer stopIdleTimeout()
|
|
|
|
// Setup cancellation handler to close the raw network stream on ctx cancellation,
|
|
// which immediately unblocks any in-progress read (including reads blocked inside a gzip decompression layer).
|
|
stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger)
|
|
defer stopCancellation()
|
|
|
|
startTime := time.Now()
|
|
lastChunkTime := startTime
|
|
chunkIndex := 0
|
|
|
|
// Setup SSE event reader for event+data format
|
|
sseReader := providerUtils.GetSSEEventReader(ctx, reader)
|
|
|
|
// Track last image data for final chunk
|
|
var lastB64Data string
|
|
var lastOutputFormat string
|
|
// Accumulate all raw response chunks for complete stream history
|
|
var rawResponseChunks []interface{}
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
eventType, eventDataBytes, readErr := sseReader.ReadEvent()
|
|
if readErr != nil {
|
|
if readErr != io.EOF {
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
provider.logger.Warn(fmt.Sprintf("Error reading SSE stream: %v", readErr))
|
|
enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr), jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger, postHookSpanFinalizer)
|
|
}
|
|
break
|
|
}
|
|
|
|
eventData := string(eventDataBytes)
|
|
if eventType == "" && eventData == "" {
|
|
continue
|
|
}
|
|
|
|
// Process the complete event
|
|
switch eventType {
|
|
case "output":
|
|
// Check if data is a data URI (image) or plain text
|
|
var b64Data, outputFormat string
|
|
if strings.HasPrefix(eventData, "data:") {
|
|
// Parse image data from data URI
|
|
var mimeType string
|
|
b64Data, mimeType = parseDataURIImage(eventData)
|
|
|
|
// Extract output format from MIME type
|
|
if mimeType != "" {
|
|
// Convert "image/webp" to "webp"
|
|
parts := strings.Split(mimeType, "/")
|
|
if len(parts) == 2 {
|
|
outputFormat = parts[1]
|
|
}
|
|
}
|
|
} else {
|
|
// For non-data-URI output (e.g., text), store as-is
|
|
// This shouldn't happen for image generation but handle it gracefully
|
|
provider.logger.Debug(fmt.Sprintf("Received non-data-URI output: %s", eventData[:min(100, len(eventData))]))
|
|
// Skip non-image output for image generation
|
|
continue
|
|
}
|
|
|
|
// Create chunk
|
|
chunk := &schemas.BifrostImageGenerationStreamResponse{
|
|
Type: schemas.ImageGenerationEventTypePartial,
|
|
Index: 0, // Single image for now
|
|
ChunkIndex: chunkIndex,
|
|
B64JSON: b64Data,
|
|
CreatedAt: time.Now().Unix(),
|
|
OutputFormat: outputFormat,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
ChunkIndex: chunkIndex,
|
|
Latency: time.Since(lastChunkTime).Milliseconds(),
|
|
},
|
|
}
|
|
|
|
// Accumulate raw response chunks if enabled
|
|
if sendBackRawResponse {
|
|
rawResponseChunks = append(rawResponseChunks, ReplicateSSEEvent{Event: eventType, Data: eventData})
|
|
}
|
|
|
|
// Track last image data for final chunk
|
|
lastB64Data = b64Data
|
|
lastOutputFormat = outputFormat
|
|
|
|
lastChunkTime = time.Now()
|
|
chunkIndex++
|
|
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner,
|
|
providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, nil, chunk),
|
|
responseChan, postHookSpanFinalizer)
|
|
|
|
case "done":
|
|
// Parse done event data
|
|
var doneData ReplicateDoneEvent
|
|
if eventData != "" && eventData != "{}" {
|
|
if err := sonic.Unmarshal(eventDataBytes, &doneData); err != nil {
|
|
provider.logger.Warn(fmt.Sprintf("Failed to parse done event data: %v", err))
|
|
}
|
|
}
|
|
|
|
// Check for cancellation or error
|
|
switch doneData.Reason {
|
|
case "canceled":
|
|
bifrostErr := providerUtils.NewBifrostOperationError(
|
|
"prediction was canceled",
|
|
fmt.Errorf("stream ended: prediction canceled"))
|
|
// Include accumulated raw responses in error
|
|
if sendBackRawResponse && len(rawResponseChunks) > 0 {
|
|
bifrostErr.ExtraFields.RawResponse = rawResponseChunks
|
|
}
|
|
bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger, postHookSpanFinalizer)
|
|
return
|
|
case "error":
|
|
bifrostErr := providerUtils.NewBifrostOperationError(
|
|
"prediction failed",
|
|
fmt.Errorf("stream ended with error"))
|
|
// Include accumulated raw responses in error
|
|
if sendBackRawResponse && len(rawResponseChunks) > 0 {
|
|
bifrostErr.ExtraFields.RawResponse = rawResponseChunks
|
|
}
|
|
bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
|
|
// Send completion chunk (success case when reason is empty or not present)
|
|
finalChunk := &schemas.BifrostImageGenerationStreamResponse{
|
|
Type: schemas.ImageGenerationEventTypeCompleted,
|
|
Index: 0,
|
|
ChunkIndex: chunkIndex,
|
|
B64JSON: lastB64Data, // Include last image data
|
|
OutputFormat: lastOutputFormat, // Include output format
|
|
CreatedAt: time.Now().Unix(),
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
ChunkIndex: chunkIndex,
|
|
Latency: time.Since(startTime).Milliseconds(),
|
|
},
|
|
}
|
|
|
|
// Set raw request only on final chunk if enabled
|
|
if sendBackRawRequest {
|
|
providerUtils.ParseAndSetRawRequest(&finalChunk.ExtraFields, jsonData)
|
|
}
|
|
|
|
// Set accumulated raw responses on final chunk if enabled
|
|
if sendBackRawResponse {
|
|
// Append the final done event to the accumulated chunks
|
|
rawResponseChunks = append(rawResponseChunks, ReplicateSSEEvent{Event: eventType, Data: eventData})
|
|
finalChunk.ExtraFields.RawResponse = rawResponseChunks
|
|
}
|
|
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner,
|
|
providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, nil, finalChunk),
|
|
responseChan, postHookSpanFinalizer)
|
|
return
|
|
|
|
case "error":
|
|
// Parse error event data
|
|
var errorData ReplicateErrorEvent
|
|
errorMsg := "stream error"
|
|
|
|
if eventData != "" {
|
|
if err := sonic.Unmarshal(eventDataBytes, &errorData); err != nil {
|
|
provider.logger.Warn(fmt.Sprintf("Failed to parse error event data: %v", err))
|
|
// Fallback to raw data
|
|
errorMsg = eventData
|
|
} else if errorData.Detail != "" {
|
|
errorMsg = errorData.Detail
|
|
}
|
|
}
|
|
|
|
bifrostErr := &schemas.BifrostError{
|
|
IsBifrostError: false,
|
|
Error: &schemas.ErrorField{
|
|
Message: errorMsg,
|
|
},
|
|
}
|
|
// Include accumulated raw responses in error
|
|
if sendBackRawResponse {
|
|
rawResponseChunks = append(rawResponseChunks, ReplicateSSEEvent{Event: eventType, Data: eventData})
|
|
bifrostErr.ExtraFields.RawResponse = rawResponseChunks
|
|
}
|
|
bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
return responseChan, nil
|
|
}
|
|
|
|
// ImageEdit is not supported by the Replicate provider.
|
|
func (provider *ReplicateProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.Replicate, provider.customProviderConfig, schemas.ImageEditRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Convert Bifrost request to Replicate format
|
|
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
return ToReplicateImageEditInput(request), nil
|
|
})
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Check for Prefer: wait header from context for sync mode
|
|
isSync := parsePreferHeader(provider.networkConfig.ExtraHeaders)
|
|
|
|
// Build prediction URL based on model type (version ID or model name)
|
|
predictionURL := buildPredictionURL(
|
|
ctx,
|
|
provider.networkConfig.BaseURL,
|
|
request.Model,
|
|
provider.customProviderConfig,
|
|
schemas.ImageEditRequest,
|
|
useDeploymentsEndpoint(key),
|
|
)
|
|
|
|
// Create prediction with appropriate mode
|
|
prediction, rawResponse, latency, providerResponseHeaders, err := createPrediction(
|
|
ctx,
|
|
provider.client,
|
|
jsonData,
|
|
key,
|
|
providerUtils.GetPathFromContext(ctx, predictionURL),
|
|
provider.networkConfig.ExtraHeaders,
|
|
false,
|
|
provider.logger,
|
|
provider.sendBackRawRequest,
|
|
provider.sendBackRawResponse,
|
|
)
|
|
if err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
// If async mode and not complete, poll until done
|
|
if !isSync && !isTerminalStatus(prediction.Status) {
|
|
prediction, rawResponse, providerResponseHeaders, err = pollPrediction(
|
|
ctx,
|
|
provider.client,
|
|
prediction.URLs.Get,
|
|
key,
|
|
provider.networkConfig.DefaultRequestTimeoutInSeconds,
|
|
provider.logger,
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
)
|
|
if err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
}
|
|
|
|
// Check for terminal error status (failed/canceled) after sync mode or polling
|
|
if err := checkForErrorStatus(prediction); err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
if providerResponseHeaders != nil {
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
}
|
|
|
|
// Convert to Bifrost response (reuse image generation response format)
|
|
bifrostResponse, err := ToBifrostImageGenerationResponse(prediction)
|
|
if err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
// Set extra fields
|
|
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
|
|
bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData)
|
|
}
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
bifrostResponse.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return bifrostResponse, nil
|
|
}
|
|
|
|
// ImageEditStream performs a streaming image edit request to the replicate API.
|
|
// It creates a prediction with streaming enabled and listens to the stream URL for progressive updates.
|
|
func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.Replicate, provider.customProviderConfig, schemas.ImageEditStreamRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
|
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
|
|
|
// Convert Bifrost request to Replicate format with streaming enabled
|
|
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
replicateReq := ToReplicateImageEditInput(request)
|
|
replicateReq.Stream = schemas.Ptr(true)
|
|
return replicateReq, nil
|
|
})
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Build prediction URL based on model type (version ID or model name)
|
|
predictionURL := buildPredictionURL(
|
|
ctx,
|
|
provider.networkConfig.BaseURL,
|
|
request.Model,
|
|
provider.customProviderConfig,
|
|
schemas.ImageEditStreamRequest,
|
|
useDeploymentsEndpoint(key),
|
|
)
|
|
|
|
// Create prediction
|
|
prediction, _, _, _, err := createPrediction(
|
|
ctx,
|
|
provider.client,
|
|
jsonData,
|
|
key,
|
|
providerUtils.GetPathFromContext(ctx, predictionURL),
|
|
provider.networkConfig.ExtraHeaders,
|
|
true, // Streaming request, strip Prefer header for async mode
|
|
provider.logger,
|
|
provider.sendBackRawRequest,
|
|
provider.sendBackRawResponse,
|
|
)
|
|
if err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
// Verify stream URL is available
|
|
if prediction.URLs == nil || prediction.URLs.Stream == nil || *prediction.URLs.Stream == "" {
|
|
return nil, providerUtils.EnrichError(
|
|
ctx,
|
|
providerUtils.NewBifrostOperationError(
|
|
"stream URL not available in prediction response",
|
|
fmt.Errorf("prediction response missing stream URL"),
|
|
),
|
|
jsonData,
|
|
nil,
|
|
sendBackRawRequest,
|
|
sendBackRawResponse,
|
|
)
|
|
}
|
|
|
|
streamURL := *prediction.URLs.Stream
|
|
|
|
// Connect to stream URL
|
|
_, resp, bifrostErr := listenToReplicateStreamURL(ctx, provider.streamingClient, streamURL, key)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Store provider response headers in context for transport layer
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
|
|
|
// Large payload streaming passthrough — pipe raw upstream SSE to client
|
|
if providerUtils.SetupStreamingPassthrough(ctx, resp) {
|
|
responseChan := make(chan *schemas.BifrostStreamChunk)
|
|
close(responseChan)
|
|
return responseChan, nil
|
|
}
|
|
|
|
// Create response channel
|
|
responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize)
|
|
|
|
providerUtils.SetStreamIdleTimeoutIfEmpty(ctx, provider.networkConfig.StreamIdleTimeoutInSeconds)
|
|
|
|
// Start streaming in a goroutine
|
|
go func() {
|
|
defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer)
|
|
defer func() {
|
|
if ctx.Err() == context.Canceled {
|
|
providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer)
|
|
} else if ctx.Err() == context.DeadlineExceeded {
|
|
providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer)
|
|
}
|
|
close(responseChan)
|
|
}()
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
|
|
// Decompress gzip-encoded streams transparently (no-op for non-gzip)
|
|
reader, releaseGzip := providerUtils.DecompressStreamBody(resp)
|
|
defer releaseGzip()
|
|
|
|
// Wrap reader with idle timeout to detect stalled streams.
|
|
reader, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(reader, resp.BodyStream(), providerUtils.GetStreamIdleTimeout(ctx))
|
|
defer stopIdleTimeout()
|
|
|
|
// Setup cancellation handler to close the raw network stream on ctx cancellation,
|
|
// which immediately unblocks any in-progress read (including reads blocked inside a gzip decompression layer).
|
|
stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger)
|
|
defer stopCancellation()
|
|
|
|
startTime := time.Now()
|
|
lastChunkTime := startTime
|
|
chunkIndex := 0
|
|
|
|
// Setup SSE event reader for event+data format
|
|
sseReader := providerUtils.GetSSEEventReader(ctx, reader)
|
|
|
|
// Track last image data for final chunk
|
|
var lastB64Data string
|
|
var lastOutputFormat string
|
|
// Accumulate all raw response chunks for complete stream history
|
|
var rawResponseChunks []interface{}
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
eventType, eventDataBytes, readErr := sseReader.ReadEvent()
|
|
if readErr != nil {
|
|
if readErr != io.EOF {
|
|
if errors.Is(readErr, context.Canceled) {
|
|
return
|
|
}
|
|
enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("stream read error", readErr), jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger, postHookSpanFinalizer)
|
|
}
|
|
break
|
|
}
|
|
|
|
eventData := string(eventDataBytes)
|
|
if eventType == "" && eventData == "" {
|
|
continue
|
|
}
|
|
|
|
// Process the complete event
|
|
switch eventType {
|
|
case "output":
|
|
// Check if data is a data URI (image) or plain text
|
|
var b64Data, outputFormat string
|
|
if strings.HasPrefix(eventData, "data:") {
|
|
// Parse image data from data URI
|
|
var mimeType string
|
|
b64Data, mimeType = parseDataURIImage(eventData)
|
|
|
|
// Extract output format from MIME type
|
|
if mimeType != "" {
|
|
// Convert "image/webp" to "webp"
|
|
parts := strings.Split(mimeType, "/")
|
|
if len(parts) == 2 {
|
|
outputFormat = parts[1]
|
|
}
|
|
}
|
|
} else {
|
|
// For non-data-URI output, skip for image edit
|
|
provider.logger.Debug(fmt.Sprintf("Received non-data-URI output: %s", eventData[:min(100, len(eventData))]))
|
|
continue
|
|
}
|
|
|
|
// Create chunk (use ImageEditEventTypePartial)
|
|
chunk := &schemas.BifrostImageGenerationStreamResponse{
|
|
Type: schemas.ImageEditEventTypePartial,
|
|
Index: 0,
|
|
ChunkIndex: chunkIndex,
|
|
B64JSON: b64Data,
|
|
CreatedAt: time.Now().Unix(),
|
|
OutputFormat: outputFormat,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
ChunkIndex: chunkIndex,
|
|
Latency: time.Since(lastChunkTime).Milliseconds(),
|
|
},
|
|
}
|
|
|
|
// Accumulate raw response chunks if enabled
|
|
if sendBackRawResponse {
|
|
rawResponseChunks = append(rawResponseChunks, ReplicateSSEEvent{Event: eventType, Data: eventData})
|
|
}
|
|
|
|
// Track last image data for final chunk
|
|
lastB64Data = b64Data
|
|
lastOutputFormat = outputFormat
|
|
|
|
lastChunkTime = time.Now()
|
|
chunkIndex++
|
|
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner,
|
|
providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, nil, chunk),
|
|
responseChan, postHookSpanFinalizer)
|
|
|
|
case "done":
|
|
// Parse done event data
|
|
var doneData ReplicateDoneEvent
|
|
if eventData != "" && eventData != "{}" {
|
|
if err := sonic.Unmarshal(eventDataBytes, &doneData); err != nil {
|
|
provider.logger.Warn(fmt.Sprintf("Failed to parse done event data: %v", err))
|
|
}
|
|
}
|
|
|
|
// Check for cancellation or error
|
|
switch doneData.Reason {
|
|
case "canceled":
|
|
bifrostErr := providerUtils.NewBifrostOperationError(
|
|
"prediction was canceled",
|
|
fmt.Errorf("stream ended: prediction canceled"))
|
|
if sendBackRawResponse && len(rawResponseChunks) > 0 {
|
|
bifrostErr.ExtraFields.RawResponse = rawResponseChunks
|
|
}
|
|
bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger, postHookSpanFinalizer)
|
|
return
|
|
case "error":
|
|
bifrostErr := providerUtils.NewBifrostOperationError(
|
|
"prediction failed",
|
|
fmt.Errorf("stream ended with error"))
|
|
if sendBackRawResponse && len(rawResponseChunks) > 0 {
|
|
bifrostErr.ExtraFields.RawResponse = rawResponseChunks
|
|
}
|
|
bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
|
|
// Send completion chunk (success case)
|
|
finalChunk := &schemas.BifrostImageGenerationStreamResponse{
|
|
Type: schemas.ImageEditEventTypeCompleted,
|
|
Index: 0,
|
|
ChunkIndex: chunkIndex,
|
|
B64JSON: lastB64Data,
|
|
CreatedAt: time.Now().Unix(),
|
|
OutputFormat: lastOutputFormat,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
ChunkIndex: chunkIndex,
|
|
Latency: time.Since(startTime).Milliseconds(),
|
|
},
|
|
}
|
|
|
|
if sendBackRawRequest {
|
|
providerUtils.ParseAndSetRawRequest(&finalChunk.ExtraFields, jsonData)
|
|
}
|
|
if sendBackRawResponse {
|
|
rawResponseChunks = append(rawResponseChunks, ReplicateSSEEvent{Event: eventType, Data: eventData})
|
|
finalChunk.ExtraFields.RawResponse = rawResponseChunks
|
|
}
|
|
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner,
|
|
providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, nil, finalChunk),
|
|
responseChan, postHookSpanFinalizer)
|
|
return
|
|
|
|
case "error":
|
|
// Parse error event
|
|
var errorData ReplicateErrorEvent
|
|
if err := sonic.Unmarshal(eventDataBytes, &errorData); err != nil {
|
|
provider.logger.Warn(fmt.Sprintf("Failed to parse error event: %v", err))
|
|
errorData.Detail = eventData
|
|
}
|
|
|
|
bifrostErr := providerUtils.NewBifrostOperationError(
|
|
"stream error",
|
|
fmt.Errorf("%s", errorData.Detail))
|
|
if sendBackRawResponse {
|
|
rawResponseChunks = append(rawResponseChunks, ReplicateSSEEvent{Event: eventType, Data: eventData})
|
|
bifrostErr.ExtraFields.RawResponse = rawResponseChunks
|
|
}
|
|
bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
return responseChan, nil
|
|
}
|
|
|
|
// ImageVariation is not supported by the Replicate provider.
|
|
func (provider *ReplicateProvider) ImageVariation(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageVariationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageVariationRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// VideoGeneration performs a video generation request to Replicate's API.
|
|
func (provider *ReplicateProvider) VideoGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.Replicate, provider.customProviderConfig, schemas.VideoGenerationRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Convert Bifrost request to Replicate format
|
|
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
return ToReplicateVideoGenerationInput(request)
|
|
})
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Create prediction asynchronously and return job ID without polling.
|
|
predictionURL := buildPredictionURL(
|
|
ctx,
|
|
provider.networkConfig.BaseURL,
|
|
request.Model,
|
|
provider.customProviderConfig,
|
|
schemas.VideoGenerationRequest,
|
|
useDeploymentsEndpoint(key),
|
|
)
|
|
|
|
// Create prediction with appropriate mode
|
|
prediction, rawResponse, latency, providerResponseHeaders, err := createPrediction(
|
|
ctx,
|
|
provider.client,
|
|
jsonData,
|
|
key,
|
|
providerUtils.GetPathFromContext(ctx, predictionURL),
|
|
provider.networkConfig.ExtraHeaders,
|
|
false,
|
|
provider.logger,
|
|
provider.sendBackRawRequest,
|
|
provider.sendBackRawResponse,
|
|
)
|
|
if err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
if providerResponseHeaders != nil {
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
}
|
|
|
|
// Convert to Bifrost response
|
|
bifrostResponse, err := ToBifrostVideoGenerationResponse(prediction)
|
|
if err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
bifrostResponse.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResponse.ID, schemas.Replicate)
|
|
|
|
// Set extra fields
|
|
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
|
|
bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData)
|
|
}
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
bifrostResponse.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return bifrostResponse, nil
|
|
}
|
|
|
|
// VideoRetrieve fetches the status/output of a Replicate video generation job.
|
|
func (provider *ReplicateProvider) VideoRetrieve(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.Replicate, provider.customProviderConfig, schemas.VideoRetrieveRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
providerName := provider.GetProviderKey()
|
|
if request.ID == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("video_id is required", nil)
|
|
}
|
|
|
|
videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName)
|
|
// Build URL to fetch the prediction by ID.
|
|
predictionURL := provider.buildRequestURL(ctx, "/v1/predictions/"+videoID, schemas.VideoRetrieveRequest)
|
|
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
req.SetRequestURI(predictionURL)
|
|
req.Header.SetMethod(http.MethodGet)
|
|
req.Header.SetContentType("application/json")
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
return nil, providerUtils.EnrichError(
|
|
ctx,
|
|
parseReplicateError(resp.Body(), resp.StatusCode()),
|
|
nil,
|
|
nil,
|
|
provider.sendBackRawRequest,
|
|
provider.sendBackRawResponse,
|
|
)
|
|
}
|
|
|
|
providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
|
|
body, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
|
}
|
|
|
|
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
|
var prediction ReplicatePredictionResponse
|
|
_, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &prediction, nil, false, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
bifrostResponse, convertErr := ToBifrostVideoGenerationResponse(&prediction)
|
|
if convertErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, convertErr, nil, body, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
bifrostResponse.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResponse.ID, providerName)
|
|
|
|
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
|
|
bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
|
|
if sendBackRawResponse {
|
|
bifrostResponse.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return bifrostResponse, nil
|
|
}
|
|
|
|
// VideoDownload is not supported by the Replicate provider.
|
|
func (provider *ReplicateProvider) VideoDownload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.Replicate, provider.customProviderConfig, schemas.VideoDownloadRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
if request.ID == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("video_id is required", nil)
|
|
}
|
|
// Retrieve latest status/output first.
|
|
bifrostVideoRetrieveRequest := &schemas.BifrostVideoRetrieveRequest{
|
|
Provider: request.Provider,
|
|
ID: request.ID,
|
|
}
|
|
videoResp, bifrostErr := provider.VideoRetrieve(ctx, key, bifrostVideoRetrieveRequest)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
if videoResp.Status != schemas.VideoStatusCompleted {
|
|
return nil, providerUtils.NewBifrostOperationError(
|
|
fmt.Sprintf("video not ready, current status: %s", videoResp.Status),
|
|
nil)
|
|
}
|
|
if len(videoResp.Videos) == 0 {
|
|
return nil, providerUtils.NewBifrostOperationError("video URL not available", nil)
|
|
}
|
|
var videoUrl string
|
|
if videoResp.Videos[0].URL != nil {
|
|
videoUrl = *videoResp.Videos[0].URL
|
|
}
|
|
if videoUrl == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil)
|
|
}
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
req.SetRequestURI(videoUrl)
|
|
req.Header.SetMethod(http.MethodGet)
|
|
// Some output URLs are signed and don't need auth, but keep auth if present.
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
return nil, providerUtils.NewBifrostOperationError(
|
|
fmt.Sprintf("failed to download video: HTTP %d", resp.StatusCode()),
|
|
nil)
|
|
}
|
|
|
|
providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
|
|
body, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
|
}
|
|
contentType := string(resp.Header.ContentType())
|
|
if contentType == "" {
|
|
contentType = "video/mp4"
|
|
}
|
|
content := append([]byte(nil), body...)
|
|
bifrostResp := &schemas.BifrostVideoDownloadResponse{
|
|
VideoID: request.ID,
|
|
Content: content,
|
|
ContentType: contentType,
|
|
}
|
|
|
|
bifrostResp.ExtraFields.Latency = latency.Milliseconds()
|
|
bifrostResp.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
|
|
|
|
return bifrostResp, nil
|
|
}
|
|
|
|
// VideoDelete is not supported by replicate provider.
|
|
func (provider *ReplicateProvider) VideoDelete(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDeleteRequest) (*schemas.BifrostVideoDeleteResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDeleteRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// VideoList is not supported by replicate provider.
|
|
func (provider *ReplicateProvider) VideoList(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoListRequest) (*schemas.BifrostVideoListResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoListRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// VideoRemix is not supported by replicate provider.
|
|
func (provider *ReplicateProvider) VideoRemix(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRemixRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRemixRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// BatchCreate is not supported by replicate provider.
|
|
func (provider *ReplicateProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// BatchList is not supported by replicate provider.
|
|
func (provider *ReplicateProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// BatchRetrieve is not supported by replicate provider.
|
|
func (provider *ReplicateProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// BatchCancel is not supported by replicate provider.
|
|
func (provider *ReplicateProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// BatchDelete is not supported by replicate provider.
|
|
func (provider *ReplicateProvider) BatchDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchDeleteRequest) (*schemas.BifrostBatchDeleteResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchDeleteRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// BatchResults is not supported by replicate provider.
|
|
func (provider *ReplicateProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// FileUpload uploads a file to Replicate's Files API.
|
|
func (provider *ReplicateProvider) FileUpload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) {
|
|
providerName := provider.GetProviderKey()
|
|
|
|
if len(request.File) == 0 {
|
|
return nil, providerUtils.NewBifrostOperationError("file content is required", nil)
|
|
}
|
|
|
|
// Create multipart form data
|
|
var buf bytes.Buffer
|
|
writer := multipart.NewWriter(&buf)
|
|
|
|
// Add file field (content)
|
|
filename := request.Filename
|
|
if filename == "" {
|
|
filename = "file"
|
|
}
|
|
|
|
// Determine content type - use from request or infer from filename
|
|
contentType := "application/octet-stream"
|
|
if request.ContentType != nil && *request.ContentType != "" {
|
|
contentType = *request.ContentType
|
|
} else {
|
|
// Try to infer from filename extension
|
|
if strings.HasSuffix(filename, ".json") {
|
|
contentType = "application/json"
|
|
} else if strings.HasSuffix(filename, ".jsonl") {
|
|
contentType = "application/x-ndjson"
|
|
} else if strings.HasSuffix(filename, ".txt") {
|
|
contentType = "text/plain"
|
|
} else if strings.HasSuffix(filename, ".zip") {
|
|
contentType = "application/zip"
|
|
}
|
|
}
|
|
|
|
// Add filename field if provided
|
|
if filename != "" {
|
|
if err := writer.WriteField("filename", filename); err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("failed to write filename field", err)
|
|
}
|
|
}
|
|
|
|
// Add type field (content type)
|
|
if err := writer.WriteField("type", contentType); err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("failed to write type field", err)
|
|
}
|
|
|
|
// Add metadata field if provided
|
|
if request.ExtraParams != nil {
|
|
if metadata, ok := request.ExtraParams["metadata"].(map[string]interface{}); ok {
|
|
if len(metadata) > 0 {
|
|
metadataJSON, err := providerUtils.MarshalSorted(metadata)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("failed to marshal metadata", err)
|
|
}
|
|
h := make(textproto.MIMEHeader)
|
|
h.Set("Content-Disposition", `form-data; name="metadata"`)
|
|
h.Set("Content-Type", "application/json")
|
|
metadataPart, err := writer.CreatePart(h)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("failed to create metadata part", err)
|
|
}
|
|
if _, err := metadataPart.Write(metadataJSON); err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("failed to write metadata", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Create form file with proper headers
|
|
h := make(textproto.MIMEHeader)
|
|
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="content"; filename="%s"`, filename))
|
|
h.Set("Content-Type", contentType)
|
|
|
|
part, err := writer.CreatePart(h)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("failed to create form file", err)
|
|
}
|
|
if _, err := part.Write(request.File); err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("failed to write file content", err)
|
|
}
|
|
|
|
if err := writer.Close(); err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err)
|
|
}
|
|
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
// Set headers
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/files")
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.Header.SetContentType(writer.FormDataContentType())
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
req.SetBody(buf.Bytes())
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusCreated {
|
|
provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body()))
|
|
return nil, parseReplicateError(resp.Body(), resp.StatusCode())
|
|
}
|
|
|
|
body, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
|
}
|
|
|
|
var replicateResp ReplicateFileResponse
|
|
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
|
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &replicateResp, nil, sendBackRawRequest, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
fileResponse := replicateResp.ToBifrostFileUploadResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse)
|
|
providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
fileResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
return fileResponse, nil
|
|
}
|
|
|
|
// FileList lists files using serial pagination across keys.
|
|
// Exhausts all pages from one key before moving to the next.
|
|
func (provider *ReplicateProvider) FileList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) {
|
|
providerName := provider.GetProviderKey()
|
|
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
|
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
|
|
|
// Initialize serial pagination helper (Replicate uses cursor-based pagination)
|
|
helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err)
|
|
}
|
|
|
|
// Get current key to query
|
|
key, nativeCursor, ok := helper.GetCurrentKey()
|
|
if !ok {
|
|
// All keys exhausted
|
|
return &schemas.BifrostFileListResponse{
|
|
Object: "list",
|
|
Data: []schemas.FileObject{},
|
|
HasMore: false,
|
|
}, nil
|
|
}
|
|
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
// Build URL with query params
|
|
requestURL := provider.networkConfig.BaseURL + "/v1/files"
|
|
values := url.Values{}
|
|
if request.Limit > 0 {
|
|
values.Set("limit", fmt.Sprintf("%d", request.Limit))
|
|
}
|
|
// Use native cursor from serial helper (Replicate pagination URL)
|
|
if nativeCursor != "" {
|
|
// For Replicate, the cursor is actually the full next URL
|
|
requestURL = nativeCursor
|
|
} else if encodedValues := values.Encode(); encodedValues != "" {
|
|
requestURL += "?" + encodedValues
|
|
}
|
|
|
|
// Set headers
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
req.SetRequestURI(requestURL)
|
|
req.Header.SetMethod(http.MethodGet)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body()))
|
|
return nil, parseReplicateError(resp.Body(), resp.StatusCode())
|
|
}
|
|
|
|
body, decodeErr := providerUtils.CheckAndDecodeBody(resp)
|
|
if decodeErr != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr)
|
|
}
|
|
|
|
var replicateResp ReplicateFileListResponse
|
|
_, _, bifrostErr = providerUtils.HandleProviderResponse(body, &replicateResp, nil, sendBackRawRequest, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Convert files to Bifrost format
|
|
files := make([]schemas.FileObject, 0, len(replicateResp.Results))
|
|
for _, file := range replicateResp.Results {
|
|
files = append(files, schemas.FileObject{
|
|
ID: file.ID,
|
|
Object: "file",
|
|
Bytes: file.Size,
|
|
CreatedAt: ParseReplicateTimestamp(file.CreatedAt),
|
|
Filename: file.Name,
|
|
Purpose: schemas.FilePurposeBatch,
|
|
Status: ToBifrostFileStatus(&file),
|
|
})
|
|
}
|
|
|
|
// Build cursor for next request
|
|
// Replicate uses full URL for pagination
|
|
var nextCursor string
|
|
hasMore := false
|
|
if replicateResp.Next != nil && *replicateResp.Next != "" {
|
|
nextCursor = *replicateResp.Next
|
|
hasMore = true
|
|
}
|
|
|
|
// Use helper to build proper cursor with key index
|
|
finalCursor, finalHasMore := helper.BuildNextCursor(hasMore, nextCursor)
|
|
|
|
// Convert to Bifrost response
|
|
providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
|
|
bifrostResp := &schemas.BifrostFileListResponse{
|
|
Object: "list",
|
|
Data: files,
|
|
HasMore: finalHasMore,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
ProviderResponseHeaders: providerResponseHeaders,
|
|
},
|
|
}
|
|
if finalCursor != "" {
|
|
bifrostResp.After = &finalCursor
|
|
}
|
|
|
|
return bifrostResp, nil
|
|
}
|
|
|
|
// FileRetrieve retrieves file metadata from Replicate's Files API by trying each key until found.
|
|
func (provider *ReplicateProvider) FileRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) {
|
|
providerName := provider.GetProviderKey()
|
|
|
|
if request.FileID == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("file_id is required", nil)
|
|
}
|
|
|
|
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
|
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
|
|
|
var lastErr *schemas.BifrostError
|
|
for _, key := range keys {
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
|
|
// Set headers
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/files/" + url.PathEscape(request.FileID))
|
|
req.Header.SetMethod(http.MethodGet)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
wait()
|
|
if bifrostErr != nil {
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
lastErr = bifrostErr
|
|
continue
|
|
}
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body()))
|
|
lastErr = parseReplicateError(resp.Body(), resp.StatusCode())
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
continue
|
|
}
|
|
|
|
body, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
|
continue
|
|
}
|
|
|
|
var replicateResp ReplicateFileResponse
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &replicateResp, nil, sendBackRawRequest, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
lastErr = bifrostErr
|
|
continue
|
|
}
|
|
|
|
providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
|
|
fileRetrieveResponse := replicateResp.ToBifrostFileRetrieveResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse)
|
|
fileRetrieveResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
|
|
return fileRetrieveResponse, nil
|
|
}
|
|
|
|
return nil, lastErr
|
|
}
|
|
|
|
// FileDelete deletes a file from Replicate's Files API by trying each key until successful.
|
|
func (provider *ReplicateProvider) FileDelete(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) {
|
|
providerName := provider.GetProviderKey()
|
|
|
|
if request.FileID == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("file_id is required", nil)
|
|
}
|
|
|
|
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
|
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
|
|
|
var lastErr *schemas.BifrostError
|
|
for _, key := range keys {
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
|
|
// Set headers
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/files/" + url.PathEscape(request.FileID))
|
|
req.Header.SetMethod(http.MethodDelete)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
wait()
|
|
if bifrostErr != nil {
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
lastErr = bifrostErr
|
|
continue
|
|
}
|
|
|
|
// Handle success response (204 No Content is expected for DELETE)
|
|
if resp.StatusCode() == fasthttp.StatusNoContent {
|
|
providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
return &schemas.BifrostFileDeleteResponse{
|
|
ID: request.FileID,
|
|
Object: "file",
|
|
Deleted: true,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
ProviderResponseHeaders: providerResponseHeaders,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body()))
|
|
lastErr = parseReplicateError(resp.Body(), resp.StatusCode())
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
continue
|
|
}
|
|
|
|
// Some APIs return 200 with body, parse it
|
|
body, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
|
continue
|
|
}
|
|
|
|
// Try to parse response body if present
|
|
var deleteResp map[string]interface{}
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &deleteResp, nil, sendBackRawRequest, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
lastErr = bifrostErr
|
|
continue
|
|
}
|
|
|
|
providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
|
|
result := &schemas.BifrostFileDeleteResponse{
|
|
ID: request.FileID,
|
|
Object: "file",
|
|
Deleted: true,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
ProviderResponseHeaders: providerResponseHeaders,
|
|
},
|
|
}
|
|
|
|
if sendBackRawRequest {
|
|
result.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
|
|
if sendBackRawResponse {
|
|
result.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
return nil, lastErr
|
|
}
|
|
|
|
// FileContent is not supported by replicate provider.
|
|
func (provider *ReplicateProvider) FileContent(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
func (provider *ReplicateProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// ContainerCreate is not supported by replicate provider.
|
|
func (provider *ReplicateProvider) ContainerCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerCreateRequest) (*schemas.BifrostContainerCreateResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerCreateRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// ContainerList is not supported by replicate provider.
|
|
func (provider *ReplicateProvider) ContainerList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerListRequest) (*schemas.BifrostContainerListResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerListRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// ContainerRetrieve is not supported by replicate provider.
|
|
func (provider *ReplicateProvider) ContainerRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerRetrieveRequest) (*schemas.BifrostContainerRetrieveResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerRetrieveRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// ContainerDelete is not supported by replicate provider.
|
|
func (provider *ReplicateProvider) ContainerDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerDeleteRequest) (*schemas.BifrostContainerDeleteResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerDeleteRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// ContainerFileCreate is not supported by replicate provider.
|
|
func (provider *ReplicateProvider) ContainerFileCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerFileCreateRequest) (*schemas.BifrostContainerFileCreateResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileCreateRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// ContainerFileList is not supported by replicate provider.
|
|
func (provider *ReplicateProvider) ContainerFileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileListRequest) (*schemas.BifrostContainerFileListResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileListRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// ContainerFileRetrieve is not supported by replicate provider.
|
|
func (provider *ReplicateProvider) ContainerFileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileRetrieveRequest) (*schemas.BifrostContainerFileRetrieveResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileRetrieveRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// ContainerFileContent is not supported by replicate provider.
|
|
func (provider *ReplicateProvider) ContainerFileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileContentRequest) (*schemas.BifrostContainerFileContentResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileContentRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// ContainerFileDelete is not supported by replicate provider.
|
|
func (provider *ReplicateProvider) ContainerFileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileDeleteRequest) (*schemas.BifrostContainerFileDeleteResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileDeleteRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// Passthrough is not supported by the Replicate provider.
|
|
func (provider *ReplicateProvider) Passthrough(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
func (provider *ReplicateProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey())
|
|
}
|