Files
bifrost/core/providers/azure/azure.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

2959 lines
105 KiB
Go

// Package azure implements the Azure provider.
package azure
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"maps"
"mime/multipart"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/providers/anthropic"
"github.com/maximhq/bifrost/core/providers/openai"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
// AzureAuthorizationTokenKey is the context key for the Azure authentication token.
const AzureAuthorizationTokenKey schemas.BifrostContextKey = "azure-authorization-token"
// DefaultAzureScope is the default scope for Azure authentication.
const DefaultAzureScope = "https://cognitiveservices.azure.com/.default"
// AzureProvider implements the Provider interface for Azure's API.
type AzureProvider struct {
logger schemas.Logger // Logger for provider operations
client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response)
streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader)
networkConfig schemas.NetworkConfig // Network configuration including extra headers
credentials sync.Map // map of tenant ID:client ID to azcore.TokenCredential
sendBackRawRequest bool // Whether to include raw request in BifrostResponse
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
}
func (p *AzureProvider) getOrCreateAuth(
tenantID, clientID, clientSecret string,
) (azcore.TokenCredential, error) {
key := tenantID + ":" + clientID
// Fast path
if val, ok := p.credentials.Load(key); ok {
return val.(azcore.TokenCredential), nil
}
// Slow path - create new credential
cred, err := azidentity.NewClientSecretCredential(
tenantID,
clientID,
clientSecret,
nil,
)
if err != nil {
return nil, err
}
actual, _ := p.credentials.LoadOrStore(key, cred)
return actual.(azcore.TokenCredential), nil
}
// getOrCreateDefaultAzureCredential returns a DefaultAzureCredential, creating and caching it if needed.
// It automatically detects the auth environment: managed identity on Azure VMs/containers,
// workload identity in AKS, environment variables, Azure CLI, and more.
func (p *AzureProvider) getOrCreateDefaultAzureCredential() (azcore.TokenCredential, error) {
const cacheKey = "default_azure_credential"
if val, ok := p.credentials.Load(cacheKey); ok {
return val.(azcore.TokenCredential), nil
}
cred, err := azidentity.NewDefaultAzureCredential(nil)
if err != nil {
return nil, err
}
actual, _ := p.credentials.LoadOrStore(cacheKey, cred)
return actual.(azcore.TokenCredential), nil
}
// getAzureAuthHeaders returns authentication headers based on priority:
// 1. Service Principal (client ID/secret/tenant ID) - Bearer token
// 2. Context token - Bearer token
// 3. API key - api-key or x-api-key header
func (provider *AzureProvider) getAzureAuthHeaders(ctx *schemas.BifrostContext, key schemas.Key, isAnthropicModel bool) (map[string]string, *schemas.BifrostError) {
authHeader := make(map[string]string)
// Service Principal authentication
if key.AzureKeyConfig != nil && key.AzureKeyConfig.ClientID != nil &&
key.AzureKeyConfig.ClientSecret != nil && key.AzureKeyConfig.TenantID != nil && key.AzureKeyConfig.ClientID.GetValue() != "" && key.AzureKeyConfig.ClientSecret.GetValue() != "" && key.AzureKeyConfig.TenantID.GetValue() != "" {
cred, err := provider.getOrCreateAuth(key.AzureKeyConfig.TenantID.GetValue(), key.AzureKeyConfig.ClientID.GetValue(), key.AzureKeyConfig.ClientSecret.GetValue())
if err != nil {
return nil, providerUtils.NewBifrostOperationError("failed to get or create Azure authentication", err)
}
scopes := getAzureScopes(key.AzureKeyConfig.Scopes)
token, err := cred.GetToken(ctx, policy.TokenRequestOptions{
Scopes: scopes,
})
if err != nil {
return nil, providerUtils.NewBifrostOperationError("failed to get Azure access token", err)
}
if token.Token == "" {
return nil, providerUtils.NewBifrostOperationError("Azure access token is empty", errors.New("token is empty"))
}
authHeader["Authorization"] = fmt.Sprintf("Bearer %s", token.Token)
return authHeader, nil
}
// Context token authentication
if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok && authToken != "" {
authHeader["Authorization"] = fmt.Sprintf("Bearer %s", authToken)
return authHeader, nil
}
value := key.Value.GetValue()
if value == "" {
// No explicit credentials provided - attempt DefaultAzureCredential auto-detection.
// This covers managed identity on Azure VMs/containers, workload identity in AKS,
// environment variables, Azure CLI, and more - with no config required.
scopes := getAzureScopes(nil)
if key.AzureKeyConfig != nil {
scopes = getAzureScopes(key.AzureKeyConfig.Scopes)
}
cred, err := provider.getOrCreateDefaultAzureCredential()
if err != nil {
return nil, providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential unavailable", err)
}
token, err := cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: scopes})
if err != nil {
return nil, providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential failed to get token", err)
}
if token.Token == "" {
return nil, providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential returned empty token", errors.New("token is empty"))
}
authHeader["Authorization"] = fmt.Sprintf("Bearer %s", token.Token)
return authHeader, nil
}
// API key authentication
if isAnthropicModel {
authHeader["x-api-key"] = value
} else {
authHeader["api-key"] = value
}
return authHeader, nil
}
// NewAzureProvider creates a new Azure provider instance.
// It initializes the HTTP client with the provided configuration and sets up response pools.
// The client is configured with timeouts, concurrency limits, and optional proxy settings.
func NewAzureProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*AzureProvider, error) {
config.CheckAndSetDefaults()
requestTimeout := time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)
client := &fasthttp.Client{
ReadTimeout: requestTimeout,
WriteTimeout: requestTimeout,
MaxConnsPerHost: config.NetworkConfig.MaxConnsPerHost,
MaxIdleConnDuration: 30 * time.Second,
MaxConnWaitTimeout: requestTimeout,
MaxConnDuration: time.Second * time.Duration(schemas.DefaultMaxConnDurationInSeconds),
ConnPoolStrategy: fasthttp.FIFO,
}
// Configure proxy and retry policy
client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger)
client = providerUtils.ConfigureDialer(client)
client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger)
streamingClient := providerUtils.BuildStreamingClient(client)
return &AzureProvider{
logger: logger,
client: client,
streamingClient: streamingClient,
networkConfig: config.NetworkConfig,
sendBackRawRequest: config.SendBackRawRequest,
sendBackRawResponse: config.SendBackRawResponse,
}, nil
}
// GetProviderKey returns the provider identifier for Azure.
func (provider *AzureProvider) GetProviderKey() schemas.ModelProvider {
return schemas.Azure
}
// completeRequest sends a request to Azure's API and handles the response.
// It constructs the API URL, sets up authentication, and processes the response.
// Returns the response body, request latency, or an error if the request fails.
func (provider *AzureProvider) completeRequest(
ctx *schemas.BifrostContext,
jsonData []byte,
path string,
key schemas.Key,
model string,
) ([]byte, time.Duration, map[string]string, *schemas.BifrostError) {
// Create the request with the JSON body
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
respOwned := true
defer func() {
if respOwned {
fasthttp.ReleaseResponse(resp)
}
}()
var url string
isAnthropicModel := schemas.IsAnthropicModel(model)
// Set any extra headers from network config.
// For Anthropic models, exclude anthropic-beta — it is merged and filtered explicitly below.
if isAnthropicModel {
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, []string{anthropic.AnthropicBetaHeader})
} else {
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
}
req.Header.SetMethod(http.MethodPost)
req.Header.SetContentType("application/json")
// Get authentication headers
authHeaders, bifrostErr := provider.getAzureAuthHeaders(ctx, key, isAnthropicModel)
if bifrostErr != nil {
return nil, 0, nil, bifrostErr
}
// Apply headers to request
for k, v := range authHeaders {
req.Header.Set(k, v)
}
endpoint := key.AzureKeyConfig.Endpoint.GetValue()
if endpoint == "" {
return nil, 0, nil, providerUtils.NewConfigurationError("endpoint not set")
}
if isAnthropicModel {
req.Header.Set("anthropic-version", AzureAnthropicAPIVersionDefault)
url = fmt.Sprintf("%s/%s", endpoint, path)
// Merge ExtraHeaders + context anthropic-beta, filter for Azure, then set as HTTP header
if betaHeaders := anthropic.FilterBetaHeadersForProvider(anthropic.MergeBetaHeaders(provider.networkConfig.ExtraHeaders, ctx), schemas.Azure, provider.networkConfig.BetaHeaderOverrides); len(betaHeaders) > 0 {
req.Header.Set(anthropic.AnthropicBetaHeader, strings.Join(betaHeaders, ","))
} else {
req.Header.Del(anthropic.AnthropicBetaHeader)
}
} else {
apiVersion := key.AzureKeyConfig.APIVersion
if apiVersion == nil {
apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault)
}
if path == "openai/v1/responses" {
url = fmt.Sprintf("%s/%s?api-version=preview", endpoint, path)
} else {
url = fmt.Sprintf("%s/%s?api-version=%s", endpoint, path, apiVersion.GetValue())
}
}
req.SetRequestURI(url)
if !providerUtils.ApplyLargePayloadRequestBodyWithModelNormalization(ctx, req, schemas.OpenAI) {
req.SetBody(jsonData)
}
// Send the request with optional large response streaming
activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.client, resp)
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
defer wait()
if bifrostErr != nil {
return nil, latency, nil, bifrostErr
}
// Extract provider response headers before body is copied — do this before status check
// so error responses also carry provider headers (rate-limit info, request IDs, etc.)
providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp)
// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
providerUtils.MaterializeStreamErrorBody(ctx, resp)
rawErrBody := append([]byte(nil), resp.Body()...)
return rawErrBody, latency, providerResponseHeaders, openai.ParseOpenAIError(resp)
}
body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger)
if decodeErr != nil {
return nil, latency, providerResponseHeaders, decodeErr
}
if isLargeResp {
respOwned = false
return nil, latency, providerResponseHeaders, nil
}
return body, latency, providerResponseHeaders, nil
}
// listModelsByKey performs a list models request for a single key.
// Returns the response and latency, or an error if the request fails.
func (provider *AzureProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
// Get API version
apiVersion := key.AzureKeyConfig.APIVersion
if apiVersion == nil {
apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault)
}
// Create the request
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
// Set any extra headers from network config
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
req.SetRequestURI(key.AzureKeyConfig.Endpoint.GetValue() + providerUtils.GetPathFromContext(ctx, fmt.Sprintf("/openai/models?api-version=%s", apiVersion.GetValue())))
req.Header.SetMethod(http.MethodGet)
req.Header.SetContentType("application/json")
// Set Azure authentication
authHeaders, bifrostErr := provider.getAzureAuthHeaders(ctx, key, false)
if bifrostErr != nil {
return nil, bifrostErr
}
for k, v := range authHeaders {
req.Header.Set(k, v)
}
// Send the request and measure latency
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
defer wait()
if bifrostErr != nil {
return nil, bifrostErr
}
// Store provider response headers in context before status check so error responses also forward them
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
return nil, openai.ParseOpenAIError(resp)
}
body, err := providerUtils.CheckAndDecodeBody(resp)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
}
// Read the response body and copy it before releasing the response
// to avoid use-after-free since resp.Body() references fasthttp's internal buffer
responseBody := append([]byte(nil), body...)
// Parse Azure-specific response
azureResponse := &AzureListModelsResponse{}
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, azureResponse, nil, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, bifrostErr
}
// Convert to Bifrost response
response := azureResponse.ToBifrostListModelsResponse(key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered)
if response == nil {
return nil, providerUtils.NewBifrostOperationError("failed to convert Azure model list response", nil)
}
response.ExtraFields.Latency = latency.Milliseconds()
// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
response.ExtraFields.RawRequest = rawRequest
}
// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = rawResponse
}
return response, nil
}
// ListModels performs a list models request to Azure's API.
// It retrieves all models accessible by the Azure resource
// Requests are made concurrently for improved performance.
func (provider *AzureProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
return providerUtils.HandleMultipleListModelsRequests(
ctx,
keys,
request,
provider.listModelsByKey,
)
}
// TextCompletion performs a text completion request to Azure's API.
// It formats the request, sends it to Azure, and processes the response.
// Returns a BifrostResponse containing the completion results or an error if the request fails.
func (provider *AzureProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) {
// Use centralized OpenAI text converter (Azure is OpenAI-compatible)
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
func() (providerUtils.RequestBodyWithExtraParams, error) {
return openai.ToOpenAITextCompletionRequest(request), nil
})
if bifrostErr != nil {
return nil, bifrostErr
}
responseBody, latency, providerResponseHeaders, err := provider.completeRequest(
ctx,
jsonData,
fmt.Sprintf("openai/deployments/%s/completions", request.Model),
key,
request.Model,
)
if providerResponseHeaders != nil {
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
}
if err != nil {
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
// Large response mode: return lightweight response with metadata only
if isLargeResp, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool); isLargeResp {
return &schemas.BifrostTextCompletionResponse{
Model: request.Model,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
ProviderResponseHeaders: providerResponseHeaders,
},
}, nil
}
response := &schemas.BifrostTextCompletionResponse{}
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, response, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
response.ExtraFields.Latency = latency.Milliseconds()
response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
response.ExtraFields.RawRequest = rawRequest
}
// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = rawResponse
}
return response, nil
}
// TextCompletionStream performs a streaming text completion request to Azure's API.
// It formats the request, sends it to Azure, and processes the response.
// Returns a channel of BifrostStreamChunk objects or an error if the request fails.
func (provider *AzureProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
apiVersion := key.AzureKeyConfig.APIVersion
if apiVersion == nil {
apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault)
}
url := fmt.Sprintf("%s/openai/deployments/%s/completions?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), request.Model, apiVersion.GetValue())
// Get Azure authentication headers
authHeader, err := provider.getAzureAuthHeaders(ctx, key, false)
if err != nil {
return nil, err
}
return openai.HandleOpenAITextCompletionStreaming(
ctx,
provider.streamingClient,
url,
request,
authHeader,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
nil,
postHookRunner,
nil,
nil,
provider.logger,
postHookSpanFinalizer,
)
}
// ChatCompletion performs a chat completion request to Azure's API.
// It formats the request, sends it to Azure, and processes the response.
// Returns a BifrostResponse containing the completion results or an error if the request fails.
func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
func() (providerUtils.RequestBodyWithExtraParams, error) {
if schemas.IsAnthropicModel(request.Model) {
reqBody, err := anthropic.ToAnthropicChatRequest(ctx, request)
if err != nil {
return nil, err
}
if reqBody != nil {
// Add provider-aware beta headers for Azure
anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Azure)
}
return reqBody, nil
} else {
return openai.ToOpenAIChatRequest(ctx, request), nil
}
})
if bifrostErr != nil {
return nil, bifrostErr
}
var path string
if schemas.IsAnthropicModel(request.Model) {
path = "anthropic/v1/messages"
} else {
path = fmt.Sprintf("openai/deployments/%s/chat/completions", request.Model)
}
responseBody, latency, providerResponseHeaders, err := provider.completeRequest(
ctx,
jsonData,
path,
key,
request.Model,
)
if providerResponseHeaders != nil {
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
}
if err != nil {
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
// Large response mode: return lightweight response with metadata only
if isLargeResp, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool); isLargeResp {
return &schemas.BifrostChatResponse{
Model: request.Model,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
ProviderResponseHeaders: providerResponseHeaders,
},
}, nil
}
response := &schemas.BifrostChatResponse{}
var rawRequest interface{}
var rawResponse interface{}
if schemas.IsAnthropicModel(request.Model) {
anthropicResponse := anthropic.AcquireAnthropicMessageResponse()
defer anthropic.ReleaseAnthropicMessageResponse(anthropicResponse)
rawRequest, rawResponse, bifrostErr = providerUtils.HandleProviderResponse(responseBody, anthropicResponse, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
response = anthropicResponse.ToBifrostChatResponse(ctx)
} else {
rawRequest, rawResponse, bifrostErr = providerUtils.HandleProviderResponse(responseBody, response, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
}
response.ExtraFields.Latency = latency.Milliseconds()
response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
response.ExtraFields.RawRequest = rawRequest
}
// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = rawResponse
}
return response, nil
}
// ChatCompletionStream performs a streaming chat completion request to Azure's API.
// It supports real-time streaming of responses using Server-Sent Events (SSE).
// Uses Azure-specific URL construction with deployments and supports both api-key and Bearer token authentication.
// Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails.
func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
var url string
if schemas.IsAnthropicModel(request.Model) {
authHeader, err := provider.getAzureAuthHeaders(ctx, key, true)
if err != nil {
return nil, err
}
authHeader["anthropic-version"] = AzureAnthropicAPIVersionDefault
url = fmt.Sprintf("%s/anthropic/v1/messages", key.AzureKeyConfig.Endpoint.GetValue())
jsonData, err := providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
func() (providerUtils.RequestBodyWithExtraParams, error) {
reqBody, err := anthropic.ToAnthropicChatRequest(ctx, request)
if err != nil {
return nil, err
}
if reqBody != nil {
reqBody.Stream = schemas.Ptr(true)
// Add provider-aware beta headers for Azure
anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Azure)
}
return reqBody, nil
})
if err != nil {
return nil, err
}
// Use shared streaming logic from Anthropic
return anthropic.HandleAnthropicChatCompletionStreaming(
ctx,
provider.streamingClient,
url,
jsonData,
authHeader,
provider.networkConfig.ExtraHeaders,
provider.networkConfig.BetaHeaderOverrides,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
postHookRunner,
nil,
provider.logger,
postHookSpanFinalizer,
)
} else {
authHeader, err := provider.getAzureAuthHeaders(ctx, key, false)
if err != nil {
return nil, err
}
apiVersion := key.AzureKeyConfig.APIVersion
if apiVersion == nil {
apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault)
}
url = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), request.Model, apiVersion.GetValue())
// Use shared streaming logic from OpenAI
return openai.HandleOpenAIChatCompletionStreaming(
ctx,
provider.streamingClient,
url,
request,
authHeader,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
postHookRunner,
nil,
nil,
nil,
nil,
nil,
provider.logger,
postHookSpanFinalizer,
)
}
}
// Responses performs a responses request to Azure's API.
// It formats the request, sends it to Azure, and processes the response.
// Returns a BifrostResponse containing the completion results or an error if the request fails.
func (provider *AzureProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
var jsonData []byte
var bifrostErr *schemas.BifrostError
if schemas.IsAnthropicModel(request.Model) {
jsonData, bifrostErr = getRequestBodyForAnthropicResponses(ctx, request, request.Model, false)
} else {
jsonData, bifrostErr = providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
func() (providerUtils.RequestBodyWithExtraParams, error) {
reqBody := openai.ToOpenAIResponsesRequest(request)
return reqBody, nil
})
}
if bifrostErr != nil {
return nil, bifrostErr
}
var path string
if schemas.IsAnthropicModel(request.Model) {
path = "anthropic/v1/messages"
} else {
path = "openai/v1/responses"
}
responseBody, latency, providerResponseHeaders, err := provider.completeRequest(
ctx,
jsonData,
path,
key,
request.Model,
)
if providerResponseHeaders != nil {
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
}
if err != nil {
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
// Large response mode: return lightweight response with metadata only
if isLargeResp, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool); isLargeResp {
return &schemas.BifrostResponsesResponse{
Model: request.Model,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
ProviderResponseHeaders: providerResponseHeaders,
},
}, nil
}
response := &schemas.BifrostResponsesResponse{}
var rawRequest interface{}
var rawResponse interface{}
if schemas.IsAnthropicModel(request.Model) {
anthropicResponse := anthropic.AcquireAnthropicMessageResponse()
defer anthropic.ReleaseAnthropicMessageResponse(anthropicResponse)
rawRequest, rawResponse, bifrostErr = providerUtils.HandleProviderResponse(responseBody, anthropicResponse, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
response = anthropicResponse.ToBifrostResponsesResponse(ctx)
} else {
rawRequest, rawResponse, bifrostErr = providerUtils.HandleProviderResponse(responseBody, response, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
}
response.ExtraFields.Latency = latency.Milliseconds()
response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
response.ExtraFields.RawRequest = rawRequest
}
// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = rawResponse
}
return response, nil
}
// ResponsesStream performs a streaming responses request to Azure's API.
func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
var url string
if schemas.IsAnthropicModel(request.Model) {
authHeader, err := provider.getAzureAuthHeaders(ctx, key, true)
if err != nil {
return nil, err
}
authHeader["anthropic-version"] = AzureAnthropicAPIVersionDefault
url = fmt.Sprintf("%s/anthropic/v1/messages", key.AzureKeyConfig.Endpoint.GetValue())
jsonData, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, request.Model, true)
if bifrostErr != nil {
return nil, bifrostErr
}
// Use shared streaming logic from Anthropic
return anthropic.HandleAnthropicResponsesStream(
ctx,
provider.streamingClient,
url,
jsonData,
authHeader,
provider.networkConfig.ExtraHeaders,
provider.networkConfig.BetaHeaderOverrides,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
postHookRunner,
nil,
provider.logger,
postHookSpanFinalizer,
)
} else {
authHeader, err := provider.getAzureAuthHeaders(ctx, key, false)
if err != nil {
return nil, err
}
url = fmt.Sprintf("%s/openai/v1/responses?api-version=preview", key.AzureKeyConfig.Endpoint.GetValue())
// Use shared streaming logic from OpenAI
return openai.HandleOpenAIResponsesStreaming(
ctx,
provider.streamingClient,
url,
request,
authHeader,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
postHookRunner,
nil,
nil,
nil,
nil,
provider.logger,
postHookSpanFinalizer,
)
}
}
// Embedding generates embeddings for the given input text(s) using Azure.
// The input can be either a single string or a slice of strings for batch embedding.
// Returns a BifrostResponse containing the embedding(s) and any error that occurred.
func (provider *AzureProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
// Use centralized converter
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
func() (providerUtils.RequestBodyWithExtraParams, error) {
return openai.ToOpenAIEmbeddingRequest(request), nil
})
if bifrostErr != nil {
return nil, bifrostErr
}
responseBody, latency, providerResponseHeaders, err := provider.completeRequest(
ctx,
jsonData,
fmt.Sprintf("openai/deployments/%s/embeddings", request.Model),
key,
request.Model,
)
if providerResponseHeaders != nil {
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
}
if err != nil {
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
// Large response mode: return lightweight response with metadata only
if isLargeResp, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool); isLargeResp {
return &schemas.BifrostEmbeddingResponse{
Model: request.Model,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
ProviderResponseHeaders: providerResponseHeaders,
},
}, nil
}
response := &schemas.BifrostEmbeddingResponse{}
// Use enhanced response handler with pre-allocated response
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, response, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
response.ExtraFields.Latency = latency.Milliseconds()
response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
response.ExtraFields.RawRequest = rawRequest
}
// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = rawResponse
}
return response, nil
}
// Speech is not supported by the Azure provider.
func (provider *AzureProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
apiVersion := key.AzureKeyConfig.APIVersion
if apiVersion == nil {
apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault)
}
endpoint := key.AzureKeyConfig.Endpoint.GetValue()
if endpoint == "" {
return nil, providerUtils.NewConfigurationError("endpoint not set")
}
url := fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", endpoint, request.Model, apiVersion.GetValue())
response, err := openai.HandleOpenAISpeechRequest(
ctx,
provider.client,
url,
request,
key,
provider.networkConfig.ExtraHeaders,
provider.GetProviderKey(),
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
nil,
provider.logger,
)
if err != nil {
return nil, err
}
return response, err
}
// Rerank is not supported by the Azure provider.
func (provider *AzureProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.RerankRequest, provider.GetProviderKey())
}
// OCR is not supported by the Azure provider.
func (provider *AzureProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostOCRRequest) (*schemas.BifrostOCRResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.OCRRequest, provider.GetProviderKey())
}
// SpeechStream handles streaming for speech synthesis with Azure.
// Azure sends raw binary audio bytes in SSE format, unlike OpenAI which sends JSON.
func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
// Get Azure authentication headers
authHeader, err := provider.getAzureAuthHeaders(ctx, key, false)
if err != nil {
return nil, err
}
apiVersion := key.AzureKeyConfig.APIVersion
if apiVersion == nil {
apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault)
}
url := fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), request.Model, apiVersion.GetValue())
// Create HTTP request for streaming
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
resp.StreamBody = true
defer fasthttp.ReleaseRequest(req)
// Prepare headers
headers := map[string]string{
"Content-Type": "application/json",
"Accept": "text/event-stream",
"Cache-Control": "no-cache",
"Accept-Encoding": "identity",
}
maps.Copy(headers, authHeader)
req.Header.SetMethod(http.MethodPost)
req.SetRequestURI(url)
req.Header.SetContentType("application/json")
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
// Set headers
for k, v := range headers {
req.Header.Set(k, v)
}
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
// Build request body
jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
func() (providerUtils.RequestBodyWithExtraParams, error) {
reqBody := openai.ToOpenAISpeechRequest(request)
if reqBody != nil {
reqBody.StreamFormat = schemas.Ptr("sse")
}
return reqBody, nil
})
if bifrostErr != nil {
return nil, bifrostErr
}
if !providerUtils.ApplyLargePayloadRequestBodyWithModelNormalization(ctx, req, schemas.OpenAI) {
req.SetBody(jsonBody)
}
// Make the request
requestErr := provider.client.Do(req, resp)
if requestErr != nil {
defer providerUtils.ReleaseStreamingResponse(resp)
if errors.Is(requestErr, context.Canceled) {
return nil, providerUtils.EnrichError(ctx, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Type: schemas.Ptr(schemas.RequestCancelled),
Message: schemas.ErrRequestCancelled,
Error: requestErr,
},
}, jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
}
if errors.Is(requestErr, fasthttp.ErrTimeout) || errors.Is(requestErr, context.DeadlineExceeded) {
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, requestErr), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
}
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, requestErr), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
}
// Extract provider response headers before status check so error responses also forward them
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
// Check for HTTP errors
if resp.StatusCode() != fasthttp.StatusOK {
defer providerUtils.ReleaseStreamingResponse(resp)
return nil, providerUtils.EnrichError(ctx, openai.ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
}
// Create response channel
responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize)
providerUtils.SetStreamIdleTimeoutIfEmpty(ctx, provider.networkConfig.StreamIdleTimeoutInSeconds)
// Start streaming in a goroutine
go func() {
defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer)
defer func() {
if ctx.Err() == context.Canceled {
providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer)
} else if ctx.Err() == context.DeadlineExceeded {
providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer)
}
close(responseChan)
}()
// Always release response on exit; bodyStream close should prevent indefinite blocking.
defer providerUtils.ReleaseStreamingResponse(resp)
reader, releaseGzip := providerUtils.DecompressStreamBody(resp)
defer releaseGzip()
// Wrap reader with idle timeout to detect stalled streams (e.g., Azure TPM throttling
// that stops sending data but keeps the TCP connection open).
reader, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(reader, resp.BodyStream(), providerUtils.GetStreamIdleTimeout(ctx))
defer stopIdleTimeout()
// Setup cancellation handler to close the raw network stream on ctx cancellation,
// which immediately unblocks any in-progress read (including reads blocked inside a gzip decompression layer).
stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger)
defer stopCancellation()
chunkIndex := -1
startTime := time.Now()
lastChunkTime := startTime
// Read SSE events manually to handle binary data with embedded newlines
// SSE format: "data: <content>\n\n" - events are separated by double newlines
// We can't use bufio.Scanner because MP3 data contains 0x0a bytes which get interpreted as newlines
readBuffer := make([]byte, 64*1024) // 64KB read chunks
var accumulated []byte
doneReceived := false
for {
// If context was cancelled/timed out, let defer handle it
if ctx.Err() != nil {
return
}
// Read from stream
n, readErr := reader.Read(readBuffer)
if n > 0 {
accumulated = append(accumulated, readBuffer[:n]...)
// Process complete SSE events (separated by \r\n\r\n or \n\n)
for {
// Find the next double-newline separator (try CRLF first, then LF)
var idx int
var separatorLen int
idx = bytes.Index(accumulated, []byte("\r\n\r\n"))
if idx != -1 {
separatorLen = 4 // \r\n\r\n
} else {
idx = bytes.Index(accumulated, []byte("\n\n"))
if idx != -1 {
separatorLen = 2 // \n\n
}
}
if idx == -1 {
// No complete event yet, need more data
break
}
// Extract the event (everything up to the separator)
event := accumulated[:idx]
accumulated = accumulated[idx+separatorLen:] // Skip the separator
// Skip empty events and comments
if len(event) == 0 || bytes.HasPrefix(event, []byte(":")) {
continue
}
// Parse the SSE event
var audioData []byte
// Check if this has "data: " prefix (standard SSE format)
if bytes.HasPrefix(event, []byte("data: ")) {
audioData = event[6:] // Skip "data: " prefix
// Check for [DONE] marker - break out of loops to send final response
if bytes.Equal(audioData, []byte("[DONE]")) {
doneReceived = true
break
}
} else {
// Raw data without prefix (shouldn't happen with Azure, but handle it)
audioData = event
}
// Skip empty data
if len(audioData) == 0 {
continue
}
// Azure sends JSON-wrapped responses for speech streaming
// Parse the JSON to extract the response type and audio data
var response schemas.BifrostSpeechStreamResponse
if err := sonic.Unmarshal(audioData, &response); err != nil {
// If JSON parsing fails, check if this might be an error response
// Quick check for error field (allocation-free using sonic.Get)
if errorNode, _ := sonic.Get(audioData, "error"); errorNode.Exists() {
// Only unmarshal when we know there's an error
var bifrostErr schemas.BifrostError
if errParseErr := sonic.Unmarshal(audioData, &bifrostErr); errParseErr == nil {
if bifrostErr.Error != nil && bifrostErr.Error.Message != "" {
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, &bifrostErr, responseChan, provider.logger, postHookSpanFinalizer)
return
}
}
}
// If it's not valid JSON, log and skip
provider.logger.Warn("failed to parse speech stream response: %v", err)
continue
}
// Check for completion event - skip if no audio data
if response.Type == schemas.SpeechStreamResponseTypeDone || len(response.Audio) == 0 {
// This is a control event or empty response - skip
continue
}
chunkIndex++
// Set extra fields for the response
response.ExtraFields = schemas.BifrostResponseExtraFields{
ChunkIndex: chunkIndex,
Latency: time.Since(lastChunkTime).Milliseconds(),
}
lastChunkTime = time.Now()
if sendBackRawResponse {
response.ExtraFields.RawResponse = audioData
}
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, &response, nil, nil), responseChan, postHookSpanFinalizer)
}
// Check if we received [DONE] marker - break outer loop to send final response
if doneReceived {
break
}
}
// Handle read errors
if readErr != nil {
// If context was cancelled/timed out, let defer handle it
if ctx.Err() != nil {
return
}
if readErr != io.EOF {
// Non-EOF errors (e.g., connection reset by peer due to TPM throttling)
// must be reported to the client instead of falling through to send
// a fake "done" response with truncated audio.
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
provider.logger.Warn("Error reading stream: %v", readErr)
providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger, postHookSpanFinalizer)
return
}
break
}
}
// Send final "done" response only if we received the [DONE] marker from the provider.
// Without [DONE], the stream ended abnormally (e.g., clean EOF without proper SSE termination).
if chunkIndex >= 0 && doneReceived {
finalResponse := schemas.BifrostSpeechStreamResponse{
Type: schemas.SpeechStreamResponseTypeDone,
ExtraFields: schemas.BifrostResponseExtraFields{
ChunkIndex: chunkIndex + 1,
Latency: time.Since(startTime).Milliseconds(),
},
}
if sendBackRawRequest {
providerUtils.ParseAndSetRawRequest(&finalResponse.ExtraFields, jsonBody)
}
finalResponse.BackfillParams(request)
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, &finalResponse, nil, nil), responseChan, postHookSpanFinalizer)
} else if chunkIndex >= 0 && !doneReceived {
provider.logger.Warn("Stream ended without receiving [DONE] marker after %d chunks", chunkIndex+1)
}
// Response is released via deferred ReleaseStreamingResponse(resp) above.
}()
return responseChan, nil
}
// Transcription is not supported by the Azure provider.
func (provider *AzureProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
apiVersion := key.AzureKeyConfig.APIVersion
if apiVersion == nil {
apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault)
}
url := fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), request.Model, apiVersion.GetValue())
response, err := openai.HandleOpenAITranscriptionRequest(
ctx,
provider.client,
url,
request,
key,
provider.networkConfig.ExtraHeaders,
provider.GetProviderKey(),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
nil,
provider.logger,
)
if err != nil {
return nil, err
}
return response, err
}
// TranscriptionStream is not supported by the Azure provider.
func (provider *AzureProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey())
}
// ImageGeneration performs an Image Generation request to Azure's API.
// It formats the request, sends it to Azure, and processes the response.
// Returns a BifrostResponse containing the bifrost response or an error if the request fails.
func (provider *AzureProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key,
request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
apiVersion := key.AzureKeyConfig.APIVersion
if apiVersion == nil || apiVersion.GetValue() == "" {
apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault)
}
endpoint := key.AzureKeyConfig.Endpoint.GetValue()
if endpoint == "" {
return nil, providerUtils.NewConfigurationError("endpoint not set")
}
response, err := openai.HandleOpenAIImageGenerationRequest(
ctx,
provider.client,
fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", endpoint, request.Model, apiVersion.GetValue()),
request,
key,
provider.networkConfig.ExtraHeaders,
provider.GetProviderKey(),
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.logger,
)
if err != nil {
return nil, err
}
return response, err
}
// ImageGenerationStream performs a streaming image generation request to Azure's API.
// It formats the request, sends it to Azure, and processes the response.
// Returns a channel of BifrostStreamChunk objects or an error if the request fails.
func (provider *AzureProvider) ImageGenerationStream(
ctx *schemas.BifrostContext,
postHookRunner schemas.PostHookRunner,
postHookSpanFinalizer func(context.Context),
key schemas.Key,
request *schemas.BifrostImageGenerationRequest,
) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
apiVersion := key.AzureKeyConfig.APIVersion
if apiVersion == nil || apiVersion.GetValue() == "" {
apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault)
}
endpoint := key.AzureKeyConfig.Endpoint.GetValue()
if endpoint == "" {
return nil, providerUtils.NewConfigurationError("endpoint not set")
}
url := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", endpoint, request.Model, apiVersion.GetValue())
authHeader, err := provider.getAzureAuthHeaders(ctx, key, false)
if err != nil {
return nil, err
}
// Azure is OpenAI-compatible
return openai.HandleOpenAIImageGenerationStreaming(
ctx,
provider.streamingClient,
url,
request,
authHeader,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
postHookRunner,
nil,
nil,
nil,
provider.logger,
postHookSpanFinalizer,
)
}
// ImageEdit performs an image edit request to Azure's API.
func (provider *AzureProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
apiVersion := key.AzureKeyConfig.APIVersion
if apiVersion == nil || apiVersion.GetValue() == "" {
apiVersion = schemas.NewEnvVar(AzureAPIVersionImageEditDefault)
}
endpoint := key.AzureKeyConfig.Endpoint.GetValue()
if endpoint == "" {
return nil, providerUtils.NewConfigurationError("endpoint not set")
}
url := fmt.Sprintf("%s/openai/deployments/%s/images/edits?api-version=%s", endpoint, request.Model, apiVersion.GetValue())
response, err := openai.HandleOpenAIImageEditRequest(
ctx,
provider.client,
url,
request,
key,
provider.networkConfig.ExtraHeaders,
false,
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
provider.logger,
)
if err != nil {
return nil, err
}
return response, err
}
// ImageEditStream performs a streaming image edit request to Azure's API.
func (provider *AzureProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
apiVersion := key.AzureKeyConfig.APIVersion
if apiVersion == nil || apiVersion.GetValue() == "" {
apiVersion = schemas.NewEnvVar(AzureAPIVersionImageEditDefault)
}
endpoint := key.AzureKeyConfig.Endpoint.GetValue()
if endpoint == "" {
return nil, providerUtils.NewConfigurationError("endpoint not set")
}
url := fmt.Sprintf("%s/openai/deployments/%s/images/edits?api-version=%s", endpoint, request.Model, apiVersion.GetValue())
authHeader, err := provider.getAzureAuthHeaders(ctx, key, false)
if err != nil {
return nil, err
}
// Azure is OpenAI-compatible
return openai.HandleOpenAIImageEditStreamRequest(
ctx,
provider.streamingClient,
url,
request,
authHeader,
provider.networkConfig.ExtraHeaders,
false,
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
postHookRunner,
nil,
nil,
nil,
provider.logger,
postHookSpanFinalizer,
)
}
// ImageVariation is not supported by the Azure provider.
func (provider *AzureProvider) ImageVariation(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageVariationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageVariationRequest, provider.GetProviderKey())
}
// VideoGeneration creates a video using Azure's OpenAI-compatible Sora API.
// This delegates to the OpenAI handler with Azure-specific URL and authentication.
func (provider *AzureProvider) VideoGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
endpoint := key.AzureKeyConfig.Endpoint.GetValue()
if endpoint == "" {
return nil, providerUtils.NewConfigurationError("endpoint not set")
}
// Build Azure URL for OpenAI-compatible video generation endpoint
url := fmt.Sprintf("%s/openai/v1/videos", endpoint)
response, bifrostErr := openai.HandleOpenAIVideoGenerationRequest(
ctx,
provider.client,
url,
request,
key,
provider.networkConfig.ExtraHeaders,
provider.GetProviderKey(),
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.logger,
)
if bifrostErr != nil {
return nil, bifrostErr
}
return response, nil
}
// VideoRetrieve retrieves the status of a video from Azure's OpenAI-compatible API.
func (provider *AzureProvider) VideoRetrieve(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
providerName := provider.GetProviderKey()
if request.ID == "" {
return nil, providerUtils.NewBifrostOperationError("video_id is required", nil)
}
videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName)
endpoint := key.AzureKeyConfig.Endpoint.GetValue()
if endpoint == "" {
return nil, providerUtils.NewConfigurationError("endpoint not set")
}
authHeaders, bifrostErr := provider.getAzureAuthHeaders(ctx, key, false)
if bifrostErr != nil {
return nil, bifrostErr
}
return openai.HandleOpenAIVideoRetrieveRequest(
ctx,
provider.client,
fmt.Sprintf("%s/openai/v1/videos/%s", endpoint, videoID),
request,
key,
provider.networkConfig.ExtraHeaders,
authHeaders,
providerName,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.VideoDownload,
provider.logger,
)
}
// VideoDownload downloads video content from Azure's OpenAI-compatible API.
func (provider *AzureProvider) VideoDownload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) {
providerName := provider.GetProviderKey()
if request.ID == "" {
return nil, providerUtils.NewBifrostOperationError("video_id is required", nil)
}
videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName)
endpoint := key.AzureKeyConfig.Endpoint.GetValue()
if endpoint == "" {
return nil, providerUtils.NewConfigurationError("endpoint not set")
}
// Create request
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
// Set headers
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
// Build Azure URL
url := fmt.Sprintf("%s/openai/v1/videos/%s/content", endpoint, videoID)
req.SetRequestURI(url)
req.Header.SetMethod(http.MethodGet)
// Get authentication headers
authHeaders, bifrostErr := provider.getAzureAuthHeaders(ctx, key, false)
if bifrostErr != nil {
return nil, bifrostErr
}
for k, v := range authHeaders {
req.Header.Set(k, v)
}
// Make request
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
defer wait()
if bifrostErr != nil {
return nil, bifrostErr
}
// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
return nil, openai.ParseOpenAIError(resp)
}
body, err := providerUtils.CheckAndDecodeBody(resp)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
}
// Get content type from response
contentType := string(resp.Header.ContentType())
if contentType == "" {
// Default to video/mp4 if not specified
contentType = "video/mp4"
}
// Create response
response := &schemas.BifrostVideoDownloadResponse{
VideoID: request.ID,
Content: append([]byte(nil), body...),
ContentType: contentType,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
},
}
return response, nil
}
// VideoDelete deletes a video from Azure's OpenAI-compatible API.
func (provider *AzureProvider) VideoDelete(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoDeleteRequest) (*schemas.BifrostVideoDeleteResponse, *schemas.BifrostError) {
providerName := provider.GetProviderKey()
if request.ID == "" {
return nil, providerUtils.NewBifrostOperationError("video_id is required", nil)
}
videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName)
endpoint := key.AzureKeyConfig.Endpoint.GetValue()
if endpoint == "" {
return nil, providerUtils.NewConfigurationError("endpoint not set")
}
// Build Azure URL
url := fmt.Sprintf("%s/openai/v1/videos/%s", endpoint, videoID)
response, bifrostErr := openai.HandleOpenAIVideoDeleteRequest(
ctx,
provider.client,
url,
videoID,
key,
provider.networkConfig.ExtraHeaders,
providerName,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.logger,
)
if bifrostErr != nil {
return nil, bifrostErr
}
return response, nil
}
// VideoList lists videos from Azure's OpenAI-compatible API.
func (provider *AzureProvider) VideoList(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoListRequest) (*schemas.BifrostVideoListResponse, *schemas.BifrostError) {
endpoint := key.AzureKeyConfig.Endpoint.GetValue()
if endpoint == "" {
return nil, providerUtils.NewConfigurationError("endpoint not set")
}
// Build Azure URL
baseURL := fmt.Sprintf("%s/openai/v1/videos", endpoint)
response, bifrostErr := openai.HandleOpenAIVideoListRequest(
ctx,
provider.client,
baseURL,
request,
key,
provider.networkConfig.ExtraHeaders,
provider.GetProviderKey(),
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.logger,
)
if bifrostErr != nil {
return nil, bifrostErr
}
return response, nil
}
// VideoRemix is not supported by Azure provider.
func (provider *AzureProvider) VideoRemix(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRemixRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRemixRequest, provider.GetProviderKey())
}
// FileUpload uploads a file to Azure OpenAI.
func (provider *AzureProvider) FileUpload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) {
if len(request.File) == 0 {
return nil, providerUtils.NewBifrostOperationError("file content is required", nil)
}
if request.Purpose == "" {
return nil, providerUtils.NewBifrostOperationError("purpose is required", nil)
}
// Get API version
apiVersion := key.AzureKeyConfig.APIVersion
if apiVersion == nil {
apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault)
}
// Create multipart form data
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
// Add purpose field
if err := writer.WriteField("purpose", string(request.Purpose)); err != nil {
return nil, providerUtils.NewBifrostOperationError("failed to write purpose field", err)
}
// Add file field
filename := request.Filename
if filename == "" {
filename = "file.jsonl"
}
part, err := writer.CreateFormFile("file", filename)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("failed to create form file", err)
}
if _, err := part.Write(request.File); err != nil {
return nil, providerUtils.NewBifrostOperationError("failed to write file content", err)
}
if err := writer.Close(); err != nil {
return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err)
}
// Create request
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
// Build URL with query params
baseURL := fmt.Sprintf("%s/openai/files", key.AzureKeyConfig.Endpoint.GetValue())
values := url.Values{}
values.Set("api-version", apiVersion.GetValue())
requestURL := baseURL + "?" + values.Encode()
// Set headers
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
req.SetRequestURI(requestURL)
req.Header.SetMethod(http.MethodPost)
req.Header.SetContentType(writer.FormDataContentType())
// Set Azure authentication
if err := provider.setAzureAuth(ctx, req, key); err != nil {
return nil, err
}
req.SetBody(buf.Bytes())
// Make request
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
defer wait()
if bifrostErr != nil {
return nil, bifrostErr
}
// Handle error response
if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusCreated {
return nil, openai.ParseOpenAIError(resp)
}
body, err := providerUtils.CheckAndDecodeBody(resp)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
}
var openAIResp openai.OpenAIFileResponse
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &openAIResp, nil, sendBackRawRequest, sendBackRawResponse)
if bifrostErr != nil {
return nil, bifrostErr
}
return openAIResp.ToBifrostFileUploadResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil
}
// FileList lists files from all provided Azure keys and aggregates results.
// FileList lists files using serial pagination across keys.
// Exhausts all pages from one key before moving to the next.
func (provider *AzureProvider) FileList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) {
if len(keys) == 0 {
return nil, providerUtils.NewConfigurationError("no Azure keys available for file list operation")
}
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
// Initialize serial pagination helper
helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err)
}
// Get current key to query
key, nativeCursor, ok := helper.GetCurrentKey()
if !ok {
// All keys exhausted
return &schemas.BifrostFileListResponse{
Object: "list",
Data: []schemas.FileObject{},
HasMore: false,
}, nil
}
// Get API version
apiVersion := key.AzureKeyConfig.APIVersion
if apiVersion == nil {
apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault)
}
// Create request
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
// Build URL with query params
requestURL := fmt.Sprintf("%s/openai/files", key.AzureKeyConfig.Endpoint.GetValue())
values := url.Values{}
values.Set("api-version", apiVersion.GetValue())
if request.Purpose != "" {
values.Set("purpose", string(request.Purpose))
}
// Use native cursor from serial helper
if nativeCursor != "" {
values.Set("after", nativeCursor)
}
if encodedValues := values.Encode(); encodedValues != "" {
requestURL += "?" + encodedValues
}
// Set headers
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
req.SetRequestURI(requestURL)
req.Header.SetMethod(http.MethodGet)
req.Header.SetContentType("application/json")
// Set Azure authentication
if err := provider.setAzureAuth(ctx, req, key); err != nil {
return nil, err
}
// Make request
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
defer wait()
if bifrostErr != nil {
return nil, bifrostErr
}
// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
return nil, openai.ParseOpenAIError(resp)
}
body, decodeErr := providerUtils.CheckAndDecodeBody(resp)
if decodeErr != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr)
}
var openAIResp openai.OpenAIFileListResponse
_, _, bifrostErr = providerUtils.HandleProviderResponse(body, &openAIResp, nil, sendBackRawRequest, sendBackRawResponse)
if bifrostErr != nil {
return nil, bifrostErr
}
// Convert files to Bifrost format
files := make([]schemas.FileObject, 0, len(openAIResp.Data))
var lastFileID string
for _, file := range openAIResp.Data {
files = append(files, schemas.FileObject{
ID: file.ID,
Object: file.Object,
Bytes: file.Bytes,
CreatedAt: file.CreatedAt,
Filename: file.Filename,
Purpose: schemas.FilePurpose(file.Purpose),
Status: openai.ToBifrostFileStatus(file.Status),
StatusDetails: file.StatusDetails,
})
lastFileID = file.ID
}
// Build cursor for next request
nextCursor, hasMore := helper.BuildNextCursor(openAIResp.HasMore, lastFileID)
// Convert to Bifrost response
bifrostResp := &schemas.BifrostFileListResponse{
Object: "list",
Data: files,
HasMore: hasMore,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
},
}
if nextCursor != "" {
bifrostResp.After = &nextCursor
}
return bifrostResp, nil
}
// FileRetrieve retrieves file metadata from Azure OpenAI by trying each key until found.
func (provider *AzureProvider) FileRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) {
providerName := provider.GetProviderKey()
if request.FileID == "" {
return nil, providerUtils.NewBifrostOperationError("file_id is required", nil)
}
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
var lastErr *schemas.BifrostError
for _, key := range keys {
// Get API version
apiVersion := key.AzureKeyConfig.APIVersion
if apiVersion == nil {
apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault)
}
// Create request
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
// Build URL with query params
baseURL := fmt.Sprintf("%s/openai/files/%s", key.AzureKeyConfig.Endpoint.GetValue(), url.PathEscape(request.FileID))
values := url.Values{}
values.Set("api-version", apiVersion.GetValue())
requestURL := baseURL + "?" + values.Encode()
// Set headers
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
req.SetRequestURI(requestURL)
req.Header.SetMethod(http.MethodGet)
req.Header.SetContentType("application/json")
// Set Azure authentication
if authErr := provider.setAzureAuth(ctx, req, key); authErr != nil {
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
lastErr = authErr
continue
}
// Make request
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
if bifrostErr != nil {
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
lastErr = bifrostErr
continue
}
// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
lastErr = openai.ParseOpenAIError(resp)
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
continue
}
body, err := providerUtils.CheckAndDecodeBody(resp)
if err != nil {
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
continue
}
var openAIResp openai.OpenAIFileResponse
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &openAIResp, nil, sendBackRawRequest, sendBackRawResponse)
if bifrostErr != nil {
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
lastErr = bifrostErr
continue
}
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
return openAIResp.ToBifrostFileRetrieveResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil
}
return nil, lastErr
}
// FileDelete deletes a file from Azure OpenAI by trying each key until successful.
func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) {
if request.FileID == "" {
return nil, providerUtils.NewBifrostOperationError("file_id is required", nil)
}
if len(keys) == 0 {
return nil, providerUtils.NewConfigurationError("no Azure keys available for file delete operation")
}
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
var lastErr *schemas.BifrostError
for _, key := range keys {
// Get API version
apiVersion := key.AzureKeyConfig.APIVersion
if apiVersion == nil {
apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault)
}
// Create request
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
// Build URL with query params
baseURL := fmt.Sprintf("%s/openai/files/%s", key.AzureKeyConfig.Endpoint.GetValue(), url.PathEscape(request.FileID))
values := url.Values{}
values.Set("api-version", apiVersion.GetValue())
requestURL := baseURL + "?" + values.Encode()
// Set headers
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
req.SetRequestURI(requestURL)
req.Header.SetMethod(http.MethodDelete)
req.Header.SetContentType("application/json")
// Set Azure authentication
if authErr := provider.setAzureAuth(ctx, req, key); authErr != nil {
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
lastErr = authErr
continue
}
// Make request
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
if bifrostErr != nil {
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
lastErr = bifrostErr
continue
}
// Handle error response
if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusNoContent {
lastErr = openai.ParseOpenAIError(resp)
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
continue
}
if resp.StatusCode() == fasthttp.StatusNoContent {
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
return &schemas.BifrostFileDeleteResponse{
ID: request.FileID,
Object: "file",
Deleted: true,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
},
}, nil
}
body, err := providerUtils.CheckAndDecodeBody(resp)
if err != nil {
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
continue
}
var openAIResp openai.OpenAIFileDeleteResponse
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &openAIResp, nil, sendBackRawRequest, sendBackRawResponse)
if bifrostErr != nil {
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
lastErr = bifrostErr
continue
}
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
result := &schemas.BifrostFileDeleteResponse{
ID: openAIResp.ID,
Object: openAIResp.Object,
Deleted: openAIResp.Deleted,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
},
}
if sendBackRawRequest {
result.ExtraFields.RawRequest = rawRequest
}
if sendBackRawResponse {
result.ExtraFields.RawResponse = rawResponse
}
return result, nil
}
return nil, lastErr
}
// FileContent downloads file content from Azure OpenAI by trying each key until found.
func (provider *AzureProvider) FileContent(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) {
if request.FileID == "" {
return nil, providerUtils.NewBifrostOperationError("file_id is required", nil)
}
if len(keys) == 0 {
return nil, providerUtils.NewConfigurationError("no Azure keys available for file content operation")
}
var lastErr *schemas.BifrostError
for _, key := range keys {
// Get API version
apiVersion := key.AzureKeyConfig.APIVersion
if apiVersion == nil {
apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault)
}
// Create request
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
// Build URL with query params
baseURL := fmt.Sprintf("%s/openai/files/%s/content", key.AzureKeyConfig.Endpoint.GetValue(), url.PathEscape(request.FileID))
values := url.Values{}
values.Set("api-version", apiVersion.GetValue())
requestURL := baseURL + "?" + values.Encode()
// Set headers
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
req.SetRequestURI(requestURL)
req.Header.SetMethod(http.MethodGet)
// Set Azure authentication
if authErr := provider.setAzureAuth(ctx, req, key); authErr != nil {
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
lastErr = authErr
continue
}
// Make request
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
if bifrostErr != nil {
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
lastErr = bifrostErr
continue
}
// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
lastErr = openai.ParseOpenAIError(resp)
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
continue
}
body, err := providerUtils.CheckAndDecodeBody(resp)
if err != nil {
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
continue
}
// Get content type from response
contentType := string(resp.Header.ContentType())
if contentType == "" {
contentType = "application/octet-stream"
}
content := append([]byte(nil), body...)
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
return &schemas.BifrostFileContentResponse{
FileID: request.FileID,
Content: content,
ContentType: contentType,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
},
}, nil
}
return nil, lastErr
}
// BatchCreate creates a new batch job on Azure OpenAI.
// Azure Batch API uses the same format as OpenAI but with Azure-specific URL patterns.
func (provider *AzureProvider) BatchCreate(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) {
inputFileID := request.InputFileID
// If no file_id provided but inline requests are available, upload them first
if inputFileID == "" && len(request.Requests) > 0 {
// Convert inline requests to JSONL format
jsonlData, err := openai.ConvertRequestsToJSONL(request.Requests)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("failed to convert requests to JSONL", err)
}
// Upload the file with purpose "batch"
uploadResp, bifrostErr := provider.FileUpload(ctx, key, &schemas.BifrostFileUploadRequest{
File: jsonlData,
Filename: "batch_requests.jsonl",
Purpose: "batch",
})
if bifrostErr != nil {
return nil, bifrostErr
}
inputFileID = uploadResp.ID
}
// Validate that we have a file ID (either provided or uploaded)
if inputFileID == "" {
return nil, providerUtils.NewBifrostOperationError("either input_file_id or requests array is required for Azure batch API", nil)
}
// Get API version
apiVersion := key.AzureKeyConfig.APIVersion
if apiVersion == nil {
apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault)
}
// Create request
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
// Build URL with query params
baseURL := fmt.Sprintf("%s/openai/batches", key.AzureKeyConfig.Endpoint.GetValue())
values := url.Values{}
values.Set("api-version", apiVersion.GetValue())
requestURL := baseURL + "?" + values.Encode()
// Set headers
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
req.SetRequestURI(requestURL)
req.Header.SetMethod(http.MethodPost)
req.Header.SetContentType("application/json")
// Set Azure authentication
if err := provider.setAzureAuth(ctx, req, key); err != nil {
return nil, err
}
// Build request body
openAIReq := &openai.OpenAIBatchRequest{
InputFileID: inputFileID,
Endpoint: string(request.Endpoint),
CompletionWindow: request.CompletionWindow,
Metadata: request.Metadata,
}
// Set default completion window if not provided
if openAIReq.CompletionWindow == "" {
openAIReq.CompletionWindow = "24h"
}
jsonData, err := providerUtils.MarshalSorted(openAIReq)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
req.SetBody(jsonData)
// Make request
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
defer wait()
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
// Handle error response
if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusCreated {
return nil, openai.ParseOpenAIError(resp)
}
body, err := providerUtils.CheckAndDecodeBody(resp)
if err != nil {
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
var openAIResp openai.OpenAIBatchResponse
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &openAIResp, jsonData, sendBackRawRequest, sendBackRawResponse)
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
return openAIResp.ToBifrostBatchCreateResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil
}
// BatchList lists batch jobs from all provided Azure keys and aggregates results.
// BatchList lists batch jobs using serial pagination across keys.
// Exhausts all pages from one key before moving to the next.
func (provider *AzureProvider) BatchList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) {
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
if len(keys) == 0 {
return nil, providerUtils.NewConfigurationError("no Azure keys available for batch list operation")
}
// Initialize serial pagination helper
helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err)
}
// Get current key to query
key, nativeCursor, ok := helper.GetCurrentKey()
if !ok {
// All keys exhausted
return &schemas.BifrostBatchListResponse{
Object: "list",
Data: []schemas.BifrostBatchRetrieveResponse{},
HasMore: false,
}, nil
}
// Get API version
apiVersion := key.AzureKeyConfig.APIVersion
if apiVersion == nil {
apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault)
}
// Create request
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
// Build URL with query params
baseURL := fmt.Sprintf("%s/openai/batches", key.AzureKeyConfig.Endpoint.GetValue())
values := url.Values{}
values.Set("api-version", apiVersion.GetValue())
if request.Limit > 0 {
values.Set("limit", fmt.Sprintf("%d", request.Limit))
}
// Use native cursor from serial helper
if nativeCursor != "" {
values.Set("after", nativeCursor)
}
requestURL := baseURL + "?" + values.Encode()
// Set headers
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
req.SetRequestURI(requestURL)
req.Header.SetMethod(http.MethodGet)
req.Header.SetContentType("application/json")
// Set Azure authentication
if err := provider.setAzureAuth(ctx, req, key); err != nil {
return nil, err
}
// Make request
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
defer wait()
if bifrostErr != nil {
return nil, bifrostErr
}
// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
return nil, openai.ParseOpenAIError(resp)
}
body, decodeErr := providerUtils.CheckAndDecodeBody(resp)
if decodeErr != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr)
}
var openAIResp openai.OpenAIBatchListResponse
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &openAIResp, nil, sendBackRawRequest, sendBackRawResponse)
if bifrostErr != nil {
return nil, bifrostErr
}
// Convert batches to Bifrost format
batches := make([]schemas.BifrostBatchRetrieveResponse, 0, len(openAIResp.Data))
var lastBatchID string
for _, batch := range openAIResp.Data {
batches = append(batches, *batch.ToBifrostBatchRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse))
lastBatchID = batch.ID
}
// Build cursor for next request
nextCursor, hasMore := helper.BuildNextCursor(openAIResp.HasMore, lastBatchID)
// Convert to Bifrost response
bifrostResp := &schemas.BifrostBatchListResponse{
Object: "list",
Data: batches,
HasMore: hasMore,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
},
}
if nextCursor != "" {
bifrostResp.NextCursor = &nextCursor
}
return bifrostResp, nil
}
// BatchRetrieve retrieves a specific batch job from Azure OpenAI by trying each key until found.
func (provider *AzureProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) {
if request.BatchID == "" {
return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil)
}
if len(keys) == 0 {
return nil, providerUtils.NewConfigurationError("no Azure keys available for batch retrieve operation")
}
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
var lastErr *schemas.BifrostError
for _, key := range keys {
// Get API version
apiVersion := key.AzureKeyConfig.APIVersion
if apiVersion == nil {
apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault)
}
// Create request
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
// Build URL with query params
baseURL := fmt.Sprintf("%s/openai/batches/%s", key.AzureKeyConfig.Endpoint.GetValue(), url.PathEscape(request.BatchID))
values := url.Values{}
values.Set("api-version", apiVersion.GetValue())
requestURL := baseURL + "?" + values.Encode()
// Set headers
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
req.SetRequestURI(requestURL)
req.Header.SetMethod(http.MethodGet)
req.Header.SetContentType("application/json")
// Set Azure authentication
if authErr := provider.setAzureAuth(ctx, req, key); authErr != nil {
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
lastErr = authErr
continue
}
// Make request
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
if bifrostErr != nil {
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
lastErr = bifrostErr
continue
}
// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
lastErr = openai.ParseOpenAIError(resp)
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
continue
}
body, err := providerUtils.CheckAndDecodeBody(resp)
if err != nil {
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
continue
}
var openAIResp openai.OpenAIBatchResponse
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &openAIResp, nil, sendBackRawRequest, sendBackRawResponse)
if bifrostErr != nil {
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
lastErr = bifrostErr
continue
}
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
result := openAIResp.ToBifrostBatchRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse)
return result, nil
}
return nil, lastErr
}
// BatchCancel cancels a batch job on Azure OpenAI by trying each key until successful.
func (provider *AzureProvider) BatchCancel(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) {
if request.BatchID == "" {
return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil)
}
if len(keys) == 0 {
return nil, providerUtils.NewConfigurationError("no Azure keys available for batch cancel operation")
}
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
var lastErr *schemas.BifrostError
for _, key := range keys {
// Get API version
apiVersion := key.AzureKeyConfig.APIVersion
if apiVersion == nil {
apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault)
}
// Create request
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
// Build URL with query params
baseURL := fmt.Sprintf("%s/openai/batches/%s/cancel", key.AzureKeyConfig.Endpoint.GetValue(), url.PathEscape(request.BatchID))
values := url.Values{}
values.Set("api-version", apiVersion.GetValue())
requestURL := baseURL + "?" + values.Encode()
// Set headers
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
req.SetRequestURI(requestURL)
req.Header.SetMethod(http.MethodPost)
req.Header.SetContentType("application/json")
// Set Azure authentication
if authErr := provider.setAzureAuth(ctx, req, key); authErr != nil {
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
lastErr = authErr
continue
}
// Make request
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
if bifrostErr != nil {
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
lastErr = bifrostErr
continue
}
// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
lastErr = openai.ParseOpenAIError(resp)
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
continue
}
body, err := providerUtils.CheckAndDecodeBody(resp)
if err != nil {
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
continue
}
var openAIResp openai.OpenAIBatchResponse
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &openAIResp, nil, sendBackRawRequest, sendBackRawResponse)
if bifrostErr != nil {
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
lastErr = bifrostErr
continue
}
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
result := &schemas.BifrostBatchCancelResponse{
ID: openAIResp.ID,
Object: openAIResp.Object,
Status: openai.ToBifrostBatchStatus(openAIResp.Status),
CancellingAt: openAIResp.CancellingAt,
CancelledAt: openAIResp.CancelledAt,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
},
}
if openAIResp.RequestCounts != nil {
result.RequestCounts = schemas.BatchRequestCounts{
Total: openAIResp.RequestCounts.Total,
Completed: openAIResp.RequestCounts.Completed,
Failed: openAIResp.RequestCounts.Failed,
}
}
if sendBackRawRequest {
result.ExtraFields.RawRequest = rawRequest
}
if sendBackRawResponse {
result.ExtraFields.RawResponse = rawResponse
}
return result, nil
}
return nil, lastErr
}
// BatchDelete is not supported by the Azure provider.
func (provider *AzureProvider) BatchDelete(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchDeleteRequest) (*schemas.BifrostBatchDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchDeleteRequest, schemas.Azure)
}
// BatchResults retrieves batch results from Azure OpenAI by trying each key until successful.
// For Azure (like OpenAI), batch results are obtained by downloading the output_file_id.
func (provider *AzureProvider) BatchResults(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) {
// First, retrieve the batch to get the output_file_id (using all keys)
batchResp, bifrostErr := provider.BatchRetrieve(ctx, keys, &schemas.BifrostBatchRetrieveRequest{
Provider: request.Provider,
BatchID: request.BatchID,
})
if bifrostErr != nil {
return nil, bifrostErr
}
if batchResp.OutputFileID == nil || *batchResp.OutputFileID == "" {
return nil, providerUtils.NewBifrostOperationError("batch results not available: output_file_id is empty (batch may not be completed)", nil)
}
// Download the output file content (using all keys)
fileContentResp, bifrostErr := provider.FileContent(ctx, keys, &schemas.BifrostFileContentRequest{
Provider: request.Provider,
FileID: *batchResp.OutputFileID,
})
if bifrostErr != nil {
return nil, bifrostErr
}
// Parse JSONL content - each line is a separate result
var results []schemas.BatchResultItem
parseResult := providerUtils.ParseJSONL(fileContentResp.Content, func(line []byte) error {
var resultItem schemas.BatchResultItem
if err := sonic.Unmarshal(line, &resultItem); err != nil {
provider.logger.Warn("failed to parse batch result line: %v", err)
return err
}
results = append(results, resultItem)
return nil
})
batchResultsResp := &schemas.BifrostBatchResultsResponse{
BatchID: request.BatchID,
Results: results,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: fileContentResp.ExtraFields.Latency,
},
}
if len(parseResult.Errors) > 0 {
batchResultsResp.ExtraFields.ParseErrors = parseResult.Errors
}
return batchResultsResp, nil
}
// CountTokens is not supported by the Azure provider.
func (provider *AzureProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey())
}
// ContainerCreate is not supported by the Azure provider.
func (provider *AzureProvider) ContainerCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerCreateRequest) (*schemas.BifrostContainerCreateResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerCreateRequest, provider.GetProviderKey())
}
// ContainerList is not supported by the Azure provider.
func (provider *AzureProvider) ContainerList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerListRequest) (*schemas.BifrostContainerListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerListRequest, provider.GetProviderKey())
}
// ContainerRetrieve is not supported by the Azure provider.
func (provider *AzureProvider) ContainerRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerRetrieveRequest) (*schemas.BifrostContainerRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerRetrieveRequest, provider.GetProviderKey())
}
// ContainerDelete is not supported by the Azure provider.
func (provider *AzureProvider) ContainerDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerDeleteRequest) (*schemas.BifrostContainerDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerDeleteRequest, provider.GetProviderKey())
}
// ContainerFileCreate is not supported by the Azure provider.
func (provider *AzureProvider) ContainerFileCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerFileCreateRequest) (*schemas.BifrostContainerFileCreateResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileCreateRequest, provider.GetProviderKey())
}
// ContainerFileList is not supported by the Azure provider.
func (provider *AzureProvider) ContainerFileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileListRequest) (*schemas.BifrostContainerFileListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileListRequest, provider.GetProviderKey())
}
// ContainerFileRetrieve is not supported by the Azure provider.
func (provider *AzureProvider) ContainerFileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileRetrieveRequest) (*schemas.BifrostContainerFileRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileRetrieveRequest, provider.GetProviderKey())
}
// ContainerFileContent is not supported by the Azure provider.
func (provider *AzureProvider) ContainerFileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileContentRequest) (*schemas.BifrostContainerFileContentResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileContentRequest, provider.GetProviderKey())
}
// ContainerFileDelete is not supported by the Azure provider.
func (provider *AzureProvider) ContainerFileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileDeleteRequest) (*schemas.BifrostContainerFileDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileDeleteRequest, provider.GetProviderKey())
}
// Passthrough forwards a raw request to Azure's API without any transformation.
func (provider *AzureProvider) Passthrough(
ctx *schemas.BifrostContext,
key schemas.Key,
req *schemas.BifrostPassthroughRequest,
) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) {
url := provider.buildPassthroughURL(key, req.Path, req.RawQuery)
fasthttpReq := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(resp)
defer fasthttp.ReleaseRequest(fasthttpReq)
fasthttpReq.Header.SetMethod(req.Method)
fasthttpReq.SetRequestURI(url)
providerUtils.SetExtraHeaders(ctx, fasthttpReq, provider.networkConfig.ExtraHeaders, nil)
for k, v := range req.SafeHeaders {
fasthttpReq.Header.Set(k, v)
}
authHeaders, bifrostErr := provider.getAzureAuthHeaders(ctx, key, schemas.IsAnthropicModel(req.Model))
if bifrostErr != nil {
return nil, bifrostErr
}
for k, v := range authHeaders {
fasthttpReq.Header.Set(k, v)
}
fasthttpReq.SetBody(req.Body)
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, fasthttpReq, resp)
defer wait()
if bifrostErr != nil {
return nil, bifrostErr
}
headers := providerUtils.ExtractProviderResponseHeaders(resp)
body, err := providerUtils.CheckAndDecodeBody(resp)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err)
}
// Remove wire-level encoding headers after decoding; downstream should recalculate them for the buffered body.
for k := range headers {
if strings.EqualFold(k, "Content-Encoding") || strings.EqualFold(k, "Content-Length") {
delete(headers, k)
}
}
bifrostResponse := &schemas.BifrostPassthroughResponse{
StatusCode: resp.StatusCode(),
Headers: headers,
Body: body,
}
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
providerUtils.ParseAndSetRawRequestIfJSON(fasthttpReq, &bifrostResponse.ExtraFields)
}
return bifrostResponse, nil
}
// PassthroughStream forwards a raw streaming request to Azure's API without any transformation.
// Chunks are piped back as raw bytes, preserving the upstream SSE or binary stream format.
func (provider *AzureProvider) PassthroughStream(
ctx *schemas.BifrostContext,
postHookRunner schemas.PostHookRunner,
postHookSpanFinalizer func(context.Context),
key schemas.Key,
req *schemas.BifrostPassthroughRequest,
) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
url := provider.buildPassthroughURL(key, req.Path, req.RawQuery)
fasthttpReq := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
resp.StreamBody = true
defer fasthttp.ReleaseRequest(fasthttpReq)
fasthttpReq.Header.SetMethod(req.Method)
fasthttpReq.SetRequestURI(url)
providerUtils.SetExtraHeaders(ctx, fasthttpReq, provider.networkConfig.ExtraHeaders, nil)
for k, v := range req.SafeHeaders {
fasthttpReq.Header.Set(k, v)
}
fasthttpReq.Header.Set("Connection", "close")
authHeaders, bifrostErr := provider.getAzureAuthHeaders(ctx, key, schemas.IsAnthropicModel(req.Model))
if bifrostErr != nil {
return nil, bifrostErr
}
for k, v := range authHeaders {
fasthttpReq.Header.Set(k, v)
}
fasthttpReq.SetBody(req.Body)
activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.streamingClient, resp)
providerUtils.SetStreamIdleTimeoutIfEmpty(ctx, provider.networkConfig.StreamIdleTimeoutInSeconds)
startTime := time.Now()
if err := activeClient.Do(fasthttpReq, resp); err != nil {
providerUtils.ReleaseStreamingResponse(resp)
if errors.Is(err, context.Canceled) {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Type: schemas.Ptr(schemas.RequestCancelled),
Message: schemas.ErrRequestCancelled,
Error: err,
},
}
}
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err)
}
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err)
}
headers := providerUtils.ExtractProviderResponseHeaders(resp)
rawBodyStream := resp.BodyStream()
if rawBodyStream == nil {
providerUtils.ReleaseStreamingResponse(resp)
return nil, providerUtils.NewBifrostOperationError("provider returned an empty stream body", fmt.Errorf("provider returned an empty stream body"))
}
bodyStream, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(rawBodyStream, rawBodyStream, providerUtils.GetStreamIdleTimeout(ctx))
stopCancellation := providerUtils.SetupStreamCancellation(ctx, rawBodyStream, provider.logger)
extraFields := schemas.BifrostResponseExtraFields{}
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
providerUtils.ParseAndSetRawRequestIfJSON(fasthttpReq, &extraFields)
}
statusCode := resp.StatusCode()
ch := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize)
go func() {
defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer)
defer func() {
if ctx.Err() == context.Canceled {
providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger, postHookSpanFinalizer)
} else if ctx.Err() == context.DeadlineExceeded {
providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger, postHookSpanFinalizer)
}
close(ch)
}()
defer providerUtils.ReleaseStreamingResponse(resp)
defer stopIdleTimeout()
defer stopCancellation()
buf := make([]byte, 4096)
for {
n, readErr := bodyStream.Read(buf)
if n > 0 {
chunk := make([]byte, n)
copy(chunk, buf[:n])
select {
case ch <- &schemas.BifrostStreamChunk{
BifrostPassthroughResponse: &schemas.BifrostPassthroughResponse{
StatusCode: statusCode,
Headers: headers,
Body: chunk,
ExtraFields: extraFields,
},
}:
case <-ctx.Done():
return
}
}
if readErr == io.EOF {
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
extraFields.Latency = time.Since(startTime).Milliseconds()
finalResp := &schemas.BifrostResponse{
PassthroughResponse: &schemas.BifrostPassthroughResponse{
StatusCode: statusCode,
Headers: headers,
ExtraFields: extraFields,
},
}
postHookRunner(ctx, finalResp, nil)
if postHookSpanFinalizer != nil {
postHookSpanFinalizer(ctx)
}
return
}
if readErr != nil {
if ctx.Err() != nil {
return
}
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
extraFields.Latency = time.Since(startTime).Milliseconds()
providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger, postHookSpanFinalizer)
return
}
}
}()
return ch, nil
}
// buildPassthroughURL constructs the full Azure URL for a passthrough request.
func (provider *AzureProvider) buildPassthroughURL(key schemas.Key, path, rawQuery string) string {
endpoint := key.AzureKeyConfig.Endpoint.GetValue()
// Normalise the responses path emitted by the Azure SDK.
path = strings.Replace(path, "/openai/responses", "/openai/v1/responses", 1)
path = strings.Replace(path, "/openai/videos", "/openai/v1/videos", 1)
var apiVersion string
switch {
case strings.Contains(path, "openai/v1/responses"):
apiVersion = AzureAPIVersionPreview
case strings.HasPrefix(path, "/anthropic/"),
strings.HasPrefix(path, "/openai/v1/videos"):
apiVersion = ""
default:
if key.AzureKeyConfig.APIVersion != nil {
apiVersion = key.AzureKeyConfig.APIVersion.GetValue()
}
if apiVersion == "" {
if values, err := url.ParseQuery(rawQuery); err == nil {
apiVersion = values.Get("api-version")
}
}
}
// Strip api-version from the client query to avoid duplicates.
queryValues, _ := url.ParseQuery(rawQuery)
queryValues.Del("api-version")
cleanQuery := queryValues.Encode()
fullURL := endpoint + path
switch {
case cleanQuery != "" && apiVersion != "":
fullURL += "?" + cleanQuery + "&api-version=" + apiVersion
case cleanQuery != "":
fullURL += "?" + cleanQuery
case apiVersion != "":
fullURL += "?api-version=" + apiVersion
}
return fullURL
}