first commit
This commit is contained in:
435
core/providers/fireworks/fireworks.go
Normal file
435
core/providers/fireworks/fireworks.go
Normal file
@@ -0,0 +1,435 @@
|
||||
// Package fireworks implements the Fireworks AI provider and its utility functions.
|
||||
package fireworks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/openai"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// FireworksProvider implements the Provider interface for Fireworks AI's API.
|
||||
type FireworksProvider struct {
|
||||
logger schemas.Logger // Logger for provider operations
|
||||
client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response)
|
||||
streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader)
|
||||
networkConfig schemas.NetworkConfig // Network configuration including extra headers
|
||||
sendBackRawRequest bool // Whether to include raw request in BifrostResponse
|
||||
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
|
||||
}
|
||||
|
||||
// NewFireworksProvider creates a new Fireworks AI provider instance.
|
||||
// It initializes the HTTP client with the provided configuration and sets up response pools.
|
||||
// The client is configured with timeouts, concurrency limits, and optional proxy settings.
|
||||
func NewFireworksProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*FireworksProvider, error) {
|
||||
config.CheckAndSetDefaults()
|
||||
|
||||
requestTimeout := time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)
|
||||
client := &fasthttp.Client{
|
||||
ReadTimeout: requestTimeout,
|
||||
WriteTimeout: requestTimeout,
|
||||
MaxConnsPerHost: config.NetworkConfig.MaxConnsPerHost,
|
||||
MaxIdleConnDuration: 30 * time.Second,
|
||||
MaxConnWaitTimeout: requestTimeout,
|
||||
MaxConnDuration: time.Second * time.Duration(schemas.DefaultMaxConnDurationInSeconds),
|
||||
ConnPoolStrategy: fasthttp.FIFO,
|
||||
}
|
||||
|
||||
// Configure proxy and retry policy
|
||||
client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger)
|
||||
client = providerUtils.ConfigureDialer(client)
|
||||
client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger)
|
||||
streamingClient := providerUtils.BuildStreamingClient(client)
|
||||
// Set default BaseURL if not provided
|
||||
if config.NetworkConfig.BaseURL == "" {
|
||||
config.NetworkConfig.BaseURL = "https://api.fireworks.ai/inference"
|
||||
}
|
||||
config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/")
|
||||
|
||||
return &FireworksProvider{
|
||||
logger: logger,
|
||||
client: client,
|
||||
streamingClient: streamingClient,
|
||||
networkConfig: config.NetworkConfig,
|
||||
sendBackRawRequest: config.SendBackRawRequest,
|
||||
sendBackRawResponse: config.SendBackRawResponse,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetProviderKey returns the provider identifier for Fireworks AI.
|
||||
func (provider *FireworksProvider) GetProviderKey() schemas.ModelProvider {
|
||||
return schemas.Fireworks
|
||||
}
|
||||
|
||||
// ListModels performs a list models request to Fireworks AI's API.
|
||||
func (provider *FireworksProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
|
||||
return openai.HandleOpenAIListModelsRequest(
|
||||
ctx,
|
||||
provider.client,
|
||||
request,
|
||||
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/models"),
|
||||
keys,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
schemas.Fireworks,
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
// TextCompletion performs a text completion request to the Fireworks AI API.
|
||||
func (provider *FireworksProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) {
|
||||
return openai.HandleOpenAITextCompletionRequest(
|
||||
ctx,
|
||||
provider.client,
|
||||
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"),
|
||||
request,
|
||||
key,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
provider.GetProviderKey(),
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
nil,
|
||||
nil,
|
||||
provider.logger,
|
||||
)
|
||||
}
|
||||
|
||||
// TextCompletionStream performs a streaming text completion request to the Fireworks AI API.
|
||||
func (provider *FireworksProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
var authHeader map[string]string
|
||||
if v := key.Value.GetValue(); v != "" {
|
||||
authHeader = map[string]string{"Authorization": "Bearer " + v}
|
||||
}
|
||||
return openai.HandleOpenAITextCompletionStreaming(
|
||||
ctx,
|
||||
provider.streamingClient,
|
||||
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"),
|
||||
request,
|
||||
authHeader,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
provider.GetProviderKey(),
|
||||
nil,
|
||||
postHookRunner,
|
||||
nil,
|
||||
nil,
|
||||
provider.logger,
|
||||
postHookSpanFinalizer,
|
||||
)
|
||||
}
|
||||
|
||||
// ChatCompletion performs a chat completion request to the Fireworks AI API.
|
||||
func (provider *FireworksProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
return openai.HandleOpenAIChatCompletionRequest(
|
||||
ctx,
|
||||
provider.client,
|
||||
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"),
|
||||
request,
|
||||
key,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
provider.GetProviderKey(),
|
||||
nil,
|
||||
nil,
|
||||
provider.logger,
|
||||
)
|
||||
}
|
||||
|
||||
// ChatCompletionStream performs a streaming chat completion request to the Fireworks AI API.
|
||||
// It supports real-time streaming of responses using Server-Sent Events (SSE).
|
||||
// Uses Fireworks AI's OpenAI-compatible streaming format.
|
||||
// Returns a channel containing BifrostStreamChunk objects representing the stream or an error if the request fails.
|
||||
func (provider *FireworksProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
var authHeader map[string]string
|
||||
if v := key.Value.GetValue(); v != "" {
|
||||
authHeader = map[string]string{"Authorization": "Bearer " + v}
|
||||
}
|
||||
// Use shared OpenAI-compatible streaming logic
|
||||
return openai.HandleOpenAIChatCompletionStreaming(
|
||||
ctx,
|
||||
provider.streamingClient,
|
||||
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"),
|
||||
request,
|
||||
authHeader,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
schemas.Fireworks,
|
||||
postHookRunner,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
provider.logger,
|
||||
postHookSpanFinalizer,
|
||||
)
|
||||
}
|
||||
|
||||
// Responses performs a responses request to the Fireworks AI API.
|
||||
func (provider *FireworksProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
return openai.HandleOpenAIResponsesRequest(
|
||||
ctx,
|
||||
provider.client,
|
||||
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/responses"),
|
||||
request,
|
||||
key,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
provider.GetProviderKey(),
|
||||
nil,
|
||||
nil,
|
||||
provider.logger,
|
||||
)
|
||||
}
|
||||
|
||||
// ResponsesStream performs a streaming responses request to the Fireworks AI API.
|
||||
func (provider *FireworksProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
var authHeader map[string]string
|
||||
if v := key.Value.GetValue(); v != "" {
|
||||
authHeader = map[string]string{"Authorization": "Bearer " + v}
|
||||
}
|
||||
return openai.HandleOpenAIResponsesStreaming(
|
||||
ctx,
|
||||
provider.streamingClient,
|
||||
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/responses"),
|
||||
request,
|
||||
authHeader,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
provider.GetProviderKey(),
|
||||
postHookRunner,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
provider.logger,
|
||||
postHookSpanFinalizer,
|
||||
)
|
||||
}
|
||||
|
||||
// Embedding performs an embedding request to the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
|
||||
return openai.HandleOpenAIEmbeddingRequest(
|
||||
ctx,
|
||||
provider.client,
|
||||
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/embeddings"),
|
||||
request,
|
||||
key,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
provider.GetProviderKey(),
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
nil,
|
||||
provider.logger,
|
||||
)
|
||||
}
|
||||
|
||||
// Speech is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Rerank is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.RerankRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// OCR is not supported by the Fireworks provider.
|
||||
func (provider *FireworksProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostOCRRequest) (*schemas.BifrostOCRResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.OCRRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// SpeechStream is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Transcription is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// TranscriptionStream is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageGeneration is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageGenerationStream is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageEdit is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageEditStream is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageVariation is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ImageVariation(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageVariationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageVariationRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoGeneration is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) VideoGeneration(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoGenerationRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoRetrieve is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) VideoRetrieve(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoDownload is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) VideoDownload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDownloadRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoDelete is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) VideoDelete(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDeleteRequest) (*schemas.BifrostVideoDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoList is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) VideoList(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoListRequest) (*schemas.BifrostVideoListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoRemix is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) VideoRemix(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRemixRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRemixRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchCreate is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchList is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchRetrieve is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchCancel is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchDelete is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) BatchDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchDeleteRequest) (*schemas.BifrostBatchDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchResults is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileUpload is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileList is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileRetrieve is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileDelete is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileContent is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// CountTokens is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerCreate is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ContainerCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerCreateRequest) (*schemas.BifrostContainerCreateResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerCreateRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerList is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ContainerList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerListRequest) (*schemas.BifrostContainerListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerRetrieve is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ContainerRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerRetrieveRequest) (*schemas.BifrostContainerRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerDelete is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ContainerDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerDeleteRequest) (*schemas.BifrostContainerDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileCreate is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ContainerFileCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerFileCreateRequest) (*schemas.BifrostContainerFileCreateResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileCreateRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileList is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ContainerFileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileListRequest) (*schemas.BifrostContainerFileListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileRetrieve is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ContainerFileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileRetrieveRequest) (*schemas.BifrostContainerFileRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileContent is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ContainerFileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileContentRequest) (*schemas.BifrostContainerFileContentResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileContentRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileDelete is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ContainerFileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileDeleteRequest) (*schemas.BifrostContainerFileDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Passthrough is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) Passthrough(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// PassthroughStream is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
443
core/providers/fireworks/fireworks_test.go
Normal file
443
core/providers/fireworks/fireworks_test.go
Normal file
@@ -0,0 +1,443 @@
|
||||
package fireworks_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/internal/llmtests"
|
||||
fireworksprovider "github.com/maximhq/bifrost/core/providers/fireworks"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func TestFireworks(t *testing.T) {
|
||||
t.Parallel()
|
||||
if strings.TrimSpace(os.Getenv("FIREWORKS_API_KEY")) == "" {
|
||||
t.Skip("Skipping Fireworks tests because FIREWORKS_API_KEY is not set")
|
||||
}
|
||||
|
||||
client, ctx, cancel, err := llmtests.SetupTest()
|
||||
if err != nil {
|
||||
t.Fatalf("Error initializing test setup: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
defer client.Shutdown()
|
||||
|
||||
chatModel, textModel, embeddingModel := resolveFireworksModels(t, client, ctx)
|
||||
|
||||
testConfig := llmtests.ComprehensiveTestConfig{
|
||||
Provider: schemas.Fireworks,
|
||||
ChatModel: chatModel,
|
||||
Fallbacks: []schemas.Fallback{},
|
||||
TextModel: textModel,
|
||||
TextCompletionFallbacks: []schemas.Fallback{},
|
||||
EmbeddingModel: embeddingModel,
|
||||
ReasoningModel: "",
|
||||
TranscriptionModel: "",
|
||||
SpeechSynthesisModel: "",
|
||||
Scenarios: llmtests.TestScenarios{
|
||||
TextCompletion: textModel != "",
|
||||
TextCompletionStream: textModel != "",
|
||||
SimpleChat: true,
|
||||
CompletionStream: true,
|
||||
MultiTurnConversation: true,
|
||||
ToolCalls: true,
|
||||
ToolCallsStreaming: true,
|
||||
MultipleToolCalls: false,
|
||||
End2EndToolCalling: false,
|
||||
AutomaticFunctionCall: false,
|
||||
ImageURL: false,
|
||||
ImageBase64: false,
|
||||
MultipleImages: false,
|
||||
FileBase64: false,
|
||||
FileURL: false,
|
||||
CompleteEnd2End: true,
|
||||
Embedding: embeddingModel != "",
|
||||
ListModels: true,
|
||||
Reasoning: false,
|
||||
Transcription: false,
|
||||
SpeechSynthesis: false,
|
||||
PromptCaching: false,
|
||||
},
|
||||
}
|
||||
t.Run("FireworksTests", func(t *testing.T) {
|
||||
llmtests.RunAllComprehensiveTests(t, client, ctx, testConfig)
|
||||
})
|
||||
}
|
||||
|
||||
// resolveFireworksModels discovers live Fireworks models for chat, completions, and embeddings.
|
||||
func resolveFireworksModels(t *testing.T, client *bifrost.Bifrost, ctx context.Context) (string, string, string) {
|
||||
t.Helper()
|
||||
|
||||
requestedChatModel := normalizeFireworksModelID(os.Getenv("FIREWORKS_CHAT_MODEL"))
|
||||
requestedTextModel := normalizeFireworksModelID(os.Getenv("FIREWORKS_TEXT_MODEL"))
|
||||
requestedEmbeddingModel := normalizeFireworksModelID(os.Getenv("FIREWORKS_EMBEDDING_MODEL"))
|
||||
|
||||
chatModel := requestedChatModel
|
||||
textModel := requestedTextModel
|
||||
embeddingModel := requestedEmbeddingModel
|
||||
|
||||
if requestedChatModel != "" {
|
||||
t.Logf("Using FIREWORKS_CHAT_MODEL=%q override", requestedChatModel)
|
||||
}
|
||||
if requestedTextModel != "" {
|
||||
t.Logf("Using FIREWORKS_TEXT_MODEL=%q override", requestedTextModel)
|
||||
}
|
||||
if requestedEmbeddingModel != "" {
|
||||
t.Logf("Using FIREWORKS_EMBEDDING_MODEL=%q override", requestedEmbeddingModel)
|
||||
}
|
||||
|
||||
if chatModel == "" || textModel == "" || embeddingModel == "" {
|
||||
pageToken := ""
|
||||
for page := 0; page < 5; page++ {
|
||||
req := &schemas.BifrostListModelsRequest{
|
||||
Provider: schemas.Fireworks,
|
||||
PageSize: 200,
|
||||
PageToken: pageToken,
|
||||
}
|
||||
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
resp, bifrostErr := client.ListModelsRequest(bfCtx, req)
|
||||
if bifrostErr != nil {
|
||||
if chatModel == "" {
|
||||
t.Fatalf("Failed to list Fireworks models for test discovery: %v", llmtests.GetErrorMessage(bifrostErr))
|
||||
}
|
||||
t.Logf("Fireworks model discovery failed: %v", llmtests.GetErrorMessage(bifrostErr))
|
||||
break
|
||||
}
|
||||
|
||||
if chatModel == "" {
|
||||
chatModel = pickFireworksChatModel(resp.Data)
|
||||
}
|
||||
if textModel == "" {
|
||||
// Fireworks text completions currently reuse the chat-capable model pool;
|
||||
// a later probe verifies that the selected model accepts /v1/completions.
|
||||
textModel = pickFireworksChatModel(resp.Data)
|
||||
}
|
||||
if embeddingModel == "" {
|
||||
embeddingModel = pickFireworksEmbeddingModel(resp.Data)
|
||||
}
|
||||
|
||||
if chatModel != "" && textModel != "" && embeddingModel != "" {
|
||||
break
|
||||
}
|
||||
if resp.NextPageToken == "" {
|
||||
break
|
||||
}
|
||||
pageToken = resp.NextPageToken
|
||||
}
|
||||
}
|
||||
|
||||
if chatModel == "" {
|
||||
t.Fatal("Unable to discover a Fireworks chat model from /v1/models; set FIREWORKS_CHAT_MODEL to override")
|
||||
}
|
||||
if textModel != "" && !fireworksModelSupportsTextCompletions(t, client, ctx, textModel) {
|
||||
t.Logf("Skipping Fireworks text completion scenarios because model %q did not accept /v1/completions", textModel)
|
||||
textModel = ""
|
||||
}
|
||||
if embeddingModel != "" && !fireworksModelSupportsEmbeddings(t, client, ctx, embeddingModel) {
|
||||
t.Logf("Skipping Fireworks embedding scenario because model %q did not accept /v1/embeddings", embeddingModel)
|
||||
embeddingModel = ""
|
||||
}
|
||||
if textModel == "" {
|
||||
t.Log("No Fireworks completions-capable model discovered from /v1/models; text completion scenarios will be skipped unless FIREWORKS_TEXT_MODEL is set")
|
||||
}
|
||||
if embeddingModel == "" {
|
||||
t.Log("No Fireworks embedding model discovered from /v1/models; embedding scenario will be skipped unless FIREWORKS_EMBEDDING_MODEL is set")
|
||||
}
|
||||
|
||||
return chatModel, textModel, embeddingModel
|
||||
}
|
||||
|
||||
// fireworksModelSupportsTextCompletions validates that the selected model actually accepts Fireworks /v1/completions.
|
||||
func fireworksModelSupportsTextCompletions(t *testing.T, client *bifrost.Bifrost, ctx context.Context, model string) bool {
|
||||
t.Helper()
|
||||
|
||||
prompt := "Say ok"
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
resp, bifrostErr := client.TextCompletionRequest(bfCtx, &schemas.BifrostTextCompletionRequest{
|
||||
Provider: schemas.Fireworks,
|
||||
Model: model,
|
||||
Input: &schemas.TextCompletionInput{
|
||||
PromptStr: &prompt,
|
||||
},
|
||||
Params: &schemas.TextCompletionParameters{
|
||||
MaxTokens: schemas.Ptr(8),
|
||||
},
|
||||
})
|
||||
if bifrostErr != nil {
|
||||
t.Logf("Fireworks /v1/completions probe failed for %q: %v", model, llmtests.GetErrorMessage(bifrostErr))
|
||||
return false
|
||||
}
|
||||
|
||||
return resp != nil && len(resp.Choices) > 0
|
||||
}
|
||||
|
||||
// fireworksModelSupportsEmbeddings validates that the selected model actually accepts Fireworks /v1/embeddings.
|
||||
func fireworksModelSupportsEmbeddings(t *testing.T, client *bifrost.Bifrost, ctx context.Context, model string) bool {
|
||||
t.Helper()
|
||||
|
||||
text := "embedding probe"
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
resp, bifrostErr := client.EmbeddingRequest(bfCtx, &schemas.BifrostEmbeddingRequest{
|
||||
Provider: schemas.Fireworks,
|
||||
Model: model,
|
||||
Input: &schemas.EmbeddingInput{
|
||||
Text: &text,
|
||||
},
|
||||
})
|
||||
if bifrostErr != nil {
|
||||
t.Logf("Fireworks /v1/embeddings probe failed for %q: %v", model, llmtests.GetErrorMessage(bifrostErr))
|
||||
return false
|
||||
}
|
||||
|
||||
return resp != nil && len(resp.Data) > 0
|
||||
}
|
||||
|
||||
// pickFireworksChatModel selects a text-capable Fireworks model from ListModels output.
|
||||
func pickFireworksChatModel(models []schemas.Model) string {
|
||||
for _, model := range models {
|
||||
normalized := normalizeFireworksModelID(model.ID)
|
||||
if isFireworksTextCapable(normalized) {
|
||||
return normalized
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// pickFireworksEmbeddingModel selects an embedding-capable Fireworks model from ListModels output.
|
||||
func pickFireworksEmbeddingModel(models []schemas.Model) string {
|
||||
for _, model := range models {
|
||||
normalized := normalizeFireworksModelID(model.ID)
|
||||
lower := strings.ToLower(normalized)
|
||||
if strings.Contains(lower, "embedding") || strings.Contains(lower, "embed") {
|
||||
return normalized
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// normalizeFireworksModelID strips any provider prefix so tests can pass raw Fireworks model IDs to Bifrost requests.
|
||||
func normalizeFireworksModelID(modelID string) string {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" {
|
||||
return ""
|
||||
}
|
||||
_, normalized := schemas.ParseModelString(modelID, schemas.Fireworks)
|
||||
return normalized
|
||||
}
|
||||
|
||||
// isFireworksTextCapable applies a conservative name-based heuristic for text/chat-capable Fireworks models.
|
||||
func isFireworksTextCapable(modelID string) bool {
|
||||
lower := strings.ToLower(modelID)
|
||||
excludedHints := []string{
|
||||
"flux",
|
||||
"whisper",
|
||||
"audio",
|
||||
"speech",
|
||||
"transcrib",
|
||||
"embedding",
|
||||
"embed",
|
||||
"rerank",
|
||||
}
|
||||
for _, hint := range excludedHints {
|
||||
if strings.Contains(lower, hint) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
preferredHints := []string{
|
||||
"instruct",
|
||||
"chat",
|
||||
"gpt-oss",
|
||||
"deepseek",
|
||||
"qwen",
|
||||
"llama",
|
||||
"glm",
|
||||
"mixtral",
|
||||
"mistral",
|
||||
"cogito",
|
||||
"gemma",
|
||||
}
|
||||
for _, hint := range preferredHints {
|
||||
if strings.Contains(lower, hint) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// TestFireworksProviderUsesNativeEndpoints verifies that the Fireworks provider targets native completions, responses, and embeddings endpoints.
|
||||
func TestFireworksProviderUsesNativeEndpoints(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expectedPath string
|
||||
run func(t *testing.T, provider *fireworksprovider.FireworksProvider, ctx *schemas.BifrostContext, key schemas.Key)
|
||||
}{
|
||||
{
|
||||
name: "TextCompletion",
|
||||
expectedPath: "/v1/completions",
|
||||
run: func(t *testing.T, provider *fireworksprovider.FireworksProvider, ctx *schemas.BifrostContext, key schemas.Key) {
|
||||
prompt := "A is for apple and B is for"
|
||||
resp, err := provider.TextCompletion(ctx, key, &schemas.BifrostTextCompletionRequest{
|
||||
Provider: schemas.Fireworks,
|
||||
Model: "accounts/fireworks/models/deepseek-v3p2",
|
||||
Input: &schemas.TextCompletionInput{
|
||||
PromptStr: &prompt,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("TextCompletion returned error: %v", llmtests.GetErrorMessage(err))
|
||||
}
|
||||
if resp == nil || len(resp.Choices) == 0 || resp.Choices[0].Text == nil || *resp.Choices[0].Text == "" {
|
||||
t.Fatalf("unexpected text completion response: %#v", resp)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Responses",
|
||||
expectedPath: "/v1/responses",
|
||||
run: func(t *testing.T, provider *fireworksprovider.FireworksProvider, ctx *schemas.BifrostContext, key schemas.Key) {
|
||||
resp, err := provider.Responses(ctx, key, &schemas.BifrostResponsesRequest{
|
||||
Provider: schemas.Fireworks,
|
||||
Model: "accounts/fireworks/models/deepseek-v3p2",
|
||||
Input: []schemas.ResponsesMessage{
|
||||
llmtests.CreateBasicResponsesMessage("hello"),
|
||||
},
|
||||
Params: &schemas.ResponsesParameters{
|
||||
PreviousResponseID: schemas.Ptr("resp_previous"),
|
||||
MaxToolCalls: schemas.Ptr(2),
|
||||
Store: schemas.Ptr(true),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Responses returned error: %v", llmtests.GetErrorMessage(err))
|
||||
}
|
||||
if resp == nil || resp.PreviousResponseID == nil || *resp.PreviousResponseID != "resp_previous" {
|
||||
t.Fatalf("unexpected responses response: %#v", resp)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Embedding",
|
||||
expectedPath: "/v1/embeddings",
|
||||
run: func(t *testing.T, provider *fireworksprovider.FireworksProvider, ctx *schemas.BifrostContext, key schemas.Key) {
|
||||
resp, err := provider.Embedding(ctx, key, &schemas.BifrostEmbeddingRequest{
|
||||
Provider: schemas.Fireworks,
|
||||
Model: "accounts/fireworks/models/nomic-embed-text-v1.5",
|
||||
Input: &schemas.EmbeddingInput{
|
||||
Text: schemas.Ptr("embedding test"),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Embedding returned error: %v", llmtests.GetErrorMessage(err))
|
||||
}
|
||||
if resp == nil || len(resp.Data) != 1 || len(resp.Data[0].Embedding.EmbeddingArray) != 3 {
|
||||
t.Fatalf("unexpected embedding response: %#v", resp)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var requestedPath string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestedPath = r.URL.Path
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
switch r.URL.Path {
|
||||
case "/v1/completions":
|
||||
_, _ = fmt.Fprint(w, `{"id":"cmpl_1","object":"text_completion","created":1,"model":"accounts/fireworks/models/deepseek-v3p2","choices":[{"text":" banana","index":0,"finish_reason":"stop"}],"usage":{"prompt_tokens":4,"completion_tokens":1,"total_tokens":5}}`)
|
||||
case "/v1/responses":
|
||||
_, _ = fmt.Fprint(w, `{"id":"resp_1","object":"response","created_at":1,"status":"completed","model":"accounts/fireworks/models/deepseek-v3p2","output":[{"id":"msg_1","type":"message","status":"completed","role":"assistant","content":[{"type":"output_text","text":"hello","annotations":[],"logprobs":[]}]}],"previous_response_id":"resp_previous","max_tool_calls":2,"store":true,"usage":{"input_tokens":1,"input_tokens_details":{"cached_tokens":0,"cached_read_tokens":0,"cached_write_tokens":0},"output_tokens":1,"total_tokens":2}}`)
|
||||
case "/v1/embeddings":
|
||||
_, _ = fmt.Fprint(w, `{"object":"list","model":"accounts/fireworks/models/nomic-embed-text-v1.5","data":[{"object":"embedding","index":0,"embedding":[0.1,0.2,0.3]}],"usage":{"prompt_tokens":2,"total_tokens":2}}`)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := newTestFireworksProvider(t, server.URL)
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
key := schemas.Key{Value: *schemas.NewEnvVar("test-key")}
|
||||
|
||||
tt.run(t, provider, ctx, key)
|
||||
|
||||
if requestedPath != tt.expectedPath {
|
||||
t.Fatalf("expected request path %q, got %q", tt.expectedPath, requestedPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFireworksResponsesStreamUsesNativeResponsesEndpoint verifies that Fireworks responses streaming targets the native responses endpoint.
|
||||
func TestFireworksResponsesStreamUsesNativeResponsesEndpoint(t *testing.T) {
|
||||
var requestedPath string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestedPath = r.URL.Path
|
||||
if r.URL.Path != "/v1/responses" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = fmt.Fprint(w, "data: {\"type\":\"response.completed\",\"sequence_number\":0,\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":1,\"status\":\"completed\",\"model\":\"accounts/fireworks/models/deepseek-v3p2\",\"output\":[{\"id\":\"msg_1\",\"type\":\"message\",\"status\":\"completed\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"hello\",\"annotations\":[],\"logprobs\":[]}]}],\"usage\":{\"input_tokens\":1,\"input_tokens_details\":{\"cached_tokens\":0,\"cached_read_tokens\":0,\"cached_write_tokens\":0},\"output_tokens\":1,\"total_tokens\":2}}}\n\n")
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := newTestFireworksProvider(t, server.URL)
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
key := schemas.Key{Value: *schemas.NewEnvVar("test-key")}
|
||||
postHookRunner := func(_ *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
|
||||
return result, err
|
||||
}
|
||||
|
||||
stream, err := provider.ResponsesStream(ctx, postHookRunner, nil, key, &schemas.BifrostResponsesRequest{
|
||||
Provider: schemas.Fireworks,
|
||||
Model: "accounts/fireworks/models/deepseek-v3p2",
|
||||
Input: []schemas.ResponsesMessage{
|
||||
llmtests.CreateBasicResponsesMessage("hello"),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ResponsesStream returned error: %v", llmtests.GetErrorMessage(err))
|
||||
}
|
||||
|
||||
sawCompleted := false
|
||||
for chunk := range stream {
|
||||
if chunk != nil && chunk.BifrostResponsesStreamResponse != nil &&
|
||||
chunk.BifrostResponsesStreamResponse.Type == schemas.ResponsesStreamResponseTypeCompleted {
|
||||
sawCompleted = true
|
||||
}
|
||||
}
|
||||
|
||||
if requestedPath != "/v1/responses" {
|
||||
t.Fatalf("expected responses stream to hit /v1/responses, got %q", requestedPath)
|
||||
}
|
||||
if !sawCompleted {
|
||||
t.Fatal("expected a completed responses stream chunk")
|
||||
}
|
||||
}
|
||||
|
||||
// newTestFireworksProvider creates a Fireworks provider configured to hit a local test server.
|
||||
func newTestFireworksProvider(t *testing.T, baseURL string) *fireworksprovider.FireworksProvider {
|
||||
t.Helper()
|
||||
|
||||
provider, err := fireworksprovider.NewFireworksProvider(&schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
BaseURL: baseURL,
|
||||
DefaultRequestTimeoutInSeconds: 30,
|
||||
},
|
||||
}, bifrost.NewNoOpLogger())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create Fireworks provider: %v", err)
|
||||
}
|
||||
return provider
|
||||
}
|
||||
Reference in New Issue
Block a user