Files
bifrost/core/providers/cohere/cohere.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

1267 lines
57 KiB
Go

package cohere
import (
"context"
"errors"
"fmt"
"io"
"strconv"
"strings"
"sync"
"time"
"net/http"
"net/url"
"github.com/bytedance/sonic"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
// cohereResponsePool provides a pool for Cohere v2 response objects.
var cohereResponsePool = sync.Pool{
New: func() interface{} {
return &CohereChatResponse{}
},
}
// cohereEmbeddingResponsePool provides a pool for Cohere embedding response objects.
var cohereEmbeddingResponsePool = sync.Pool{
New: func() interface{} {
return &CohereEmbeddingResponse{}
},
}
// acquireCohereEmbeddingResponse gets a Cohere embedding response from the pool and resets it.
func acquireCohereEmbeddingResponse() *CohereEmbeddingResponse {
resp := cohereEmbeddingResponsePool.Get().(*CohereEmbeddingResponse)
*resp = CohereEmbeddingResponse{} // Reset the struct
return resp
}
// releaseCohereEmbeddingResponse returns a Cohere embedding response to the pool.
func releaseCohereEmbeddingResponse(resp *CohereEmbeddingResponse) {
if resp != nil {
cohereEmbeddingResponsePool.Put(resp)
}
}
// cohereRerankResponsePool provides a pool for Cohere rerank response objects.
var cohereRerankResponsePool = sync.Pool{
New: func() interface{} {
return &CohereRerankResponse{}
},
}
// acquireCohereRerankResponse gets a Cohere rerank response from the pool and resets it.
func acquireCohereRerankResponse() *CohereRerankResponse {
resp := cohereRerankResponsePool.Get().(*CohereRerankResponse)
*resp = CohereRerankResponse{} // Reset the struct
return resp
}
// releaseCohereRerankResponse returns a Cohere rerank response to the pool.
func releaseCohereRerankResponse(resp *CohereRerankResponse) {
if resp != nil {
cohereRerankResponsePool.Put(resp)
}
}
// acquireCohereResponse gets a Cohere v2 response from the pool and resets it.
func acquireCohereResponse() *CohereChatResponse {
resp := cohereResponsePool.Get().(*CohereChatResponse)
*resp = CohereChatResponse{} // Reset the struct
return resp
}
// releaseCohereResponse returns a Cohere v2 response to the pool.
func releaseCohereResponse(resp *CohereChatResponse) {
if resp != nil {
cohereResponsePool.Put(resp)
}
}
// CohereProvider implements the Provider interface for Cohere.
type CohereProvider struct {
logger schemas.Logger // Logger for provider operations
client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response)
streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader)
networkConfig schemas.NetworkConfig // Network configuration including extra headers
sendBackRawRequest bool // Whether to include raw request in BifrostResponse
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
customProviderConfig *schemas.CustomProviderConfig // Custom provider config
}
// NewCohereProvider creates a new Cohere provider instance.
// It initializes the HTTP client with the provided configuration and sets up response pools.
// The client is configured with timeouts and connection limits.
func NewCohereProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*CohereProvider, error) {
config.CheckAndSetDefaults()
requestTimeout := time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)
client := &fasthttp.Client{
ReadTimeout: requestTimeout,
WriteTimeout: requestTimeout,
MaxConnsPerHost: config.NetworkConfig.MaxConnsPerHost,
MaxIdleConnDuration: 30 * time.Second,
MaxConnWaitTimeout: requestTimeout,
MaxConnDuration: time.Second * time.Duration(schemas.DefaultMaxConnDurationInSeconds),
ConnPoolStrategy: fasthttp.FIFO,
}
// Setting proxy and retry policy
client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger)
client = providerUtils.ConfigureDialer(client)
client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger)
// Pre-warm response pools
for i := 0; i < config.ConcurrencyAndBufferSize.Concurrency; i++ {
cohereResponsePool.Put(&CohereChatResponse{})
cohereEmbeddingResponsePool.Put(&CohereEmbeddingResponse{})
cohereRerankResponsePool.Put(&CohereRerankResponse{})
}
streamingClient := providerUtils.BuildStreamingClient(client)
// Set default BaseURL if not provided
if config.NetworkConfig.BaseURL == "" {
config.NetworkConfig.BaseURL = "https://api.cohere.ai"
}
config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/")
return &CohereProvider{
logger: logger,
client: client,
streamingClient: streamingClient,
networkConfig: config.NetworkConfig,
customProviderConfig: config.CustomProviderConfig,
sendBackRawRequest: config.SendBackRawRequest,
sendBackRawResponse: config.SendBackRawResponse,
}, nil
}
// GetProviderKey returns the provider identifier for Cohere.
func (provider *CohereProvider) GetProviderKey() schemas.ModelProvider {
return providerUtils.GetProviderName(schemas.Cohere, provider.customProviderConfig)
}
// buildRequestURL constructs the full request URL using the provider's configuration.
func (provider *CohereProvider) buildRequestURL(ctx *schemas.BifrostContext, defaultPath string, requestType schemas.RequestType) string {
path, isCompleteURL := providerUtils.GetRequestPath(ctx, defaultPath, provider.customProviderConfig, requestType)
if isCompleteURL {
return path
}
return provider.networkConfig.BaseURL + path
}
// completeRequest sends a request to Cohere's API and handles the response.
// It constructs the API URL, sets up authentication, and processes the response.
// Returns the response body or an error if the request fails.
func (provider *CohereProvider) completeRequest(ctx *schemas.BifrostContext, jsonData []byte, url string, key string) ([]byte, time.Duration, map[string]string, *schemas.BifrostError) {
// Create the request with the JSON body
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
respOwned := true
defer func() {
if respOwned {
fasthttp.ReleaseResponse(resp)
}
}()
// Set any extra headers from network config
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
req.SetRequestURI(url)
req.Header.SetMethod(http.MethodPost)
req.Header.SetContentType("application/json")
if key != "" {
req.Header.Set("Authorization", "Bearer "+key)
}
usedLargePayloadBody := providerUtils.ApplyLargePayloadRequestBodyWithModelNormalization(ctx, req, schemas.Cohere)
if !usedLargePayloadBody {
req.SetBody(jsonData)
}
// Send the request with optional large response streaming
activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.client, resp)
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
defer wait()
if usedLargePayloadBody {
providerUtils.DrainLargePayloadRemainder(ctx)
}
if bifrostErr != nil {
return nil, latency, nil, bifrostErr
}
// Extract provider response headers before status check so error responses also forward them
providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp)
// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
providerUtils.MaterializeStreamErrorBody(ctx, resp)
return nil, latency, providerResponseHeaders, parseCohereError(resp)
}
body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger)
if decodeErr != nil {
return nil, latency, providerResponseHeaders, decodeErr
}
if isLargeResp {
respOwned = false
return nil, latency, providerResponseHeaders, nil
}
return body, latency, providerResponseHeaders, nil
}
// listModelsByKey performs a list models request for a single key.
// Returns the response and latency, or an error if the request fails.
func (provider *CohereProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
// Create request
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
// Set any extra headers from network config
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
// Build base URL first
baseURL := provider.buildRequestURL(ctx, "/v1/models", schemas.ListModelsRequest)
// Parse and add query parameters
u, err := url.Parse(baseURL)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("failed to parse request url", err)
}
q := u.Query()
q.Set("page_size", strconv.Itoa(schemas.DefaultPageSize))
if request.ExtraParams != nil {
if endpoint, ok := request.ExtraParams["endpoint"].(string); ok && endpoint != "" {
q.Set("endpoint", endpoint)
}
if defaultOnly, ok := request.ExtraParams["default_only"].(bool); ok && defaultOnly {
q.Set("default_only", "true")
}
}
u.RawQuery = q.Encode()
// Set the final URL
req.SetRequestURI(u.String())
req.Header.SetMethod(http.MethodGet)
req.Header.SetContentType("application/json")
if key.Value.GetValue() != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key.Value.GetValue()))
}
// Make request
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
defer wait()
if bifrostErr != nil {
return nil, bifrostErr
}
// Store provider response headers in context before status check so error responses also forward them
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
return nil, parseCohereError(resp)
}
body, err := providerUtils.CheckAndDecodeBody(resp)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
}
// Parse Cohere list models response
var cohereResponse CohereListModelsResponse
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &cohereResponse, nil, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, bifrostErr
}
// Convert Cohere v2 response to Bifrost response
response := cohereResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered)
response.ExtraFields.Latency = latency.Milliseconds()
// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
response.ExtraFields.RawRequest = rawRequest
}
// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = rawResponse
}
return response, nil
}
// ListModels performs a list models request to Cohere's API.
// Requests are made concurrently for improved performance.
func (provider *CohereProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ListModelsRequest); err != nil {
return nil, err
}
if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess {
return providerUtils.HandleKeylessListModelsRequest(provider.GetProviderKey(), func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
return provider.listModelsByKey(ctx, schemas.Key{}, request)
})
}
return providerUtils.HandleMultipleListModelsRequests(
ctx,
keys,
request,
provider.listModelsByKey,
)
}
// TextCompletion is not supported by the Cohere provider.
// Returns an error indicating that text completion is not supported.
func (provider *CohereProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionRequest, provider.GetProviderKey())
}
// TextCompletionStream performs a streaming text completion request to Cohere's API.
// It formats the request, sends it to Cohere, and processes the response.
// Returns a channel of BifrostStreamChunk objects or an error if the request fails.
func (provider *CohereProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey())
}
// ChatCompletion performs a chat completion request to the Cohere API using v2 converter.
// It formats the request, sends it to Cohere, and processes the response.
// Returns a BifrostResponse containing the completion results or an error if the request fails.
func (provider *CohereProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
// Check if chat completion is allowed
if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ChatCompletionRequest); err != nil {
return nil, err
}
// Convert to Cohere v2 request
jsonBody, err := providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
func() (providerUtils.RequestBodyWithExtraParams, error) {
return ToCohereChatCompletionRequest(request)
})
if err != nil {
return nil, err
}
responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/chat", schemas.ChatCompletionRequest), key.Value.GetValue())
if providerResponseHeaders != nil {
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
}
if err != nil {
return nil, providerUtils.EnrichError(ctx, err, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
// Large response mode: return lightweight response with metadata only
if isLargeResp, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool); isLargeResp {
return &schemas.BifrostChatResponse{
Model: request.Model,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
ProviderResponseHeaders: providerResponseHeaders,
},
}, nil
}
// Create response object from pool
response := acquireCohereResponse()
defer releaseCohereResponse(response)
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, response, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
bifrostResponse := response.ToBifrostChatResponse(request.Model)
// Set ExtraFields
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
bifrostResponse.ExtraFields.RawRequest = rawRequest
}
// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
bifrostResponse.ExtraFields.RawResponse = rawResponse
}
return bifrostResponse, nil
}
// ChatCompletionStream performs a streaming chat completion request to the Cohere API.
// It supports real-time streaming of responses using Server-Sent Events (SSE).
// Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails.
func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
// Check if chat completion stream is allowed
if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil {
return nil, err
}
jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
func() (providerUtils.RequestBodyWithExtraParams, error) {
reqBody, err := ToCohereChatCompletionRequest(request)
if err != nil {
return nil, err
}
reqBody.Stream = schemas.Ptr(true)
return reqBody, nil
})
if bifrostErr != nil {
return nil, bifrostErr
}
sendBackRawRequest := provider.sendBackRawRequest
sendBackRawResponse := provider.sendBackRawResponse
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
resp.StreamBody = true
defer fasthttp.ReleaseRequest(req)
req.Header.SetMethod(http.MethodPost)
req.SetRequestURI(provider.buildRequestURL(ctx, "/v2/chat", schemas.ChatCompletionStreamRequest))
req.Header.SetContentType("application/json")
// Set any extra headers from network config
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
// Set headers
if key.Value.GetValue() != "" {
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
}
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
usedLargePayloadBody := providerUtils.ApplyLargePayloadRequestBodyWithModelNormalization(ctx, req, schemas.Cohere)
if !usedLargePayloadBody {
req.SetBody(jsonBody)
}
// Make the request
err := provider.streamingClient.Do(req, resp)
if usedLargePayloadBody {
providerUtils.DrainLargePayloadRemainder(ctx)
}
if err != nil {
defer providerUtils.ReleaseStreamingResponse(resp)
if errors.Is(err, context.Canceled) {
return nil, providerUtils.EnrichError(ctx, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Type: schemas.Ptr(schemas.RequestCancelled),
Message: schemas.ErrRequestCancelled,
Error: err,
},
}, jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
}
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
}
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
}
// Extract provider response headers before status check so error responses also forward them
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
// Check for HTTP errors
if resp.StatusCode() != fasthttp.StatusOK {
defer providerUtils.ReleaseStreamingResponse(resp)
return nil, providerUtils.EnrichError(ctx, parseCohereError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
}
// Large payload streaming passthrough — pipe raw upstream SSE to client
if providerUtils.SetupStreamingPassthrough(ctx, resp) {
responseChan := make(chan *schemas.BifrostStreamChunk)
close(responseChan)
return responseChan, nil
}
// Create response channel
responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize)
providerUtils.SetStreamIdleTimeoutIfEmpty(ctx, provider.networkConfig.StreamIdleTimeoutInSeconds)
// Start streaming in a goroutine
go func() {
defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer)
defer func() {
if ctx.Err() == context.Canceled {
providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer)
} else if ctx.Err() == context.DeadlineExceeded {
providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer)
}
close(responseChan)
}()
defer providerUtils.ReleaseStreamingResponse(resp)
// Decompress gzip-encoded streams transparently (no-op for non-gzip)
reader, releaseGzip := providerUtils.DecompressStreamBody(resp)
defer releaseGzip()
// Wrap reader with idle timeout to detect stalled streams.
reader, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(reader, resp.BodyStream(), providerUtils.GetStreamIdleTimeout(ctx))
defer stopIdleTimeout()
// Setup cancellation handler to close the raw network stream on ctx cancellation,
// which immediately unblocks any in-progress read (including reads blocked inside a gzip decompression layer).
stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger)
defer stopCancellation()
sseReader := providerUtils.GetSSEDataReader(ctx, reader)
chunkIndex := 0
startTime := time.Now()
lastChunkTime := startTime
var responseID string
for {
// If context was cancelled/timed out, let defer handle it
if ctx.Err() != nil {
return
}
data, readErr := sseReader.ReadDataLine()
if readErr != nil {
if readErr != io.EOF {
if ctx.Err() != nil {
return
}
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
provider.logger.Warn("Error reading stream: %v", readErr)
providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger, postHookSpanFinalizer)
return
}
break
}
eventData := string(data)
// Parse the unified streaming event
var event CohereStreamEvent
if err := sonic.Unmarshal(data, &event); err != nil {
provider.logger.Warn("Failed to parse stream event: %v", err)
continue
}
// Extract response ID from message-start events
if event.Type == StreamEventMessageStart && event.ID != nil {
responseID = *event.ID
}
response, bifrostErr, isLastChunk := event.ToBifrostChatCompletionStream()
if bifrostErr != nil {
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger, postHookSpanFinalizer)
break
}
if response != nil {
response.ID = responseID
response.ExtraFields = schemas.BifrostResponseExtraFields{
ChunkIndex: chunkIndex,
Latency: time.Since(lastChunkTime).Milliseconds(),
}
lastChunkTime = time.Now()
chunkIndex++
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = eventData
}
if isLastChunk {
// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody)
}
response.ExtraFields.Latency = time.Since(startTime).Milliseconds()
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan, postHookSpanFinalizer)
break
}
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan, postHookSpanFinalizer)
}
}
}()
return responseChan, nil
}
// Responses performs a responses request to the Cohere API using v2 converter.
func (provider *CohereProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
// Check if chat completion is allowed
if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ResponsesRequest); err != nil {
return nil, err
}
jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
func() (providerUtils.RequestBodyWithExtraParams, error) {
return ToCohereResponsesRequest(request)
})
if bifrostErr != nil {
return nil, bifrostErr
}
// Convert to Cohere v2 request
responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/chat", schemas.ResponsesRequest), key.Value.GetValue())
if providerResponseHeaders != nil {
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
}
if err != nil {
return nil, providerUtils.EnrichError(ctx, err, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
// Large response mode: return lightweight response with metadata only
if isLargeResp, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool); isLargeResp {
return &schemas.BifrostResponsesResponse{
Model: request.Model,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
ProviderResponseHeaders: providerResponseHeaders,
},
}, nil
}
// Create response object from pool
response := acquireCohereResponse()
defer releaseCohereResponse(response)
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, response, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
bifrostResponse := response.ToBifrostResponsesResponse()
bifrostResponse.Model = request.Model
// Set ExtraFields
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
bifrostResponse.ExtraFields.RawRequest = rawRequest
}
// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
bifrostResponse.ExtraFields.RawResponse = rawResponse
}
return bifrostResponse, nil
}
// ResponsesStream performs a streaming responses request to the Cohere API.
func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
// Check if responses stream is allowed
if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil {
return nil, err
}
// Convert to Cohere v2 request and add streaming
jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
func() (providerUtils.RequestBodyWithExtraParams, error) {
reqBody, err := ToCohereResponsesRequest(request)
if err != nil {
return nil, err
}
if reqBody != nil {
reqBody.Stream = schemas.Ptr(true)
}
return reqBody, nil
})
if bifrostErr != nil {
return nil, bifrostErr
}
sendBackRawRequest := provider.sendBackRawRequest
sendBackRawResponse := provider.sendBackRawResponse
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
resp.StreamBody = true
defer fasthttp.ReleaseRequest(req)
req.Header.SetMethod(http.MethodPost)
req.SetRequestURI(provider.buildRequestURL(ctx, "/v2/chat", schemas.ResponsesStreamRequest))
req.Header.SetContentType("application/json")
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
// Set headers
if key.Value.GetValue() != "" {
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
}
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
usedLargePayloadBody := providerUtils.ApplyLargePayloadRequestBodyWithModelNormalization(ctx, req, schemas.Cohere)
if !usedLargePayloadBody {
req.SetBody(jsonBody)
}
// Make the request
err := provider.streamingClient.Do(req, resp)
if usedLargePayloadBody {
providerUtils.DrainLargePayloadRemainder(ctx)
}
if err != nil {
defer providerUtils.ReleaseStreamingResponse(resp)
if errors.Is(err, context.Canceled) {
return nil, providerUtils.EnrichError(ctx, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Type: schemas.Ptr(schemas.RequestCancelled),
Message: schemas.ErrRequestCancelled,
Error: err,
},
}, jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
}
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
}
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
}
// Extract provider response headers before status check so error responses also forward them
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
// Check for HTTP errors
if resp.StatusCode() != fasthttp.StatusOK {
defer providerUtils.ReleaseStreamingResponse(resp)
return nil, providerUtils.EnrichError(ctx, parseCohereError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
}
// Large payload streaming passthrough — pipe raw upstream SSE to client
if providerUtils.SetupStreamingPassthrough(ctx, resp) {
responseChan := make(chan *schemas.BifrostStreamChunk)
close(responseChan)
return responseChan, nil
}
// Create response channel
responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize)
providerUtils.SetStreamIdleTimeoutIfEmpty(ctx, provider.networkConfig.StreamIdleTimeoutInSeconds)
// Start streaming in a goroutine
go func() {
defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer)
defer func() {
if ctx.Err() == context.Canceled {
providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer)
} else if ctx.Err() == context.DeadlineExceeded {
providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer)
}
close(responseChan)
}()
defer providerUtils.ReleaseStreamingResponse(resp)
// Decompress gzip-encoded streams transparently (no-op for non-gzip)
reader, releaseGzip := providerUtils.DecompressStreamBody(resp)
defer releaseGzip()
// Wrap reader with idle timeout to detect stalled streams.
reader, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(reader, resp.BodyStream(), providerUtils.GetStreamIdleTimeout(ctx))
defer stopIdleTimeout()
// Setup cancellation handler to close the raw network stream on ctx cancellation,
// which immediately unblocks any in-progress read (including reads blocked inside a gzip decompression layer).
stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger)
defer stopCancellation()
sseReader := providerUtils.GetSSEDataReader(ctx, reader)
chunkIndex := 0
startTime := time.Now()
lastChunkTime := startTime
// Create stream state for stateful conversions (outside loop to persist across events)
streamState := acquireCohereResponsesStreamState()
streamState.Model = &request.Model
defer releaseCohereResponsesStreamState(streamState)
for {
// If context was cancelled/timed out, let defer handle it
if ctx.Err() != nil {
return
}
data, readErr := sseReader.ReadDataLine()
if readErr != nil {
if readErr != io.EOF {
if ctx.Err() != nil {
return
}
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
provider.logger.Warn("Error reading stream: %v", readErr)
providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger, postHookSpanFinalizer)
return
}
break
}
eventData := string(data)
// Parse the unified streaming event
var event CohereStreamEvent
if err := sonic.Unmarshal(data, &event); err != nil {
provider.logger.Warn("Failed to parse stream event: %v", err)
continue
}
// Note: response.created and response.in_progress are now emitted by ToBifrostResponsesStream
// from the message_start event, so we don't need to call them manually here
responses, bifrostErr, isLastChunk := event.ToBifrostResponsesStream(chunkIndex, streamState)
if bifrostErr != nil {
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger, postHookSpanFinalizer)
break
}
// Handle each response in the slice
for i, response := range responses {
if response != nil {
response.ExtraFields = schemas.BifrostResponseExtraFields{
ChunkIndex: chunkIndex,
Latency: time.Since(lastChunkTime).Milliseconds(),
}
lastChunkTime = time.Now()
chunkIndex++
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = eventData
}
if isLastChunk && i == len(responses)-1 {
if response.Response == nil {
response.Response = &schemas.BifrostResponsesResponse{}
}
// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody)
}
response.ExtraFields.Latency = time.Since(startTime).Milliseconds()
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan, postHookSpanFinalizer)
return
}
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan, postHookSpanFinalizer)
}
}
}
}()
return responseChan, nil
}
// Embedding generates embeddings for the given input text(s) using the Cohere API.
// Supports Cohere's embedding models and returns a BifrostResponse containing the embedding(s).
func (provider *CohereProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
// Check if embedding is allowed
if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.EmbeddingRequest); err != nil {
return nil, err
}
jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
func() (providerUtils.RequestBodyWithExtraParams, error) {
return ToCohereEmbeddingRequest(request), nil
})
if bifrostErr != nil {
return nil, bifrostErr
}
// Create Bifrost request for conversion
responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/embed", schemas.EmbeddingRequest), key.Value.GetValue())
if providerResponseHeaders != nil {
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
}
if err != nil {
return nil, providerUtils.EnrichError(ctx, err, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
// Large response mode: return lightweight response with metadata only
if isLargeResp, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool); isLargeResp {
return &schemas.BifrostEmbeddingResponse{
Model: request.Model,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
ProviderResponseHeaders: providerResponseHeaders,
},
}, nil
}
// Create response object from pool
response := acquireCohereEmbeddingResponse()
defer releaseCohereEmbeddingResponse(response)
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, response, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
bifrostResponse := response.ToBifrostEmbeddingResponse()
// Set ExtraFields
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
bifrostResponse.ExtraFields.RawRequest = rawRequest
}
// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
bifrostResponse.ExtraFields.RawResponse = rawResponse
}
return bifrostResponse, nil
}
// Rerank performs a rerank request using the Cohere /v2/rerank API.
func (provider *CohereProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) {
// Check if rerank is allowed
if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.RerankRequest); err != nil {
return nil, err
}
jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
func() (providerUtils.RequestBodyWithExtraParams, error) {
return ToCohereRerankRequest(request), nil
})
if bifrostErr != nil {
return nil, bifrostErr
}
responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/rerank", schemas.RerankRequest), key.Value.GetValue())
if providerResponseHeaders != nil {
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
}
if err != nil {
return nil, providerUtils.EnrichError(ctx, err, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
// Large response mode: return lightweight response with metadata only
if isLargeResp, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool); isLargeResp {
return &schemas.BifrostRerankResponse{
Model: request.Model,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
ProviderResponseHeaders: providerResponseHeaders,
},
}, nil
}
// Create response object from pool
response := acquireCohereRerankResponse()
defer releaseCohereRerankResponse(response)
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, response, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
returnDocuments := request.Params != nil && request.Params.ReturnDocuments != nil && *request.Params.ReturnDocuments
bifrostResponse := response.ToBifrostRerankResponse(request.Documents, returnDocuments)
bifrostResponse.Model = request.Model
// Set ExtraFields
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
bifrostResponse.ExtraFields.RawRequest = rawRequest
}
// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
bifrostResponse.ExtraFields.RawResponse = rawResponse
}
return bifrostResponse, nil
}
// OCR is not supported by the Cohere provider.
func (provider *CohereProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostOCRRequest) (*schemas.BifrostOCRResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.OCRRequest, provider.GetProviderKey())
}
// Speech is not supported by the Cohere provider.
func (provider *CohereProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey())
}
// SpeechStream is not supported by the Cohere provider.
func (provider *CohereProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey())
}
// Transcription is not supported by the Cohere provider.
func (provider *CohereProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionRequest, provider.GetProviderKey())
}
// TranscriptionStream is not supported by the Cohere provider.
func (provider *CohereProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey())
}
// ImageGeneration is not supported by the Cohere provider.
func (provider *CohereProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey())
}
// ImageGenerationStream is not supported by the Cohere provider.
func (provider *CohereProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey())
}
// ImageEdit is not supported by the Cohere provider.
func (provider *CohereProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditRequest, provider.GetProviderKey())
}
// ImageEditStream is not supported by the Cohere provider.
func (provider *CohereProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey())
}
// ImageVariation is not supported by the Cohere provider.
func (provider *CohereProvider) ImageVariation(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageVariationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageVariationRequest, provider.GetProviderKey())
}
// VideoGeneration is not supported by the Cohere provider.
func (provider *CohereProvider) VideoGeneration(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoGenerationRequest, provider.GetProviderKey())
}
// VideoRetrieve is not supported by the Cohere provider.
func (provider *CohereProvider) VideoRetrieve(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRetrieveRequest, provider.GetProviderKey())
}
// VideoDownload is not supported by the Cohere provider.
func (provider *CohereProvider) VideoDownload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDownloadRequest, provider.GetProviderKey())
}
// VideoDelete is not supported by Cohere provider.
func (provider *CohereProvider) VideoDelete(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDeleteRequest) (*schemas.BifrostVideoDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDeleteRequest, provider.GetProviderKey())
}
// VideoList is not supported by Cohere provider.
func (provider *CohereProvider) VideoList(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoListRequest) (*schemas.BifrostVideoListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoListRequest, provider.GetProviderKey())
}
// VideoRemix is not supported by Cohere provider.
func (provider *CohereProvider) VideoRemix(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRemixRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRemixRequest, provider.GetProviderKey())
}
// BatchCreate is not supported by Cohere provider.
func (provider *CohereProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey())
}
// BatchList is not supported by Cohere provider.
func (provider *CohereProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey())
}
// BatchRetrieve is not supported by Cohere provider.
func (provider *CohereProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey())
}
// BatchCancel is not supported by Cohere provider.
func (provider *CohereProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey())
}
// BatchDelete is not supported by Cohere provider.
func (provider *CohereProvider) BatchDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchDeleteRequest) (*schemas.BifrostBatchDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchDeleteRequest, provider.GetProviderKey())
}
// BatchResults is not supported by Cohere provider.
func (provider *CohereProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey())
}
// FileUpload is not supported by Cohere provider.
func (provider *CohereProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey())
}
// FileList is not supported by Cohere provider.
func (provider *CohereProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey())
}
// FileRetrieve is not supported by Cohere provider.
func (provider *CohereProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey())
}
// FileDelete is not supported by Cohere provider.
func (provider *CohereProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey())
}
// FileContent is not supported by Cohere provider.
func (provider *CohereProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey())
}
// CountTokens performs a token counting request via Cohere's /v1/tokenize API.
func (provider *CohereProvider) CountTokens(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) {
if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.CountTokensRequest); err != nil {
return nil, err
}
jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
func() (providerUtils.RequestBodyWithExtraParams, error) {
return ToCohereCountTokensRequest(request)
})
if bifrostErr != nil {
return nil, bifrostErr
}
responseBody, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(
ctx,
jsonBody,
provider.buildRequestURL(ctx, "/v1/tokenize", schemas.CountTokensRequest),
key.Value.GetValue(),
)
if providerResponseHeaders != nil {
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
}
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
// Large response mode: return lightweight response with metadata only
if isLargeResp, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool); isLargeResp {
return &schemas.BifrostCountTokensResponse{
Model: request.Model,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
ProviderResponseHeaders: providerResponseHeaders,
},
}, nil
}
cohereResponse := &CohereCountTokensResponse{}
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(
responseBody,
cohereResponse,
jsonBody,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
)
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
bifrostResponse := cohereResponse.ToBifrostCountTokensResponse(request.Model)
if bifrostResponse == nil {
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, fmt.Errorf("nil cohere count tokens response")), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
bifrostResponse.ExtraFields.RawRequest = rawRequest
}
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
bifrostResponse.ExtraFields.RawResponse = rawResponse
}
return bifrostResponse, nil
}
// ContainerCreate is not supported by the Cohere provider.
func (provider *CohereProvider) ContainerCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerCreateRequest) (*schemas.BifrostContainerCreateResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerCreateRequest, provider.GetProviderKey())
}
// ContainerList is not supported by the Cohere provider.
func (provider *CohereProvider) ContainerList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerListRequest) (*schemas.BifrostContainerListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerListRequest, provider.GetProviderKey())
}
// ContainerRetrieve is not supported by the Cohere provider.
func (provider *CohereProvider) ContainerRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerRetrieveRequest) (*schemas.BifrostContainerRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerRetrieveRequest, provider.GetProviderKey())
}
// ContainerDelete is not supported by the Cohere provider.
func (provider *CohereProvider) ContainerDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerDeleteRequest) (*schemas.BifrostContainerDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerDeleteRequest, provider.GetProviderKey())
}
// ContainerFileCreate is not supported by the Cohere provider.
func (provider *CohereProvider) ContainerFileCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerFileCreateRequest) (*schemas.BifrostContainerFileCreateResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileCreateRequest, provider.GetProviderKey())
}
// ContainerFileList is not supported by the Cohere provider.
func (provider *CohereProvider) ContainerFileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileListRequest) (*schemas.BifrostContainerFileListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileListRequest, provider.GetProviderKey())
}
// ContainerFileRetrieve is not supported by the Cohere provider.
func (provider *CohereProvider) ContainerFileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileRetrieveRequest) (*schemas.BifrostContainerFileRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileRetrieveRequest, provider.GetProviderKey())
}
// ContainerFileContent is not supported by the Cohere provider.
func (provider *CohereProvider) ContainerFileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileContentRequest) (*schemas.BifrostContainerFileContentResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileContentRequest, provider.GetProviderKey())
}
// ContainerFileDelete is not supported by the Cohere provider.
func (provider *CohereProvider) ContainerFileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileDeleteRequest) (*schemas.BifrostContainerFileDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileDeleteRequest, provider.GetProviderKey())
}
// Passthrough is not supported by the Cohere provider.
func (provider *CohereProvider) Passthrough(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey())
}
func (provider *CohereProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey())
}