757 lines
28 KiB
Go
757 lines
28 KiB
Go
// Package telemetry provides Prometheus metrics collection and monitoring functionality
|
|
// for the Bifrost HTTP service. It includes middleware for HTTP request tracking
|
|
// and a plugin for tracking upstream provider metrics.
|
|
package telemetry
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
bifrost "github.com/maximhq/bifrost/core"
|
|
schemas "github.com/maximhq/bifrost/core/schemas"
|
|
"github.com/maximhq/bifrost/framework/modelcatalog"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"github.com/prometheus/client_golang/prometheus/collectors"
|
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
|
"github.com/prometheus/client_golang/prometheus/push"
|
|
"github.com/valyala/fasthttp"
|
|
)
|
|
|
|
const (
|
|
PluginName = "telemetry"
|
|
)
|
|
|
|
const (
|
|
startTimeKey schemas.BifrostContextKey = "bf-prom-start-time"
|
|
)
|
|
|
|
// PushGatewayConfig holds the configuration for pushing metrics to a Prometheus Push Gateway.
|
|
// This enables accurate metrics aggregation in multi-node cluster deployments where
|
|
// traditional /metrics scraping may miss nodes behind load balancers.
|
|
type PushGatewayConfig struct {
|
|
// Enabled controls whether pushing metrics to the Push Gateway is active
|
|
Enabled bool `json:"enabled"`
|
|
// PushGatewayURL is the URL of the Prometheus Push Gateway (e.g., http://pushgateway:9091)
|
|
PushGatewayURL string `json:"push_gateway_url"`
|
|
// JobName is the job label for pushed metrics (default: "bifrost")
|
|
JobName string `json:"job_name"`
|
|
// InstanceID is the instance label for grouping metrics. If empty, hostname is used.
|
|
InstanceID string `json:"instance_id"`
|
|
// PushInterval is how often to push metrics in seconds (default: 15)
|
|
PushInterval int `json:"push_interval"`
|
|
// BasicAuth credentials for the Push Gateway
|
|
BasicAuth *BasicAuthConfig `json:"basic_auth"`
|
|
}
|
|
|
|
// BasicAuthConfig holds basic authentication credentials for the Push Gateway
|
|
type BasicAuthConfig struct {
|
|
Username string `json:"username"`
|
|
Password string `json:"password"`
|
|
}
|
|
|
|
// PrometheusPlugin implements the schemas.LLMPlugin interface for Prometheus metrics.
|
|
// It tracks metrics for upstream provider requests, including:
|
|
// - Total number of requests
|
|
// - Request latency
|
|
// - Error counts
|
|
type PrometheusPlugin struct {
|
|
pricingManager *modelcatalog.ModelCatalog
|
|
registry *prometheus.Registry
|
|
|
|
logger schemas.Logger
|
|
|
|
// Built-in collectors registered by this plugin
|
|
GoCollector prometheus.Collector
|
|
ProcessCollector prometheus.Collector
|
|
|
|
// Metrics are defined using promauto for automatic registration
|
|
HTTPRequestsTotal *prometheus.CounterVec
|
|
HTTPRequestDuration *prometheus.HistogramVec
|
|
HTTPRequestSizeBytes *prometheus.HistogramVec
|
|
HTTPResponseSizeBytes *prometheus.HistogramVec
|
|
UpstreamRequestsTotal *prometheus.CounterVec
|
|
UpstreamLatencySeconds *prometheus.HistogramVec
|
|
SuccessRequestsTotal *prometheus.CounterVec
|
|
ErrorRequestsTotal *prometheus.CounterVec
|
|
InputTokensTotal *prometheus.CounterVec
|
|
OutputTokensTotal *prometheus.CounterVec
|
|
CacheHitsTotal *prometheus.CounterVec
|
|
CostTotal *prometheus.CounterVec
|
|
StreamInterTokenLatencySeconds *prometheus.HistogramVec
|
|
StreamFirstTokenLatencySeconds *prometheus.HistogramVec
|
|
KeyRotationEventsTotal *prometheus.CounterVec
|
|
customLabels []string
|
|
|
|
defaultHTTPLabels []string
|
|
defaultBifrostLabels []string
|
|
|
|
// Push gateway fields
|
|
pushConfig *PushGatewayConfig
|
|
pusher *push.Pusher
|
|
pushCtx context.Context
|
|
pushCancel context.CancelFunc
|
|
pushWg sync.WaitGroup
|
|
pushMu sync.RWMutex
|
|
pushActive bool
|
|
}
|
|
|
|
type Config struct {
|
|
CustomLabels []string `json:"custom_labels"`
|
|
Registry *prometheus.Registry
|
|
PushGateway *PushGatewayConfig `json:"push_gateway"`
|
|
}
|
|
|
|
// Init creates a new PrometheusPlugin with initialized metrics.
|
|
func Init(config *Config, pricingManager *modelcatalog.ModelCatalog, logger schemas.Logger) (*PrometheusPlugin, error) {
|
|
if config == nil {
|
|
return nil, fmt.Errorf("config is required")
|
|
}
|
|
|
|
if pricingManager == nil {
|
|
logger.Warn("telemetry plugin requires model catalog to calculate cost, all cost calculations will be skipped.")
|
|
}
|
|
|
|
registry := config.Registry
|
|
// If config has no registry, create a new one
|
|
if registry == nil {
|
|
registry = prometheus.NewRegistry()
|
|
}
|
|
|
|
// Create collectors and store references for cleanup
|
|
goCollector := collectors.NewGoCollector()
|
|
if err := registry.Register(goCollector); err != nil {
|
|
return nil, fmt.Errorf("failed to register Go collector: %v", err)
|
|
}
|
|
|
|
processCollector := collectors.NewProcessCollector(collectors.ProcessCollectorOpts{})
|
|
if err := registry.Register(processCollector); err != nil {
|
|
return nil, fmt.Errorf("failed to register process collector: %v", err)
|
|
}
|
|
|
|
defaultHTTPLabels := []string{"path", "method", "status"}
|
|
defaultBifrostLabels := []string{
|
|
"provider",
|
|
"model",
|
|
"alias",
|
|
"method",
|
|
"virtual_key_id",
|
|
"virtual_key_name",
|
|
"routing_engine_used",
|
|
"routing_rule_id",
|
|
"routing_rule_name",
|
|
"selected_key_id",
|
|
"selected_key_name",
|
|
"number_of_retries",
|
|
"fallback_index",
|
|
"team_id",
|
|
"team_name",
|
|
"customer_id",
|
|
"customer_name",
|
|
}
|
|
|
|
var filteredCustomLabels []string
|
|
if len(config.CustomLabels) > 0 {
|
|
for _, label := range config.CustomLabels {
|
|
if !containsLabel(defaultBifrostLabels, label) && !containsLabel(defaultHTTPLabels, label) {
|
|
filteredCustomLabels = append(filteredCustomLabels, label)
|
|
} else {
|
|
logger.Info("custom label %s is already a default label, it will be ignored", label)
|
|
}
|
|
}
|
|
}
|
|
|
|
factory := promauto.With(registry)
|
|
|
|
// Upstream LLM latency buckets - extended range for AI model inference times
|
|
upstreamLatencyBuckets := []float64{.005, .01, .025, .05, .1, .25, .5, 1, 2.5, 5, 10, 15, 30, 45, 60, 90} // in seconds
|
|
|
|
httpRequestsTotal := factory.NewCounterVec(
|
|
prometheus.CounterOpts{
|
|
Name: "http_requests_total",
|
|
Help: "Total number of HTTP requests.",
|
|
},
|
|
append(defaultHTTPLabels, filteredCustomLabels...),
|
|
)
|
|
|
|
// httpRequestDuration tracks the duration of HTTP requests
|
|
httpRequestDuration := factory.NewHistogramVec(
|
|
prometheus.HistogramOpts{
|
|
Name: "http_request_duration_seconds",
|
|
Help: "Duration of HTTP requests.",
|
|
Buckets: prometheus.DefBuckets,
|
|
},
|
|
append(defaultHTTPLabels, filteredCustomLabels...),
|
|
)
|
|
|
|
// httpRequestSizeBytes tracks the size of incoming HTTP requests
|
|
httpRequestSizeBytes := factory.NewHistogramVec(
|
|
prometheus.HistogramOpts{
|
|
Name: "http_request_size_bytes",
|
|
Help: "Size of HTTP requests.",
|
|
Buckets: prometheus.ExponentialBuckets(100, 10, 8), // 100B to 1GB
|
|
},
|
|
append(defaultHTTPLabels, filteredCustomLabels...),
|
|
)
|
|
|
|
// httpResponseSizeBytes tracks the size of outgoing HTTP responses
|
|
httpResponseSizeBytes := factory.NewHistogramVec(
|
|
prometheus.HistogramOpts{
|
|
Name: "http_response_size_bytes",
|
|
Help: "Size of HTTP responses.",
|
|
Buckets: prometheus.ExponentialBuckets(100, 10, 8), // 100B to 1GB
|
|
},
|
|
append(defaultHTTPLabels, filteredCustomLabels...),
|
|
)
|
|
|
|
// Bifrost Upstream Metrics
|
|
bifrostUpstreamRequestsTotal := factory.NewCounterVec(
|
|
prometheus.CounterOpts{
|
|
Name: "bifrost_upstream_requests_total",
|
|
Help: "Total number of requests forwarded to upstream providers by Bifrost.",
|
|
},
|
|
append(defaultBifrostLabels, filteredCustomLabels...),
|
|
)
|
|
|
|
bifrostUpstreamLatencySeconds := factory.NewHistogramVec(
|
|
prometheus.HistogramOpts{
|
|
Name: "bifrost_upstream_latency_seconds",
|
|
Help: "Latency of requests forwarded to upstream providers by Bifrost.",
|
|
Buckets: upstreamLatencyBuckets, // Extended range for AI model inference times
|
|
},
|
|
append(append(defaultBifrostLabels, "is_success"), filteredCustomLabels...),
|
|
)
|
|
|
|
bifrostSuccessRequestsTotal := factory.NewCounterVec(
|
|
prometheus.CounterOpts{
|
|
Name: "bifrost_success_requests_total",
|
|
Help: "Total number of successful requests forwarded to upstream providers by Bifrost.",
|
|
},
|
|
append(defaultBifrostLabels, filteredCustomLabels...),
|
|
)
|
|
|
|
bifrostErrorRequestsTotal := factory.NewCounterVec(
|
|
prometheus.CounterOpts{
|
|
Name: "bifrost_error_requests_total",
|
|
Help: "Total number of error requests forwarded to upstream providers by Bifrost.",
|
|
},
|
|
append(append(defaultBifrostLabels, "status_code"), filteredCustomLabels...),
|
|
)
|
|
|
|
bifrostInputTokensTotal := factory.NewCounterVec(
|
|
prometheus.CounterOpts{
|
|
Name: "bifrost_input_tokens_total",
|
|
Help: "Total number of input tokens forwarded to upstream providers by Bifrost.",
|
|
},
|
|
append(defaultBifrostLabels, filteredCustomLabels...),
|
|
)
|
|
|
|
bifrostOutputTokensTotal := factory.NewCounterVec(
|
|
prometheus.CounterOpts{
|
|
Name: "bifrost_output_tokens_total",
|
|
Help: "Total number of output tokens forwarded to upstream providers by Bifrost.",
|
|
},
|
|
append(defaultBifrostLabels, filteredCustomLabels...),
|
|
)
|
|
|
|
bifrostCacheHitsTotal := factory.NewCounterVec(
|
|
prometheus.CounterOpts{
|
|
Name: "bifrost_cache_hits_total",
|
|
Help: "Total number of cache hits forwarded to upstream providers by Bifrost, separated by cache type (direct/semantic).",
|
|
},
|
|
append(append(defaultBifrostLabels, "cache_type"), filteredCustomLabels...),
|
|
)
|
|
|
|
bifrostCostTotal := factory.NewCounterVec(
|
|
prometheus.CounterOpts{
|
|
Name: "bifrost_cost_total",
|
|
Help: "Total cost in USD for requests to upstream providers.",
|
|
},
|
|
append(defaultBifrostLabels, filteredCustomLabels...),
|
|
)
|
|
|
|
bifrostStreamInterTokenLatencySeconds := factory.NewHistogramVec(
|
|
prometheus.HistogramOpts{
|
|
Name: "bifrost_stream_inter_token_latency_seconds",
|
|
Help: "Latency of the intermediate tokens of a stream response.",
|
|
},
|
|
append(defaultBifrostLabels, filteredCustomLabels...),
|
|
)
|
|
|
|
bifrostStreamFirstTokenLatencySeconds := factory.NewHistogramVec(
|
|
prometheus.HistogramOpts{
|
|
Name: "bifrost_stream_first_token_latency_seconds",
|
|
Help: "Latency of the first token of a stream response.",
|
|
},
|
|
append(defaultBifrostLabels, filteredCustomLabels...),
|
|
)
|
|
|
|
// bifrostKeyRotationEventsTotal counts individual retry/rotation events from the attempt trail.
|
|
// One observation is emitted per failed attempt (where fail_reason is non-nil), not per request.
|
|
// Use this to track rate-limit pressure and network-error frequency per provider/key.
|
|
bifrostKeyRotationEventsTotal := factory.NewCounterVec(
|
|
prometheus.CounterOpts{
|
|
Name: "bifrost_key_rotation_events_total",
|
|
Help: "Number of key retry/rotation events, broken down by provider, key, and failure reason. One increment per failed attempt.",
|
|
},
|
|
[]string{"provider", "requested_model", "key_id", "key_name", "fail_reason"},
|
|
)
|
|
|
|
plugin := &PrometheusPlugin{
|
|
logger: logger,
|
|
pricingManager: pricingManager,
|
|
registry: registry,
|
|
GoCollector: goCollector,
|
|
ProcessCollector: processCollector,
|
|
HTTPRequestsTotal: httpRequestsTotal,
|
|
HTTPRequestDuration: httpRequestDuration,
|
|
HTTPRequestSizeBytes: httpRequestSizeBytes,
|
|
HTTPResponseSizeBytes: httpResponseSizeBytes,
|
|
UpstreamRequestsTotal: bifrostUpstreamRequestsTotal,
|
|
UpstreamLatencySeconds: bifrostUpstreamLatencySeconds,
|
|
SuccessRequestsTotal: bifrostSuccessRequestsTotal,
|
|
ErrorRequestsTotal: bifrostErrorRequestsTotal,
|
|
InputTokensTotal: bifrostInputTokensTotal,
|
|
OutputTokensTotal: bifrostOutputTokensTotal,
|
|
CacheHitsTotal: bifrostCacheHitsTotal,
|
|
CostTotal: bifrostCostTotal,
|
|
StreamInterTokenLatencySeconds: bifrostStreamInterTokenLatencySeconds,
|
|
StreamFirstTokenLatencySeconds: bifrostStreamFirstTokenLatencySeconds,
|
|
KeyRotationEventsTotal: bifrostKeyRotationEventsTotal,
|
|
customLabels: filteredCustomLabels,
|
|
defaultHTTPLabels: defaultHTTPLabels,
|
|
defaultBifrostLabels: defaultBifrostLabels,
|
|
}
|
|
|
|
// Start push gateway if configured
|
|
if config.PushGateway != nil && config.PushGateway.Enabled && config.PushGateway.PushGatewayURL != "" {
|
|
if err := plugin.EnablePushGateway(config.PushGateway); err != nil {
|
|
return nil, fmt.Errorf("failed to start push gateway: %w", err)
|
|
}
|
|
}
|
|
|
|
return plugin, nil
|
|
}
|
|
|
|
func (p *PrometheusPlugin) GetRegistry() *prometheus.Registry {
|
|
return p.registry
|
|
}
|
|
|
|
// GetName returns the name of the plugin.
|
|
func (p *PrometheusPlugin) GetName() string {
|
|
return PluginName
|
|
}
|
|
|
|
// HTTPTransportPreHook is not used for this plugin
|
|
func (p *PrometheusPlugin) HTTPTransportPreHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
// HTTPTransportPostHook is not used for this plugin
|
|
func (p *PrometheusPlugin) HTTPTransportPostHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, resp *schemas.HTTPResponse) error {
|
|
return nil
|
|
}
|
|
|
|
// HTTPTransportStreamChunkHook passes through streaming chunks unchanged
|
|
func (p *PrometheusPlugin) HTTPTransportStreamChunkHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, chunk *schemas.BifrostStreamChunk) (*schemas.BifrostStreamChunk, error) {
|
|
return chunk, nil
|
|
}
|
|
|
|
// PreLLMHook records the start time of the request in the context.
|
|
// This time is used later in PostLLMHook to calculate request duration.
|
|
func (p *PrometheusPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) {
|
|
ctx.SetValue(startTimeKey, time.Now())
|
|
return req, nil, nil
|
|
}
|
|
|
|
// PostLLMHook calculates duration and records upstream metrics for successful requests.
|
|
// It records:
|
|
// - Request latency
|
|
// - Total request count
|
|
func (p *PrometheusPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) {
|
|
requestType, provider, originalModel, resolvedModel := bifrost.GetResponseFields(result, bifrostErr)
|
|
|
|
// Determine effective model label and alias label (mirrors applyModelAlias logic in logging)
|
|
model := originalModel
|
|
alias := ""
|
|
if resolvedModel != "" {
|
|
model = resolvedModel
|
|
if resolvedModel != originalModel {
|
|
alias = originalModel
|
|
}
|
|
}
|
|
|
|
startTime, ok := ctx.Value(startTimeKey).(time.Time)
|
|
if !ok {
|
|
p.logger.Warn("Warning: startTime not found in context for Prometheus PostLLMHook")
|
|
return result, bifrostErr, nil
|
|
}
|
|
|
|
virtualKeyID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyGovernanceVirtualKeyID)
|
|
virtualKeyName := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyGovernanceVirtualKeyName)
|
|
routingRuleID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyGovernanceRoutingRuleID)
|
|
routingRuleName := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyGovernanceRoutingRuleName)
|
|
|
|
selectedKeyID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeySelectedKeyID)
|
|
selectedKeyName := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeySelectedKeyName)
|
|
|
|
numberOfRetries := bifrost.GetIntFromContext(ctx, schemas.BifrostContextKeyNumberOfRetries)
|
|
fallbackIndex := bifrost.GetIntFromContext(ctx, schemas.BifrostContextKeyFallbackIndex)
|
|
attemptTrail, _ := ctx.Value(schemas.BifrostContextKeyAttemptTrail).([]schemas.KeyAttemptRecord)
|
|
// Get routing engines array and join into comma-separated string
|
|
routingEngines := []string{}
|
|
if engines, ok := ctx.Value(schemas.BifrostContextKeyRoutingEnginesUsed).([]string); ok {
|
|
routingEngines = engines
|
|
}
|
|
routingEngineUsed := strings.Join(routingEngines, ",")
|
|
|
|
teamID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyGovernanceTeamID)
|
|
teamName := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyGovernanceTeamName)
|
|
customerID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyGovernanceCustomerID)
|
|
customerName := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyGovernanceCustomerName)
|
|
|
|
// Extract ALL context values BEFORE spawning the goroutine.
|
|
labelValues := map[string]string{
|
|
"provider": string(provider),
|
|
"model": model,
|
|
"alias": alias,
|
|
"method": string(requestType),
|
|
"virtual_key_id": virtualKeyID,
|
|
"virtual_key_name": virtualKeyName,
|
|
"routing_engine_used": routingEngineUsed,
|
|
"routing_rule_id": routingRuleID,
|
|
"routing_rule_name": routingRuleName,
|
|
"selected_key_id": selectedKeyID,
|
|
"selected_key_name": selectedKeyName,
|
|
"number_of_retries": strconv.Itoa(numberOfRetries),
|
|
"fallback_index": strconv.Itoa(fallbackIndex),
|
|
"team_id": teamID,
|
|
"team_name": teamName,
|
|
"customer_id": customerID,
|
|
"customer_name": customerName,
|
|
}
|
|
|
|
// Get all custom prometheus labels from context BEFORE the goroutine
|
|
for _, key := range p.customLabels {
|
|
if value := ctx.Value(schemas.BifrostContextKey(key)); value != nil {
|
|
if strValue, ok := value.(string); ok {
|
|
labelValues[key] = strValue
|
|
}
|
|
}
|
|
}
|
|
|
|
// Get label values in the correct order (cache_type will be handled separately for cache hits)
|
|
promLabelValues := getPrometheusLabelValues(append(p.defaultBifrostLabels, p.customLabels...), labelValues)
|
|
|
|
// Extract stream end indicator BEFORE the goroutine
|
|
streamEndIndicatorValue := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator)
|
|
isFinalChunk, hasFinalChunkIndicator := streamEndIndicatorValue.(bool)
|
|
|
|
pricingScopes := modelcatalog.PricingLookupScopesFromContext(ctx, string(provider))
|
|
|
|
// Calculate cost and record metrics in a separate goroutine to avoid blocking the main thread
|
|
go func() {
|
|
// For streaming requests, handle per-token metrics for intermediate chunks
|
|
if bifrost.IsStreamRequestType(requestType) {
|
|
// For intermediate chunks, record per-token metrics and exit.
|
|
// The final chunk will fall through to record full request metrics.
|
|
if !hasFinalChunkIndicator || !isFinalChunk {
|
|
// Record metrics for the first token
|
|
if result != nil {
|
|
extraFields := result.GetExtraFields()
|
|
if extraFields.ChunkIndex == 0 {
|
|
p.StreamFirstTokenLatencySeconds.WithLabelValues(promLabelValues...).Observe(float64(extraFields.Latency) / 1000.0)
|
|
} else {
|
|
p.StreamInterTokenLatencySeconds.WithLabelValues(promLabelValues...).Observe(float64(extraFields.Latency) / 1000.0)
|
|
}
|
|
}
|
|
return // Exit goroutine for intermediate chunks
|
|
}
|
|
}
|
|
|
|
cost := 0.0
|
|
if p.pricingManager != nil && result != nil {
|
|
cost = p.pricingManager.CalculateCost(result, pricingScopes)
|
|
}
|
|
|
|
// Emit one counter increment per failed attempt in the trail (fail_reason != nil).
|
|
// This decouples per-attempt retry visibility from the per-request metrics above.
|
|
for _, record := range attemptTrail {
|
|
if record.FailReason != nil {
|
|
p.KeyRotationEventsTotal.WithLabelValues(
|
|
string(provider), originalModel, record.KeyID, record.KeyName, *record.FailReason,
|
|
).Inc()
|
|
}
|
|
}
|
|
|
|
p.UpstreamRequestsTotal.WithLabelValues(promLabelValues...).Inc()
|
|
|
|
// Record latency
|
|
duration := time.Since(startTime).Seconds()
|
|
latencyLabelValues := make([]string, 0, len(promLabelValues)+1)
|
|
latencyLabelValues = append(latencyLabelValues, promLabelValues[:len(p.defaultBifrostLabels)]...) // all default labels
|
|
latencyLabelValues = append(latencyLabelValues, strconv.FormatBool(bifrostErr == nil)) // is_success
|
|
latencyLabelValues = append(latencyLabelValues, promLabelValues[len(p.defaultBifrostLabels):]...) // then custom labels
|
|
p.UpstreamLatencySeconds.WithLabelValues(latencyLabelValues...).Observe(duration)
|
|
|
|
// Record cost using the dedicated cost counter
|
|
if cost > 0 {
|
|
p.CostTotal.WithLabelValues(promLabelValues...).Add(cost)
|
|
}
|
|
|
|
// Record error and success counts
|
|
if bifrostErr != nil {
|
|
// Add status_code to label values (create new slice to avoid modifying original)
|
|
statusCode := "unknown"
|
|
if bifrostErr.StatusCode != nil {
|
|
statusCode = strconv.Itoa(*bifrostErr.StatusCode)
|
|
}
|
|
errorPromLabelValues := make([]string, 0, len(promLabelValues)+1)
|
|
errorPromLabelValues = append(errorPromLabelValues, promLabelValues[:len(p.defaultBifrostLabels)]...) // all default labels
|
|
errorPromLabelValues = append(errorPromLabelValues, statusCode) // status_code
|
|
errorPromLabelValues = append(errorPromLabelValues, promLabelValues[len(p.defaultBifrostLabels):]...) // then custom labels
|
|
|
|
p.ErrorRequestsTotal.WithLabelValues(errorPromLabelValues...).Inc()
|
|
} else {
|
|
p.SuccessRequestsTotal.WithLabelValues(promLabelValues...).Inc()
|
|
}
|
|
|
|
if result != nil {
|
|
// Record input and output tokens
|
|
var inputTokens, outputTokens int
|
|
|
|
switch {
|
|
case result.TextCompletionResponse != nil && result.TextCompletionResponse.Usage != nil:
|
|
inputTokens = result.TextCompletionResponse.Usage.PromptTokens
|
|
outputTokens = result.TextCompletionResponse.Usage.CompletionTokens
|
|
case result.ChatResponse != nil && result.ChatResponse.Usage != nil:
|
|
inputTokens = result.ChatResponse.Usage.PromptTokens
|
|
outputTokens = result.ChatResponse.Usage.CompletionTokens
|
|
case result.ResponsesResponse != nil && result.ResponsesResponse.Usage != nil:
|
|
inputTokens = result.ResponsesResponse.Usage.InputTokens
|
|
outputTokens = result.ResponsesResponse.Usage.OutputTokens
|
|
case result.ResponsesStreamResponse != nil && result.ResponsesStreamResponse.Response != nil && result.ResponsesStreamResponse.Response.Usage != nil:
|
|
inputTokens = result.ResponsesStreamResponse.Response.Usage.InputTokens
|
|
outputTokens = result.ResponsesStreamResponse.Response.Usage.OutputTokens
|
|
case result.EmbeddingResponse != nil && result.EmbeddingResponse.Usage != nil:
|
|
inputTokens = result.EmbeddingResponse.Usage.PromptTokens
|
|
outputTokens = result.EmbeddingResponse.Usage.CompletionTokens
|
|
case result.SpeechStreamResponse != nil && result.SpeechStreamResponse.Usage != nil:
|
|
inputTokens = result.SpeechStreamResponse.Usage.InputTokens
|
|
outputTokens = result.SpeechStreamResponse.Usage.OutputTokens
|
|
case result.TranscriptionResponse != nil && result.TranscriptionResponse.Usage != nil:
|
|
if result.TranscriptionResponse.Usage.InputTokens != nil {
|
|
inputTokens = *result.TranscriptionResponse.Usage.InputTokens
|
|
}
|
|
if result.TranscriptionResponse.Usage.OutputTokens != nil {
|
|
outputTokens = *result.TranscriptionResponse.Usage.OutputTokens
|
|
}
|
|
case result.TranscriptionStreamResponse != nil && result.TranscriptionStreamResponse.Usage != nil:
|
|
if result.TranscriptionStreamResponse.Usage.InputTokens != nil {
|
|
inputTokens = *result.TranscriptionStreamResponse.Usage.InputTokens
|
|
}
|
|
if result.TranscriptionStreamResponse.Usage.OutputTokens != nil {
|
|
outputTokens = *result.TranscriptionStreamResponse.Usage.OutputTokens
|
|
}
|
|
}
|
|
|
|
p.InputTokensTotal.WithLabelValues(promLabelValues...).Add(float64(inputTokens))
|
|
p.OutputTokensTotal.WithLabelValues(promLabelValues...).Add(float64(outputTokens))
|
|
|
|
// Record cache hits with cache type
|
|
extraFields := result.GetExtraFields()
|
|
if extraFields.CacheDebug != nil && extraFields.CacheDebug.CacheHit {
|
|
cacheType := "unknown"
|
|
if extraFields.CacheDebug.HitType != nil {
|
|
cacheType = *extraFields.CacheDebug.HitType
|
|
}
|
|
|
|
// Add cache_type to label values (create new slice to avoid modifying original)
|
|
cacheHitLabelValues := make([]string, 0, len(promLabelValues)+1)
|
|
cacheHitLabelValues = append(cacheHitLabelValues, promLabelValues[:len(p.defaultBifrostLabels)]...) // all default labels
|
|
cacheHitLabelValues = append(cacheHitLabelValues, cacheType) // cache_type
|
|
cacheHitLabelValues = append(cacheHitLabelValues, promLabelValues[len(p.defaultBifrostLabels):]...) // then custom labels
|
|
|
|
p.CacheHitsTotal.WithLabelValues(cacheHitLabelValues...).Inc()
|
|
}
|
|
}
|
|
}()
|
|
|
|
return result, bifrostErr, nil
|
|
}
|
|
|
|
// HTTPMiddleware wraps a FastHTTP handler to collect Prometheus metrics.
|
|
// It tracks:
|
|
// - Total number of requests
|
|
// - Request duration
|
|
// - Request and response sizes
|
|
// - HTTP status codes
|
|
// - Bifrost upstream requests and errors
|
|
func (p *PrometheusPlugin) HTTPMiddleware(handler fasthttp.RequestHandler) fasthttp.RequestHandler {
|
|
return func(ctx *fasthttp.RequestCtx) {
|
|
start := time.Now()
|
|
|
|
// Collect request metrics and headers
|
|
promKeyValues := collectPrometheusKeyValues(ctx)
|
|
reqSize := float64(ctx.Request.Header.ContentLength())
|
|
|
|
// Process the request
|
|
handler(ctx)
|
|
|
|
// Record metrics after request completion
|
|
duration := time.Since(start).Seconds()
|
|
status := strconv.Itoa(ctx.Response.StatusCode())
|
|
respSize := float64(ctx.Response.Header.ContentLength())
|
|
|
|
// Add status to the label values
|
|
promKeyValues["status"] = status
|
|
|
|
// Get label values in the correct order
|
|
promLabelValues := getPrometheusLabelValues(append([]string{"path", "method", "status"}, p.customLabels...), promKeyValues)
|
|
|
|
// Record all metrics with prometheus labels
|
|
p.HTTPRequestsTotal.WithLabelValues(promLabelValues...).Inc()
|
|
p.HTTPRequestDuration.WithLabelValues(promLabelValues...).Observe(duration)
|
|
if reqSize >= 0 {
|
|
safeObserve(p.HTTPRequestSizeBytes, reqSize, promLabelValues...)
|
|
}
|
|
if respSize >= 0 {
|
|
safeObserve(p.HTTPResponseSizeBytes, respSize, promLabelValues...)
|
|
}
|
|
}
|
|
}
|
|
|
|
// EnablePushGateway starts pushing metrics to a Prometheus Push Gateway.
|
|
// If push gateway is already active, it stops the existing one first.
|
|
func (p *PrometheusPlugin) EnablePushGateway(config *PushGatewayConfig) error {
|
|
if config == nil || config.PushGatewayURL == "" {
|
|
return fmt.Errorf("push_gateway_url is required")
|
|
}
|
|
|
|
// Stop existing push gateway if running
|
|
p.DisablePushGateway()
|
|
|
|
// Apply defaults
|
|
if config.JobName == "" {
|
|
config.JobName = "bifrost"
|
|
}
|
|
if config.PushInterval <= 0 {
|
|
config.PushInterval = 15
|
|
}
|
|
if config.InstanceID == "" {
|
|
hostname, err := os.Hostname()
|
|
if err != nil {
|
|
config.InstanceID = "unknown"
|
|
} else {
|
|
config.InstanceID = hostname
|
|
}
|
|
}
|
|
|
|
// Create the pusher with the registry
|
|
pusher := push.New(config.PushGatewayURL, config.JobName).
|
|
Gatherer(p.registry).
|
|
Grouping("instance", config.InstanceID)
|
|
|
|
if config.BasicAuth != nil && config.BasicAuth.Username != "" {
|
|
pusher = pusher.BasicAuth(config.BasicAuth.Username, config.BasicAuth.Password)
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
p.pushMu.Lock()
|
|
p.pushConfig = config
|
|
p.pusher = pusher
|
|
p.pushCtx = ctx
|
|
p.pushCancel = cancel
|
|
p.pushActive = true
|
|
p.pushWg.Add(1)
|
|
p.pushMu.Unlock()
|
|
|
|
go p.pushLoop()
|
|
|
|
p.logger.Info("push gateway started, pushing to %s every %d seconds",
|
|
config.PushGatewayURL, config.PushInterval)
|
|
|
|
return nil
|
|
}
|
|
|
|
// DisablePushGateway stops the push gateway loop if active
|
|
func (p *PrometheusPlugin) DisablePushGateway() {
|
|
p.pushMu.Lock()
|
|
if !p.pushActive {
|
|
p.pushMu.Unlock()
|
|
return
|
|
}
|
|
p.pushActive = false
|
|
p.pushCancel()
|
|
p.pushMu.Unlock()
|
|
|
|
p.pushWg.Wait()
|
|
p.logger.Info("push gateway stopped")
|
|
}
|
|
|
|
// GetPushGatewayConfig returns the current push gateway configuration
|
|
func (p *PrometheusPlugin) GetPushGatewayConfig() *PushGatewayConfig {
|
|
p.pushMu.RLock()
|
|
defer p.pushMu.RUnlock()
|
|
return p.pushConfig
|
|
}
|
|
|
|
// IsPushGatewayRunning returns whether the push gateway loop is active
|
|
func (p *PrometheusPlugin) IsPushGatewayRunning() bool {
|
|
p.pushMu.RLock()
|
|
defer p.pushMu.RUnlock()
|
|
return p.pushActive
|
|
}
|
|
|
|
// pushLoop periodically pushes metrics to the Push Gateway
|
|
func (p *PrometheusPlugin) pushLoop() {
|
|
defer p.pushWg.Done()
|
|
|
|
p.pushMu.RLock()
|
|
interval := p.pushConfig.PushInterval
|
|
p.pushMu.RUnlock()
|
|
|
|
ticker := time.NewTicker(time.Duration(interval) * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
// Initial push
|
|
p.doPush()
|
|
|
|
for {
|
|
select {
|
|
case <-p.pushCtx.Done():
|
|
// Final push before shutdown
|
|
p.logger.Info("push gateway shutting down, performing final push")
|
|
p.doPush()
|
|
return
|
|
case <-ticker.C:
|
|
p.doPush()
|
|
}
|
|
}
|
|
}
|
|
|
|
// doPush performs a single push to the Push Gateway
|
|
func (p *PrometheusPlugin) doPush() {
|
|
p.pushMu.RLock()
|
|
pusher := p.pusher
|
|
p.pushMu.RUnlock()
|
|
|
|
if pusher == nil {
|
|
return
|
|
}
|
|
|
|
if err := pusher.Push(); err != nil {
|
|
p.logger.Error("failed to push metrics to push gateway: %v", err)
|
|
}
|
|
}
|
|
|
|
func (p *PrometheusPlugin) Cleanup() error {
|
|
p.DisablePushGateway()
|
|
return nil
|
|
}
|