first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

View File

@@ -0,0 +1,280 @@
package replicate
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"regexp"
"strconv"
"strings"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
// isTerminalStatus checks if a prediction status is terminal (completed/failed/canceled)
func isTerminalStatus(status ReplicatePredictionStatus) bool {
return status == ReplicatePredictionStatusSucceeded ||
status == ReplicatePredictionStatusFailed ||
status == ReplicatePredictionStatusCanceled
}
// checkForErrorStatus returns an error if the prediction failed
func checkForErrorStatus(prediction *ReplicatePredictionResponse) *schemas.BifrostError {
if prediction.Status == ReplicatePredictionStatusFailed {
errorMsg := "prediction failed"
if prediction.Error != nil && *prediction.Error != "" {
errorMsg = *prediction.Error
}
return providerUtils.NewBifrostOperationError(
"prediction failed",
fmt.Errorf("%s", errorMsg))
}
if prediction.Status == ReplicatePredictionStatusCanceled {
return providerUtils.NewBifrostOperationError(
"prediction was canceled",
fmt.Errorf("prediction was canceled"))
}
return nil
}
// parsePreferHeader parses the Prefer header to extract wait duration
// Examples: "wait", "wait=30", "wait=60"
// Returns the header value to use and whether sync mode should be enabled
func parsePreferHeader(extraHeaders map[string]string) bool {
if preferValue, exists := extraHeaders["Prefer"]; exists {
if strings.HasPrefix(preferValue, "wait") {
return true
}
return false
}
return false
}
// Streaming requests should always be async and not wait for completion,
// so the Prefer header (which enables sync mode) must be excluded.
func stripPreferHeader(extraHeaders map[string]string) map[string]string {
if extraHeaders == nil {
return nil
}
// Check if Prefer header exists
if _, exists := extraHeaders["Prefer"]; !exists {
// No Prefer header, return original map
return extraHeaders
}
// Create new map without Prefer header
filtered := make(map[string]string, len(extraHeaders)-1)
for key, value := range extraHeaders {
if key != "Prefer" {
filtered[key] = value
}
}
return filtered
}
// listenToReplicateStreamURL listens to a Replicate stream URL and processes SSE events.
// This is a reusable utility for any Replicate streaming endpoint.
// It returns the response body stream (as io.Reader) and any error that occurred during connection.
func listenToReplicateStreamURL(
ctx *schemas.BifrostContext,
client *fasthttp.Client,
streamURL string,
key schemas.Key,
) (io.Reader, *fasthttp.Response, *schemas.BifrostError) {
// Create request
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
resp.StreamBody = true
// Set URL and headers
req.SetRequestURI(streamURL)
req.Header.SetMethod(http.MethodGet)
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
// Set authorization header
if value := key.Value.GetValue(); value != "" {
req.Header.Set("Authorization", "Bearer "+value)
}
// Make request
err := client.Do(req, resp)
fasthttp.ReleaseRequest(req)
if err != nil {
providerUtils.ReleaseStreamingResponse(resp)
if errors.Is(err, context.Canceled) {
return nil, nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Type: schemas.Ptr(schemas.RequestCancelled),
Message: schemas.ErrRequestCancelled,
Error: err,
},
}
}
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
return nil, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err)
}
return nil, nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err)
}
// Extract provider response headers before status check so error responses also forward them
if ctx != nil {
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
}
// Check for HTTP errors
if resp.StatusCode() != fasthttp.StatusOK {
defer providerUtils.ReleaseStreamingResponse(resp)
return nil, nil, parseReplicateError(resp.Body(), resp.StatusCode())
}
return resp.BodyStream(), resp, nil
}
// parseDataURIImage extracts the base64 data from a data URI
// Example: "data:image/webp;base64,UklGRmSu..." -> "UklGRmSu..."
func parseDataURIImage(dataURI string) (base64Data string, mimeType string) {
// Format: data:image/webp;base64,<base64-data>
if !strings.HasPrefix(dataURI, "data:") {
return dataURI, "" // Return as-is if not a data URI
}
// Split by comma to separate metadata and data
parts := strings.SplitN(dataURI[len("data:"):], ",", 2)
if len(parts) != 2 {
return dataURI, ""
}
// Parse MIME type from metadata (e.g., "image/webp;base64")
metadata := parts[0]
metaParts := strings.Split(metadata, ";")
if len(metaParts) > 0 {
mimeType = metaParts[0]
}
// Return the base64 data
return parts[1], mimeType
}
// versionIDPattern matches a 64-character hexadecimal string (Replicate version ID format)
var versionIDPattern = regexp.MustCompile(`^[a-f0-9]{64}$`)
// isVersionID checks if a string is a Replicate version ID (64-character hex string)
func isVersionID(s string) bool {
return versionIDPattern.MatchString(s)
}
// buildPredictionURL builds the appropriate URL for creating a prediction
// Returns the URL for the appropriate prediction endpoint.
func buildPredictionURL(ctx *schemas.BifrostContext, baseURL, model string, customProviderConfig *schemas.CustomProviderConfig, requestType schemas.RequestType, useDeploymentsEndpoint bool) string {
var defaultPath string
if useDeploymentsEndpoint {
defaultPath = "/v1/deployments/" + model + "/predictions"
} else if isVersionID(model) {
// If model is a version ID, use base predictions endpoint
defaultPath = "/v1/predictions"
} else {
// If model is a name (owner/name), use model-specific endpoint
defaultPath = "/v1/models/" + model + "/predictions"
}
path, isCompleteURL := providerUtils.GetRequestPath(ctx, defaultPath, customProviderConfig, requestType)
if isCompleteURL {
return path
}
return baseURL + path
}
// parseTokenUsageFromLogs extracts token counts from Replicate's logs field
// Handles multiple log formats with varying levels of detail
func parseTokenUsageFromLogs(logs *string, requestType schemas.RequestType) (inputTokens, outputTokens, totalTokens int, found bool) {
if logs == nil || *logs == "" {
return 0, 0, 0, false
}
logText := *logs
foundAny := false
// Pattern 1: Detailed format with input/output breakdown
// "Input token count: 20"
// "Input text token count: 15"
inputPatterns := []string{
`Input token count:\s*(\d+)`,
`Input text token count:\s*(\d+)`,
}
for _, pattern := range inputPatterns {
if matches := regexp.MustCompile(pattern).FindStringSubmatch(logText); len(matches) > 1 {
if val, err := strconv.Atoi(matches[1]); err == nil {
inputTokens = val
foundAny = true
break
}
}
}
// "Input image token count: 0" (for image generation)
if matches := regexp.MustCompile(`Input image token count:\s*(\d+)`).FindStringSubmatch(logText); len(matches) > 1 {
if val, err := strconv.Atoi(matches[1]); err == nil {
inputTokens += val // Add to text input tokens
foundAny = true
}
}
// "Output token count: 28"
if matches := regexp.MustCompile(`Output token count:\s*(\d+)`).FindStringSubmatch(logText); len(matches) > 1 {
if val, err := strconv.Atoi(matches[1]); err == nil {
outputTokens = val
foundAny = true
}
}
// "Total token count: 48"
if matches := regexp.MustCompile(`Total token count:\s*(\d+)`).FindStringSubmatch(logText); len(matches) > 1 {
if val, err := strconv.Atoi(matches[1]); err == nil {
totalTokens = val
foundAny = true
}
}
// Pattern 2: Simple "Tokens: X" format (ambiguous - need heuristic)
// Only use if detailed format not found
if !foundAny {
if matches := regexp.MustCompile(`Tokens:\s*(\d+)`).FindStringSubmatch(logText); len(matches) > 1 {
if val, err := strconv.Atoi(matches[1]); err == nil {
// Heuristic based on response type
switch requestType {
case schemas.ImageGenerationRequest:
// For image generation, "Tokens: X" typically means output tokens
outputTokens = val
totalTokens = val
case schemas.TextCompletionRequest, schemas.ChatCompletionRequest, schemas.ResponsesRequest:
// For text, unclear - could be total or output
// Conservative approach: treat as total tokens
totalTokens = val
default:
// Unknown type - treat as total
totalTokens = val
}
foundAny = true
}
}
}
// If we found input/output but not total, compute it
if foundAny && totalTokens == 0 {
totalTokens = inputTokens + outputTokens
}
return inputTokens, outputTokens, totalTokens, foundAny
}