first commit
This commit is contained in:
38
core/providers/runway/errors.go
Normal file
38
core/providers/runway/errors.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package runway
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// parseRunwayError parses Runway API error responses and converts them to BifrostError.
|
||||
func parseRunwayError(resp *fasthttp.Response) *schemas.BifrostError {
|
||||
// Parse as RunwayAPIError
|
||||
var errorResp RunwayAPIError
|
||||
bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp)
|
||||
|
||||
// Set error message if available
|
||||
if errorResp.Error != "" {
|
||||
if bifrostErr.Error == nil {
|
||||
bifrostErr.Error = &schemas.ErrorField{}
|
||||
}
|
||||
bifrostErr.Error.Message = errorResp.Error
|
||||
} else if bifrostErr.Error != nil && bifrostErr.Error.Message == "" {
|
||||
// If no error message was extracted, use a generic one
|
||||
bifrostErr.Error.Message = "Runway API request failed"
|
||||
} else if bifrostErr.Error == nil {
|
||||
bifrostErr.Error = &schemas.ErrorField{
|
||||
Message: "Runway API request failed",
|
||||
}
|
||||
}
|
||||
|
||||
// Remove trailing newlines
|
||||
if bifrostErr.Error != nil && bifrostErr.Error.Message != "" {
|
||||
bifrostErr.Error.Message = strings.TrimRight(bifrostErr.Error.Message, "\n")
|
||||
}
|
||||
|
||||
return bifrostErr
|
||||
}
|
||||
566
core/providers/runway/runway.go
Normal file
566
core/providers/runway/runway.go
Normal file
@@ -0,0 +1,566 @@
|
||||
// Package providers implements various LLM providers and their utility functions.
|
||||
// This file contains the Runway provider implementation.
|
||||
package runway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// RunwayProvider implements the Provider interface for Runway's API.
|
||||
type RunwayProvider struct {
|
||||
logger schemas.Logger // Logger for provider operations
|
||||
client *fasthttp.Client // HTTP client for API requests
|
||||
networkConfig schemas.NetworkConfig // Network configuration including extra headers
|
||||
sendBackRawRequest bool // Whether to include raw request in BifrostResponse
|
||||
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
|
||||
}
|
||||
|
||||
// NewRunwayProvider creates a new Runway provider instance.
|
||||
// It initializes the HTTP client with the provided configuration and sets up response pools.
|
||||
// The client is configured with timeouts, concurrency limits, and optional proxy settings.
|
||||
func NewRunwayProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*RunwayProvider, error) {
|
||||
config.CheckAndSetDefaults()
|
||||
|
||||
requestTimeout := time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)
|
||||
client := &fasthttp.Client{
|
||||
ReadTimeout: requestTimeout,
|
||||
WriteTimeout: requestTimeout,
|
||||
MaxConnsPerHost: config.NetworkConfig.MaxConnsPerHost,
|
||||
MaxIdleConnDuration: 60 * time.Second, // Video provider — longer idle duration to accommodate slower video generation responses
|
||||
MaxConnWaitTimeout: requestTimeout,
|
||||
MaxConnDuration: time.Second * time.Duration(schemas.DefaultMaxConnDurationInSeconds),
|
||||
ConnPoolStrategy: fasthttp.FIFO,
|
||||
}
|
||||
|
||||
// Configure proxy if provided
|
||||
client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger)
|
||||
client = providerUtils.ConfigureDialer(client)
|
||||
client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger)
|
||||
|
||||
// Set default BaseURL if not provided
|
||||
if config.NetworkConfig.BaseURL == "" {
|
||||
config.NetworkConfig.BaseURL = "https://api.dev.runwayml.com"
|
||||
}
|
||||
config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/")
|
||||
|
||||
return &RunwayProvider{
|
||||
logger: logger,
|
||||
client: client,
|
||||
networkConfig: config.NetworkConfig,
|
||||
sendBackRawRequest: config.SendBackRawRequest,
|
||||
sendBackRawResponse: config.SendBackRawResponse,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetProviderKey returns the provider identifier for Runway.
|
||||
func (provider *RunwayProvider) GetProviderKey() schemas.ModelProvider {
|
||||
return schemas.Runway
|
||||
}
|
||||
|
||||
// ListModels is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ListModelsRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// TextCompletion is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// TextCompletionStream is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ChatCompletion is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ChatCompletionRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ChatCompletionStream is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ChatCompletionStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Responses is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ResponsesRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ResponsesStream is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ResponsesStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Embedding is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.EmbeddingRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Speech is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// SpeechStream is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Transcription is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// TranscriptionStream is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Rerank is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.RerankRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// OCR is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostOCRRequest) (*schemas.BifrostOCRResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.OCRRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageGeneration is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageGenerationStream is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageEdit is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageEditStream is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageVariation is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) ImageVariation(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageVariationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageVariationRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoGeneration performs a video generation request to Runway's API.
|
||||
func (provider *RunwayProvider) VideoGeneration(ctx *schemas.BifrostContext, key schemas.Key, bifrostReq *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
providerName := provider.GetProviderKey()
|
||||
model := bifrostReq.Model
|
||||
|
||||
// Convert Bifrost request to Runway format
|
||||
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
||||
ctx,
|
||||
bifrostReq,
|
||||
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
||||
return ToRunwayVideoGenerationRequest(bifrostReq)
|
||||
})
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
|
||||
// Determine the endpoint based on request type
|
||||
endpoint := getRunwayEndpoint(bifrostReq)
|
||||
|
||||
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
||||
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
||||
|
||||
// Create HTTP request
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
// Set any extra headers from network config
|
||||
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
||||
|
||||
// Set request URI and headers
|
||||
req.SetRequestURI(provider.networkConfig.BaseURL + providerUtils.GetPathFromContext(ctx, endpoint))
|
||||
req.Header.SetMethod(http.MethodPost)
|
||||
req.Header.SetContentType("application/json")
|
||||
req.Header.Set("X-Runway-Version", "2024-11-06")
|
||||
if key.Value.GetValue() != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
||||
}
|
||||
|
||||
req.SetBody(jsonData)
|
||||
|
||||
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
||||
defer wait()
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
|
||||
// Handle error response
|
||||
if resp.StatusCode() != fasthttp.StatusOK {
|
||||
return nil, providerUtils.EnrichError(ctx, parseRunwayError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
||||
}
|
||||
|
||||
// Decode response body
|
||||
body, err := providerUtils.CheckAndDecodeBody(resp)
|
||||
if err != nil {
|
||||
rawErrBody := append([]byte(nil), resp.Body()...)
|
||||
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, rawErrBody, sendBackRawRequest, sendBackRawResponse)
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var taskResp RunwayTaskCreationResponse
|
||||
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &taskResp, jsonData, sendBackRawRequest, sendBackRawResponse)
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
|
||||
// Convert to Bifrost response
|
||||
bifrostResp := &schemas.BifrostVideoGenerationResponse{
|
||||
ID: providerUtils.AddVideoIDProviderSuffix(taskResp.ID, providerName),
|
||||
Model: model,
|
||||
Object: "video",
|
||||
Status: schemas.VideoStatusQueued,
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
Latency: latency.Milliseconds(),
|
||||
},
|
||||
}
|
||||
|
||||
if sendBackRawRequest {
|
||||
bifrostResp.ExtraFields.RawRequest = rawRequest
|
||||
}
|
||||
if sendBackRawResponse {
|
||||
bifrostResp.ExtraFields.RawResponse = rawResponse
|
||||
}
|
||||
|
||||
return bifrostResp, nil
|
||||
}
|
||||
|
||||
// VideoRetrieve retrieves the status of a video generation task from Runway's API.
|
||||
func (provider *RunwayProvider) VideoRetrieve(ctx *schemas.BifrostContext, key schemas.Key, bifrostReq *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
providerName := provider.GetProviderKey()
|
||||
taskID := providerUtils.StripVideoIDProviderSuffix(bifrostReq.ID, providerName)
|
||||
|
||||
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
||||
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
||||
|
||||
// Create HTTP request
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
// Set any extra headers from network config
|
||||
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
||||
|
||||
// Set request URI and headers
|
||||
req.SetRequestURI(provider.networkConfig.BaseURL + providerUtils.GetPathFromContext(ctx, "/v1/tasks/"+taskID))
|
||||
req.Header.SetMethod("GET")
|
||||
req.Header.Set("X-Runway-Version", "2024-11-06")
|
||||
if key.Value.GetValue() != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
||||
}
|
||||
|
||||
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
||||
defer wait()
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
|
||||
// Handle error response
|
||||
if resp.StatusCode() != fasthttp.StatusOK {
|
||||
return nil, providerUtils.EnrichError(ctx, parseRunwayError(resp), nil, nil, sendBackRawRequest, sendBackRawResponse)
|
||||
}
|
||||
|
||||
// Decode response body
|
||||
body, err := providerUtils.CheckAndDecodeBody(resp)
|
||||
if err != nil {
|
||||
rawErrBody := append([]byte(nil), resp.Body()...)
|
||||
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), nil, rawErrBody, sendBackRawRequest, sendBackRawResponse)
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var taskDetails RunwayTaskDetailsResponse
|
||||
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &taskDetails, nil, sendBackRawRequest, sendBackRawResponse)
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
|
||||
// Convert to Bifrost response
|
||||
bifrostResp, bifrostErr := ToBifrostVideoGenerationResponse(&taskDetails)
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
|
||||
bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, providerName)
|
||||
bifrostResp.ExtraFields.Latency = latency.Milliseconds()
|
||||
|
||||
if sendBackRawRequest {
|
||||
bifrostResp.ExtraFields.RawRequest = rawRequest
|
||||
}
|
||||
if sendBackRawResponse {
|
||||
bifrostResp.ExtraFields.RawResponse = rawResponse
|
||||
}
|
||||
|
||||
return bifrostResp, nil
|
||||
}
|
||||
|
||||
// VideoDownload retrieves a video from Runway's API.
|
||||
func (provider *RunwayProvider) VideoDownload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) {
|
||||
// Retrieve task status to get the video URL
|
||||
bifrostVideoRetrieveRequest := &schemas.BifrostVideoRetrieveRequest{
|
||||
Provider: request.Provider,
|
||||
ID: request.ID,
|
||||
}
|
||||
taskDetails, bifrostErr := provider.VideoRetrieve(ctx, key, bifrostVideoRetrieveRequest)
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
// Check if video is ready
|
||||
if taskDetails.Status != schemas.VideoStatusCompleted {
|
||||
return nil, providerUtils.NewBifrostOperationError(
|
||||
fmt.Sprintf("video not ready, current status: %s", taskDetails.Status),
|
||||
nil)
|
||||
}
|
||||
if len(taskDetails.Videos) == 0 {
|
||||
return nil, providerUtils.NewBifrostOperationError("video URL not available", nil)
|
||||
}
|
||||
var videoUrl string
|
||||
if taskDetails.Videos[0].URL != nil {
|
||||
videoUrl = *taskDetails.Videos[0].URL
|
||||
}
|
||||
if videoUrl == "" {
|
||||
return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil)
|
||||
}
|
||||
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
||||
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
||||
|
||||
// Download video from Runway's URL
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
req.SetRequestURI(videoUrl)
|
||||
req.Header.SetMethod("GET")
|
||||
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
||||
defer wait()
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
if resp.StatusCode() != fasthttp.StatusOK {
|
||||
return nil, providerUtils.NewBifrostOperationError(
|
||||
fmt.Sprintf("failed to download video: HTTP %d", resp.StatusCode()),
|
||||
nil)
|
||||
}
|
||||
// Get content and content type
|
||||
body, err := providerUtils.CheckAndDecodeBody(resp)
|
||||
if err != nil {
|
||||
rawErrBody := append([]byte(nil), resp.Body()...)
|
||||
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), nil, rawErrBody, sendBackRawRequest, sendBackRawResponse)
|
||||
}
|
||||
contentType := string(resp.Header.ContentType())
|
||||
if contentType == "" {
|
||||
contentType = "video/mp4" // Default for Runway
|
||||
}
|
||||
// Copy the binary content
|
||||
content := append([]byte(nil), body...)
|
||||
bifrostResp := &schemas.BifrostVideoDownloadResponse{
|
||||
VideoID: request.ID,
|
||||
Content: content,
|
||||
ContentType: contentType,
|
||||
}
|
||||
|
||||
bifrostResp.ExtraFields.Latency = latency.Milliseconds()
|
||||
|
||||
return bifrostResp, nil
|
||||
}
|
||||
|
||||
// VideoDelete cancels or deletes a task in Runway.
|
||||
// Tasks that are running, pending, or throttled can be canceled by invoking this method.
|
||||
// Invoking this method for other tasks will delete them.
|
||||
func (provider *RunwayProvider) VideoDelete(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoDeleteRequest) (*schemas.BifrostVideoDeleteResponse, *schemas.BifrostError) {
|
||||
providerName := provider.GetProviderKey()
|
||||
|
||||
if request.ID == "" {
|
||||
return nil, providerUtils.NewBifrostOperationError("task_id is required", nil)
|
||||
}
|
||||
|
||||
taskID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName)
|
||||
|
||||
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
||||
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
||||
|
||||
// Create request
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
||||
|
||||
req.SetRequestURI(provider.networkConfig.BaseURL + providerUtils.GetPathFromContext(ctx, "/v1/tasks/"+taskID))
|
||||
req.Header.SetMethod(http.MethodDelete)
|
||||
req.Header.Set("X-Runway-Version", "2024-11-06")
|
||||
if key.Value.GetValue() != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
||||
}
|
||||
|
||||
// Make request
|
||||
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
||||
defer wait()
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
|
||||
// Handle error response - Runway returns 204 No Content on success
|
||||
if resp.StatusCode() != fasthttp.StatusNoContent {
|
||||
return nil, providerUtils.EnrichError(ctx, parseRunwayError(resp), nil, nil, sendBackRawRequest, sendBackRawResponse)
|
||||
}
|
||||
|
||||
// Build response - Runway returns empty body on 204
|
||||
response := &schemas.BifrostVideoDeleteResponse{
|
||||
ID: request.ID, // Return with provider prefix
|
||||
Object: "video.deleted",
|
||||
Deleted: true,
|
||||
}
|
||||
|
||||
response.ExtraFields.Latency = latency.Milliseconds()
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// VideoList is not supported by Runway provider.
|
||||
func (provider *RunwayProvider) VideoList(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoListRequest) (*schemas.BifrostVideoListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoRemix is not supported by Runway provider.
|
||||
func (provider *RunwayProvider) VideoRemix(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRemixRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRemixRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileUpload is not supported by Runway provider.
|
||||
func (provider *RunwayProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileList is not supported by Runway provider.
|
||||
func (provider *RunwayProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileRetrieve is not supported by Runway provider.
|
||||
func (provider *RunwayProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileDelete is not supported by Runway provider.
|
||||
func (provider *RunwayProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileContent is not supported by Runway provider.
|
||||
func (provider *RunwayProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchCreate is not supported by Runway provider.
|
||||
func (provider *RunwayProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchList is not supported by Runway provider.
|
||||
func (provider *RunwayProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchRetrieve is not supported by Runway provider.
|
||||
func (provider *RunwayProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchCancel is not supported by Runway provider.
|
||||
func (provider *RunwayProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchDelete is not supported by Runway provider.
|
||||
func (provider *RunwayProvider) BatchDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchDeleteRequest) (*schemas.BifrostBatchDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchResults is not supported by Runway provider.
|
||||
func (provider *RunwayProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// CountTokens is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerCreate is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) ContainerCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerCreateRequest) (*schemas.BifrostContainerCreateResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerCreateRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerList is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) ContainerList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerListRequest) (*schemas.BifrostContainerListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerRetrieve is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) ContainerRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerRetrieveRequest) (*schemas.BifrostContainerRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerDelete is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) ContainerDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerDeleteRequest) (*schemas.BifrostContainerDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileCreate is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) ContainerFileCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerFileCreateRequest) (*schemas.BifrostContainerFileCreateResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileCreateRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileList is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) ContainerFileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileListRequest) (*schemas.BifrostContainerFileListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileRetrieve is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) ContainerFileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileRetrieveRequest) (*schemas.BifrostContainerFileRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileContent is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) ContainerFileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileContentRequest) (*schemas.BifrostContainerFileContentResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileContentRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileDelete is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) ContainerFileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileDeleteRequest) (*schemas.BifrostContainerFileDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Passthrough is not supported by the Runway provider.
|
||||
func (provider *RunwayProvider) Passthrough(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
func (provider *RunwayProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
41
core/providers/runway/runway_test.go
Normal file
41
core/providers/runway/runway_test.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package runway_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/internal/llmtests"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func TestRunway(t *testing.T) {
|
||||
t.Parallel()
|
||||
if strings.TrimSpace(os.Getenv("RUNWAY_API_KEY")) == "" {
|
||||
t.Skip("Skipping Runway tests because RUNWAY_API_KEY is not set")
|
||||
}
|
||||
|
||||
client, ctx, cancel, err := llmtests.SetupTest()
|
||||
if err != nil {
|
||||
t.Fatalf("Error initializing test setup: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
defer client.Shutdown()
|
||||
|
||||
testConfig := llmtests.ComprehensiveTestConfig{
|
||||
Provider: schemas.Runway,
|
||||
VideoGenerationModel: "gen4.5",
|
||||
Scenarios: llmtests.TestScenarios{
|
||||
VideoGeneration: false, // disabled for now because of long running operations
|
||||
VideoRetrieve: false,
|
||||
VideoRemix: false,
|
||||
VideoDownload: false,
|
||||
VideoList: false,
|
||||
VideoDelete: false,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("RunwayTests", func(t *testing.T) {
|
||||
llmtests.RunAllComprehensiveTests(t, client, ctx, testConfig)
|
||||
})
|
||||
}
|
||||
119
core/providers/runway/types.go
Normal file
119
core/providers/runway/types.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package runway
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
)
|
||||
|
||||
type Reference struct {
|
||||
Type string `json:"type"` // always image
|
||||
URI string `json:"uri"`
|
||||
}
|
||||
|
||||
type ReferenceImage struct {
|
||||
URI string `json:"uri"`
|
||||
Tag string `json:"tag"`
|
||||
}
|
||||
|
||||
type PromptImageObject struct {
|
||||
URI string `json:"uri"`
|
||||
Position string `json:"position"`
|
||||
}
|
||||
|
||||
type PromptImage struct {
|
||||
PromptImageStr *string
|
||||
PromptImageObject []PromptImageObject
|
||||
}
|
||||
|
||||
// custom marshal for PromptImage
|
||||
// MarshalJSON implements custom JSON marshalling for PromptImage.
|
||||
// It marshals either PromptImageStr or PromptImageObject directly without wrapping.
|
||||
func (pi PromptImage) MarshalJSON() ([]byte, error) {
|
||||
// Validation: ensure only one field is set at a time
|
||||
if pi.PromptImageStr != nil && pi.PromptImageObject != nil {
|
||||
return nil, fmt.Errorf("both PromptImageStr and PromptImageObject are set; only one should be non-nil")
|
||||
}
|
||||
|
||||
if pi.PromptImageStr != nil {
|
||||
return providerUtils.MarshalSorted(*pi.PromptImageStr)
|
||||
}
|
||||
if pi.PromptImageObject != nil {
|
||||
return providerUtils.MarshalSorted(pi.PromptImageObject)
|
||||
}
|
||||
// If both are nil, return null
|
||||
return []byte("null"), nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshalling for PromptImage.
|
||||
// It determines whether "content" is a string or array and assigns to the appropriate field.
|
||||
func (pi *PromptImage) UnmarshalJSON(data []byte) error {
|
||||
// First, try to unmarshal as a direct string
|
||||
pi.PromptImageStr = nil
|
||||
pi.PromptImageObject = nil
|
||||
|
||||
var stringContent string
|
||||
if err := sonic.Unmarshal(data, &stringContent); err == nil {
|
||||
pi.PromptImageStr = &stringContent
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to unmarshal as a direct array of PromptImageObject
|
||||
var arrayContent []PromptImageObject
|
||||
if err := sonic.Unmarshal(data, &arrayContent); err == nil {
|
||||
pi.PromptImageObject = arrayContent
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("promptImage field is neither a string nor an array of PromptImageObject")
|
||||
}
|
||||
|
||||
type RunwayVideoGenerationRequest struct {
|
||||
Model string `json:"model"`
|
||||
PromptText *string `json:"promptText,omitempty"`
|
||||
PromptImage *PromptImage `json:"promptImage,omitempty"`
|
||||
VideoURI *string `json:"videoUri,omitempty"`
|
||||
References []Reference `json:"references,omitempty"` // for video to video generation
|
||||
ReferenceImages []ReferenceImage `json:"referenceImages,omitempty"` // for text to video generation
|
||||
Seed *int `json:"seed,omitempty"`
|
||||
Ratio *string `json:"ratio,omitempty"`
|
||||
Duration *int `json:"duration,omitempty"`
|
||||
Audio *bool `json:"audio,omitempty"` // for veo models
|
||||
ContentModeration *ContentModeration `json:"contentModeration,omitempty"`
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
func (r *RunwayVideoGenerationRequest) GetExtraParams() map[string]interface{} {
|
||||
return r.ExtraParams
|
||||
}
|
||||
|
||||
type ContentModeration struct {
|
||||
PublicFigureThreshold *string `json:"public_figure_threshold,omitempty"`
|
||||
}
|
||||
|
||||
type RunwayTaskCreationResponse struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
type RunwayTaskStatus string
|
||||
|
||||
const (
|
||||
RunwayTaskStatusPending RunwayTaskStatus = "PENDING"
|
||||
RunwayTaskStatusThrottled RunwayTaskStatus = "THROTTLED"
|
||||
RunwayTaskStatusCancelled RunwayTaskStatus = "CANCELLED"
|
||||
RunwayTaskStatusRunning RunwayTaskStatus = "RUNNING"
|
||||
RunwayTaskStatusFailed RunwayTaskStatus = "FAILED"
|
||||
RunwayTaskStatusSucceeded RunwayTaskStatus = "SUCCEEDED"
|
||||
)
|
||||
|
||||
type RunwayTaskDetailsResponse struct {
|
||||
Status RunwayTaskStatus `json:"status"`
|
||||
ID string `json:"id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Output []string `json:"output,omitempty"`
|
||||
}
|
||||
|
||||
type RunwayAPIError struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
34
core/providers/runway/utils.go
Normal file
34
core/providers/runway/utils.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package runway
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// getRunwayEndpoint determines which Runway API endpoint to use based on the request parameters.
|
||||
// Returns the appropriate endpoint path:
|
||||
// - /v1/text_to_video: when only text prompt is provided
|
||||
// - /v1/video_to_video: when video URI is provided
|
||||
// - /v1/image_to_video: when image input reference is provided
|
||||
func getRunwayEndpoint(req *schemas.BifrostVideoGenerationRequest) string {
|
||||
if req.Params != nil && req.Params.VideoURI != nil && *req.Params.VideoURI != "" {
|
||||
return "/v1/video_to_video"
|
||||
}
|
||||
if req.Input != nil && req.Input.InputReference != nil && *req.Input.InputReference != "" {
|
||||
return "/v1/image_to_video"
|
||||
}
|
||||
return "/v1/text_to_video"
|
||||
}
|
||||
|
||||
func isRunwayGenModel(model string) bool {
|
||||
return strings.Contains(model, "gen")
|
||||
}
|
||||
|
||||
func isRunwayVeoModel(model string) bool {
|
||||
return strings.Contains(model, "veo")
|
||||
}
|
||||
|
||||
func supportsVideoToVideo(model string) bool {
|
||||
return model == "gen4_aleph"
|
||||
}
|
||||
172
core/providers/runway/videos.go
Normal file
172
core/providers/runway/videos.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package runway
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func ToRunwayVideoGenerationRequest(bifrostReq *schemas.BifrostVideoGenerationRequest) (*RunwayVideoGenerationRequest, error) {
|
||||
// three types of video generation requests in runway api
|
||||
// 1. image to video
|
||||
// 2. text to video
|
||||
// 3. video to video
|
||||
if bifrostReq.Input == nil {
|
||||
return nil, fmt.Errorf("input is required")
|
||||
}
|
||||
|
||||
request := &RunwayVideoGenerationRequest{
|
||||
Model: bifrostReq.Model,
|
||||
Ratio: schemas.Ptr("1280:720"),
|
||||
}
|
||||
|
||||
if isRunwayVeoModel(bifrostReq.Model) {
|
||||
request.Duration = schemas.Ptr(4)
|
||||
} else if isRunwayGenModel(bifrostReq.Model) {
|
||||
request.Duration = schemas.Ptr(2)
|
||||
}
|
||||
|
||||
if bifrostReq.Input.Prompt != "" {
|
||||
request.PromptText = &bifrostReq.Input.Prompt
|
||||
}
|
||||
if bifrostReq.Input.InputReference != nil {
|
||||
sanitizedURL, err := schemas.SanitizeImageURL(*bifrostReq.Input.InputReference)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid input reference: %w", err)
|
||||
}
|
||||
request.PromptImage = &PromptImage{
|
||||
PromptImageStr: schemas.Ptr(sanitizedURL),
|
||||
}
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil {
|
||||
if bifrostReq.Params.Seconds != nil {
|
||||
seconds, err := strconv.Atoi(*bifrostReq.Params.Seconds)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid seconds value: %w", err)
|
||||
}
|
||||
request.Duration = &seconds
|
||||
}
|
||||
|
||||
if bifrostReq.Params.Size != "" {
|
||||
// convert 1280x720 to 1280:720
|
||||
request.Ratio = schemas.Ptr(strings.Replace(bifrostReq.Params.Size, "x", ":", 1))
|
||||
}
|
||||
|
||||
if isRunwayVeoModel(bifrostReq.Model) {
|
||||
if bifrostReq.Params.Audio != nil {
|
||||
request.Audio = bifrostReq.Params.Audio
|
||||
}
|
||||
}
|
||||
|
||||
if isRunwayGenModel(bifrostReq.Model) {
|
||||
if bifrostReq.Params.Seed != nil {
|
||||
request.Seed = bifrostReq.Params.Seed
|
||||
}
|
||||
}
|
||||
|
||||
if bifrostReq.Params.VideoURI != nil {
|
||||
if !supportsVideoToVideo(bifrostReq.Model) {
|
||||
return nil, fmt.Errorf("video_uri is not supported for model %s", bifrostReq.Model)
|
||||
}
|
||||
request.VideoURI = bifrostReq.Params.VideoURI
|
||||
}
|
||||
|
||||
if bifrostReq.Params.ExtraParams != nil {
|
||||
request.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
// Handle references for video-to-video generation
|
||||
if refsVal := bifrostReq.Params.ExtraParams["references"]; refsVal != nil {
|
||||
if refs, ok := refsVal.([]Reference); ok && refs != nil {
|
||||
request.References = refs
|
||||
delete(request.ExtraParams, "references")
|
||||
} else if refs, err := schemas.ConvertViaJSON[[]Reference](refsVal); err == nil {
|
||||
request.References = refs
|
||||
delete(request.ExtraParams, "references")
|
||||
}
|
||||
}
|
||||
|
||||
// Handle reference images for video generation
|
||||
if refImagesVal := bifrostReq.Params.ExtraParams["reference_images"]; refImagesVal != nil {
|
||||
if refImages, ok := refImagesVal.([]ReferenceImage); ok && refImages != nil {
|
||||
delete(request.ExtraParams, "reference_images")
|
||||
request.ReferenceImages = refImages
|
||||
} else if refImages, err := schemas.ConvertViaJSON[[]ReferenceImage](refImagesVal); err == nil {
|
||||
delete(request.ExtraParams, "reference_images")
|
||||
request.ReferenceImages = refImages
|
||||
}
|
||||
}
|
||||
|
||||
// add content moderation
|
||||
if isRunwayVeoModel(bifrostReq.Model) {
|
||||
if cmVal := bifrostReq.Params.ExtraParams["content_moderation"]; cmVal != nil {
|
||||
if cm, ok := cmVal.(*ContentModeration); ok && cm != nil {
|
||||
delete(request.ExtraParams, "content_moderation")
|
||||
request.ContentModeration = cm
|
||||
} else if cm, err := schemas.ConvertViaJSON[ContentModeration](cmVal); err == nil {
|
||||
delete(request.ExtraParams, "content_moderation")
|
||||
request.ContentModeration = &cm
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return request, nil
|
||||
}
|
||||
|
||||
// ToBifrostVideoGenerationResponse converts Runway task details to Bifrost video generation response format.
|
||||
func ToBifrostVideoGenerationResponse(taskDetails *RunwayTaskDetailsResponse) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
if taskDetails == nil {
|
||||
return nil, providerUtils.NewBifrostOperationError("task details is nil", nil)
|
||||
}
|
||||
|
||||
response := &schemas.BifrostVideoGenerationResponse{
|
||||
ID: taskDetails.ID,
|
||||
Object: "video",
|
||||
CreatedAt: time.Now().Unix(),
|
||||
}
|
||||
|
||||
// Map Runway task status to Bifrost video status
|
||||
switch taskDetails.Status {
|
||||
case RunwayTaskStatusPending, RunwayTaskStatusThrottled:
|
||||
response.Status = schemas.VideoStatusQueued
|
||||
case RunwayTaskStatusRunning:
|
||||
response.Status = schemas.VideoStatusInProgress
|
||||
case RunwayTaskStatusSucceeded:
|
||||
response.Status = schemas.VideoStatusCompleted
|
||||
case RunwayTaskStatusFailed, RunwayTaskStatusCancelled:
|
||||
response.Status = schemas.VideoStatusFailed
|
||||
// Set error message for failed tasks
|
||||
errorMsg := fmt.Sprintf("Task %s", taskDetails.Status)
|
||||
response.Error = &schemas.VideoCreateError{
|
||||
Code: string(taskDetails.Status),
|
||||
Message: errorMsg,
|
||||
}
|
||||
default:
|
||||
response.Status = schemas.VideoStatusQueued
|
||||
}
|
||||
|
||||
if len(taskDetails.Output) > 0 {
|
||||
response.Videos = make([]schemas.VideoOutput, len(taskDetails.Output))
|
||||
for i, url := range taskDetails.Output {
|
||||
response.Videos[i] = schemas.VideoOutput{
|
||||
Type: schemas.VideoOutputTypeURL,
|
||||
URL: schemas.Ptr(url),
|
||||
ContentType: "video/mp4",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse created_at timestamp if available
|
||||
if taskDetails.CreatedAt != "" {
|
||||
if t, err := time.Parse(time.RFC3339, taskDetails.CreatedAt); err == nil {
|
||||
response.CreatedAt = t.Unix()
|
||||
}
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
116
core/providers/runway/videos_test.go
Normal file
116
core/providers/runway/videos_test.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package runway
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func makeVideoReq(model string, extraParams map[string]interface{}) *schemas.BifrostVideoGenerationRequest {
|
||||
return &schemas.BifrostVideoGenerationRequest{
|
||||
Model: model,
|
||||
Input: &schemas.VideoGenerationInput{
|
||||
Prompt: "test prompt",
|
||||
},
|
||||
Params: &schemas.VideoGenerationParameters{
|
||||
ExtraParams: extraParams,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestToRunwayVideoGenerationRequest_References(t *testing.T) {
|
||||
t.Run("direct_typed_references", func(t *testing.T) {
|
||||
refs := []Reference{{Type: "image", URI: "https://example.com/img.jpg"}}
|
||||
req := makeVideoReq("gen3", map[string]interface{}{
|
||||
"references": refs,
|
||||
})
|
||||
|
||||
result, err := ToRunwayVideoGenerationRequest(req)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result.References, 1)
|
||||
assert.Equal(t, "image", result.References[0].Type)
|
||||
assert.Equal(t, "https://example.com/img.jpg", result.References[0].URI)
|
||||
assert.NotContains(t, result.ExtraParams, "references")
|
||||
})
|
||||
|
||||
t.Run("map_fallback_references", func(t *testing.T) {
|
||||
// Simulates what happens when references arrive via JSON deserialization
|
||||
req := makeVideoReq("gen3", map[string]interface{}{
|
||||
"references": []interface{}{
|
||||
map[string]interface{}{"type": "image", "uri": "https://example.com/img.jpg"},
|
||||
},
|
||||
})
|
||||
|
||||
result, err := ToRunwayVideoGenerationRequest(req)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result.References, 1, "ConvertViaJSON fallback should convert map-based references")
|
||||
assert.Equal(t, "image", result.References[0].Type)
|
||||
assert.Equal(t, "https://example.com/img.jpg", result.References[0].URI)
|
||||
assert.NotContains(t, result.ExtraParams, "references")
|
||||
})
|
||||
}
|
||||
|
||||
func TestToRunwayVideoGenerationRequest_ReferenceImages(t *testing.T) {
|
||||
t.Run("direct_typed_reference_images", func(t *testing.T) {
|
||||
refImages := []ReferenceImage{{URI: "https://example.com/ref.jpg", Tag: "style"}}
|
||||
req := makeVideoReq("gen3", map[string]interface{}{
|
||||
"reference_images": refImages,
|
||||
})
|
||||
|
||||
result, err := ToRunwayVideoGenerationRequest(req)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result.ReferenceImages, 1)
|
||||
assert.Equal(t, "https://example.com/ref.jpg", result.ReferenceImages[0].URI)
|
||||
assert.Equal(t, "style", result.ReferenceImages[0].Tag)
|
||||
assert.NotContains(t, result.ExtraParams, "reference_images")
|
||||
})
|
||||
|
||||
t.Run("map_fallback_reference_images", func(t *testing.T) {
|
||||
req := makeVideoReq("gen3", map[string]interface{}{
|
||||
"reference_images": []interface{}{
|
||||
map[string]interface{}{"uri": "https://example.com/ref.jpg", "tag": "style"},
|
||||
},
|
||||
})
|
||||
|
||||
result, err := ToRunwayVideoGenerationRequest(req)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result.ReferenceImages, 1, "ConvertViaJSON fallback should convert map-based reference images")
|
||||
assert.Equal(t, "https://example.com/ref.jpg", result.ReferenceImages[0].URI)
|
||||
assert.Equal(t, "style", result.ReferenceImages[0].Tag)
|
||||
assert.NotContains(t, result.ExtraParams, "reference_images")
|
||||
})
|
||||
}
|
||||
|
||||
func TestToRunwayVideoGenerationRequest_ContentModeration(t *testing.T) {
|
||||
// ContentModeration handling only applies to veo models
|
||||
t.Run("pointer_content_moderation", func(t *testing.T) {
|
||||
cm := &ContentModeration{PublicFigureThreshold: schemas.Ptr("high")}
|
||||
req := makeVideoReq("veo-model", map[string]interface{}{
|
||||
"content_moderation": cm,
|
||||
})
|
||||
|
||||
result, err := ToRunwayVideoGenerationRequest(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result.ContentModeration)
|
||||
require.NotNil(t, result.ContentModeration.PublicFigureThreshold)
|
||||
assert.Equal(t, "high", *result.ContentModeration.PublicFigureThreshold)
|
||||
assert.NotContains(t, result.ExtraParams, "content_moderation")
|
||||
})
|
||||
|
||||
t.Run("map_fallback_content_moderation", func(t *testing.T) {
|
||||
req := makeVideoReq("veo-model", map[string]interface{}{
|
||||
"content_moderation": map[string]interface{}{
|
||||
"public_figure_threshold": "high",
|
||||
},
|
||||
})
|
||||
|
||||
result, err := ToRunwayVideoGenerationRequest(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result.ContentModeration, "ConvertViaJSON fallback should convert map-based content moderation")
|
||||
require.NotNil(t, result.ContentModeration.PublicFigureThreshold)
|
||||
assert.Equal(t, "high", *result.ContentModeration.PublicFigureThreshold)
|
||||
assert.NotContains(t, result.ExtraParams, "content_moderation")
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user