first commit
This commit is contained in:
688
plugins/maxim/main.go
Normal file
688
plugins/maxim/main.go
Normal file
@@ -0,0 +1,688 @@
|
||||
// Package maxim provides integration for Maxim's SDK as a Bifrost plugin.
|
||||
// This file contains the main plugin implementation.
|
||||
package maxim
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/google/uuid"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/streaming"
|
||||
|
||||
"github.com/maximhq/maxim-go"
|
||||
"github.com/maximhq/maxim-go/logging"
|
||||
maximSchemas "github.com/maximhq/maxim-go/schemas"
|
||||
)
|
||||
|
||||
// PluginName is the canonical name for the maxim plugin.
|
||||
const (
|
||||
PluginName string = "maxim"
|
||||
PluginLoggerPrefix string = "[Maxim Plugin]"
|
||||
)
|
||||
|
||||
// Config is the configuration for the maxim plugin.
|
||||
// - APIKey: API key for Maxim SDK authentication
|
||||
// - LogRepoID: Optional default ID for the Maxim logger instance
|
||||
type Config struct {
|
||||
LogRepoID string `json:"log_repo_id,omitempty"` // Optional - can be empty
|
||||
APIKey string `json:"api_key"`
|
||||
}
|
||||
|
||||
// Plugin implements the schemas.LLMPlugin interface for Maxim's logger.
|
||||
// It provides request and response tracing functionality using Maxim logger,
|
||||
// allowing detailed tracking of requests and responses across different log repositories.
|
||||
//
|
||||
// Fields:
|
||||
// - mx: The Maxim SDK instance for creating new loggers
|
||||
// - defaultLogRepoId: Default log repository ID from config (optional)
|
||||
// - loggers: Map of log repo ID to logger instances
|
||||
// - loggerMutex: RW mutex for thread-safe access to loggers map
|
||||
type Plugin struct {
|
||||
mx *maxim.Maxim
|
||||
defaultLogRepoID string
|
||||
loggers map[string]*logging.Logger
|
||||
loggerMutex *sync.RWMutex
|
||||
logger schemas.Logger
|
||||
}
|
||||
|
||||
// Init initializes and returns a Plugin instance for Maxim's logger.
|
||||
//
|
||||
// Parameters:
|
||||
// - config: Configuration for the maxim plugin
|
||||
//
|
||||
// Returns:
|
||||
// - schemas.LLMPlugin: A configured plugin instance for request/response tracing
|
||||
// - error: Any error that occurred during plugin initialization
|
||||
func Init(config *Config, logger schemas.Logger) (schemas.LLMPlugin, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("config is required")
|
||||
}
|
||||
// check if Maxim Logger variables are set
|
||||
if config.APIKey == "" {
|
||||
return nil, fmt.Errorf("apiKey is not set")
|
||||
}
|
||||
|
||||
mx := maxim.Init(&maxim.MaximSDKConfig{ApiKey: config.APIKey})
|
||||
|
||||
plugin := &Plugin{
|
||||
mx: mx,
|
||||
defaultLogRepoID: config.LogRepoID,
|
||||
loggers: make(map[string]*logging.Logger),
|
||||
loggerMutex: &sync.RWMutex{},
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// Initialize default logger if LogRepoId is provided
|
||||
if config.LogRepoID != "" {
|
||||
logger, err := mx.GetLogger(&logging.LoggerConfig{Id: config.LogRepoID})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize default logger: %w", err)
|
||||
}
|
||||
plugin.loggers[config.LogRepoID] = logger
|
||||
}
|
||||
|
||||
return plugin, nil
|
||||
}
|
||||
|
||||
// TraceIDKey is the context key used to store and retrieve trace IDs.
|
||||
// This constant provides a consistent key for tracking request traces
|
||||
// throughout the request/response lifecycle.
|
||||
const (
|
||||
SessionIDKey schemas.BifrostContextKey = "session-id"
|
||||
TraceIDKey schemas.BifrostContextKey = "trace-id"
|
||||
TraceNameKey schemas.BifrostContextKey = "trace-name"
|
||||
GenerationIDKey schemas.BifrostContextKey = "generation-id"
|
||||
GenerationNameKey schemas.BifrostContextKey = "generation-name"
|
||||
TagsKey schemas.BifrostContextKey = "maxim-tags"
|
||||
LogRepoIDKey schemas.BifrostContextKey = "log-repo-id"
|
||||
)
|
||||
|
||||
// convertAccResultToProcessedStreamResponse converts StreamAccumulatorResult to ProcessedStreamResponse
|
||||
func convertAccResultToProcessedStreamResponse(accResult *schemas.StreamAccumulatorResult) *streaming.ProcessedStreamResponse {
|
||||
if accResult == nil {
|
||||
return nil
|
||||
}
|
||||
// Determine StreamType based on the response content
|
||||
streamType := streaming.StreamTypeChat
|
||||
if accResult.AudioOutput != nil {
|
||||
streamType = streaming.StreamTypeAudio
|
||||
} else if accResult.TranscriptionOutput != nil {
|
||||
streamType = streaming.StreamTypeTranscription
|
||||
} else if len(accResult.OutputMessages) > 0 {
|
||||
streamType = streaming.StreamTypeResponses
|
||||
} else if accResult.ImageGenerationOutput != nil {
|
||||
streamType = streaming.StreamTypeImage
|
||||
}
|
||||
return &streaming.ProcessedStreamResponse{
|
||||
RequestID: accResult.RequestID,
|
||||
StreamType: streamType,
|
||||
RequestedModel: accResult.RequestedModel,
|
||||
ResolvedModel: accResult.ResolvedModel,
|
||||
Provider: accResult.Provider,
|
||||
Data: &streaming.AccumulatedData{
|
||||
Status: accResult.Status,
|
||||
Latency: accResult.Latency,
|
||||
TimeToFirstToken: accResult.TimeToFirstToken,
|
||||
OutputMessage: accResult.OutputMessage,
|
||||
OutputMessages: accResult.OutputMessages,
|
||||
TokenUsage: accResult.TokenUsage,
|
||||
Cost: accResult.Cost,
|
||||
ErrorDetails: accResult.ErrorDetails,
|
||||
AudioOutput: accResult.AudioOutput,
|
||||
TranscriptionOutput: accResult.TranscriptionOutput,
|
||||
FinishReason: accResult.FinishReason,
|
||||
RawResponse: accResult.RawResponse,
|
||||
},
|
||||
RawRequest: &accResult.RawRequest,
|
||||
}
|
||||
}
|
||||
|
||||
// The plugin provides request/response tracing functionality by integrating with Maxim's logging system.
|
||||
// It supports both chat completion and text completion requests, tracking the entire lifecycle of each request
|
||||
// including inputs, parameters, and responses.
|
||||
//
|
||||
// Key Features:
|
||||
// - Automatic trace and generation ID management
|
||||
// - Support for both chat and text completion requests
|
||||
// - Contextual tracking across request lifecycle
|
||||
// - Graceful handling of existing trace/generation IDs
|
||||
//
|
||||
// The plugin uses context values to maintain trace and generation IDs throughout the request lifecycle.
|
||||
// These IDs can be propagated from external systems through HTTP headers (x-bf-maxim-trace-id and x-bf-maxim-generation-id).
|
||||
|
||||
// GetName returns the name of the plugin.
|
||||
func (plugin *Plugin) GetName() string {
|
||||
return PluginName
|
||||
}
|
||||
|
||||
// HTTPTransportPreHook is not used for this plugin
|
||||
func (plugin *Plugin) HTTPTransportPreHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// HTTPTransportPostHook is not used for this plugin
|
||||
func (plugin *Plugin) HTTPTransportPostHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, resp *schemas.HTTPResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// HTTPTransportStreamChunkHook passes through streaming chunks unchanged
|
||||
func (plugin *Plugin) HTTPTransportStreamChunkHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, chunk *schemas.BifrostStreamChunk) (*schemas.BifrostStreamChunk, error) {
|
||||
return chunk, nil
|
||||
}
|
||||
|
||||
// getEffectiveLogRepoID determines which single log repo ID to use based on priority:
|
||||
// 1. Header log repo ID (if provided)
|
||||
// 2. Default log repo ID from config (if configured)
|
||||
// 3. Empty string (skip logging)
|
||||
func (plugin *Plugin) getEffectiveLogRepoID(ctx *schemas.BifrostContext) string {
|
||||
// Check for header log repo ID first (highest priority)
|
||||
if ctx != nil {
|
||||
if headerRepoID, ok := ctx.Value(LogRepoIDKey).(string); ok && headerRepoID != "" {
|
||||
return headerRepoID
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to default log repo ID from config
|
||||
if plugin.defaultLogRepoID != "" {
|
||||
return plugin.defaultLogRepoID
|
||||
}
|
||||
|
||||
// Return empty string if neither header nor default is available
|
||||
return ""
|
||||
}
|
||||
|
||||
// getOrCreateLogger gets an existing logger or creates a new one for the given log repo ID
|
||||
func (plugin *Plugin) getOrCreateLogger(logRepoID string) (*logging.Logger, error) {
|
||||
// First, try to get existing logger (read lock)
|
||||
plugin.loggerMutex.RLock()
|
||||
if logger, exists := plugin.loggers[logRepoID]; exists {
|
||||
plugin.loggerMutex.RUnlock()
|
||||
return logger, nil
|
||||
}
|
||||
plugin.loggerMutex.RUnlock()
|
||||
|
||||
// Logger doesn't exist, create it (write lock)
|
||||
plugin.loggerMutex.Lock()
|
||||
defer plugin.loggerMutex.Unlock()
|
||||
|
||||
// Double-check in case another goroutine created it while we were waiting
|
||||
if logger, exists := plugin.loggers[logRepoID]; exists {
|
||||
return logger, nil
|
||||
}
|
||||
|
||||
// Create new logger
|
||||
logger, err := plugin.mx.GetLogger(&logging.LoggerConfig{Id: logRepoID})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create logger for repo ID %s: %w", logRepoID, err)
|
||||
}
|
||||
|
||||
plugin.loggers[logRepoID] = logger
|
||||
return logger, nil
|
||||
}
|
||||
|
||||
// PreLLMHook is called before a request is processed by Bifrost.
|
||||
// It manages trace and generation tracking for incoming requests by either:
|
||||
// - Creating a new trace if none exists
|
||||
// - Reusing an existing trace ID from the context
|
||||
// - Creating a new generation within an existing trace
|
||||
// - Skipping trace/generation creation if they already exist
|
||||
//
|
||||
// The function handles both chat completion and text completion requests,
|
||||
// capturing relevant metadata such as:
|
||||
// - Request type (chat/text completion)
|
||||
// - Model information
|
||||
// - Message content and role
|
||||
// - Model parameters
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Pointer to the schemas.BifrostContext that may contain existing trace/generation IDs
|
||||
// - req: The incoming Bifrost request to be traced
|
||||
//
|
||||
// Returns:
|
||||
// - *schemas.BifrostRequest: The original request, unmodified
|
||||
// - error: Any error that occurred during trace/generation creation
|
||||
func (plugin *Plugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) {
|
||||
if req != nil && req.RequestType == schemas.RealtimeRequest {
|
||||
return req, nil, nil
|
||||
}
|
||||
|
||||
var traceID string
|
||||
var traceName string
|
||||
var sessionID string
|
||||
var generationName string
|
||||
|
||||
// Get effective log repo ID (header > default > skip)
|
||||
effectiveLogRepoID := plugin.getEffectiveLogRepoID(ctx)
|
||||
|
||||
// If no log repo ID available, skip logging
|
||||
if effectiveLogRepoID == "" {
|
||||
return req, nil, nil
|
||||
}
|
||||
|
||||
// Check if context already has traceID and generationID
|
||||
if ctx != nil {
|
||||
if existingGenerationID, ok := ctx.Value(GenerationIDKey).(string); ok && existingGenerationID != "" {
|
||||
// If generationID exists, return early
|
||||
return req, nil, nil
|
||||
}
|
||||
|
||||
if existingTraceID, ok := ctx.Value(TraceIDKey).(string); ok && existingTraceID != "" {
|
||||
// If traceID exists, and no generationID, create a new generation on the trace
|
||||
traceID = existingTraceID
|
||||
}
|
||||
|
||||
if existingSessionID, ok := ctx.Value(SessionIDKey).(string); ok && existingSessionID != "" {
|
||||
sessionID = existingSessionID
|
||||
}
|
||||
|
||||
if existingTraceName, ok := ctx.Value(TraceNameKey).(string); ok && existingTraceName != "" {
|
||||
traceName = existingTraceName
|
||||
}
|
||||
|
||||
if existingGenerationName, ok := ctx.Value(GenerationNameKey).(string); ok && existingGenerationName != "" {
|
||||
generationName = existingGenerationName
|
||||
}
|
||||
}
|
||||
|
||||
provider, model, _ := req.GetRequestFields()
|
||||
|
||||
// Determine request type and set appropriate tags
|
||||
var messages []maximSchemas.CompletionRequest
|
||||
var latestMessage string
|
||||
|
||||
modelParams := make(map[string]interface{})
|
||||
|
||||
switch req.RequestType {
|
||||
case schemas.TextCompletionRequest, schemas.TextCompletionStreamRequest:
|
||||
messages = append(messages, maximSchemas.CompletionRequest{
|
||||
Role: string(schemas.ChatMessageRoleUser),
|
||||
Content: req.TextCompletionRequest.Input,
|
||||
})
|
||||
if req.TextCompletionRequest.Input.PromptStr != nil {
|
||||
latestMessage = *req.TextCompletionRequest.Input.PromptStr
|
||||
} else {
|
||||
var stringBuilder strings.Builder
|
||||
for _, prompt := range req.TextCompletionRequest.Input.PromptArray {
|
||||
stringBuilder.WriteString(prompt)
|
||||
}
|
||||
latestMessage = stringBuilder.String()
|
||||
}
|
||||
|
||||
if req.TextCompletionRequest.Params != nil {
|
||||
// Convert the struct to a map using reflection or JSON marshaling
|
||||
jsonData, err := sonic.Marshal(req.TextCompletionRequest.Params)
|
||||
if err == nil {
|
||||
sonic.Unmarshal(jsonData, &modelParams)
|
||||
}
|
||||
}
|
||||
case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest:
|
||||
for _, message := range req.ChatRequest.Input {
|
||||
messages = append(messages, maximSchemas.CompletionRequest{
|
||||
Role: string(message.Role),
|
||||
Content: message.Content,
|
||||
})
|
||||
}
|
||||
if len(req.ChatRequest.Input) > 0 {
|
||||
lastMsg := req.ChatRequest.Input[len(req.ChatRequest.Input)-1]
|
||||
if lastMsg.Content.ContentStr != nil {
|
||||
latestMessage = *lastMsg.Content.ContentStr
|
||||
} else if lastMsg.Content.ContentBlocks != nil {
|
||||
// Find the last text content block
|
||||
for i := len(lastMsg.Content.ContentBlocks) - 1; i >= 0; i-- {
|
||||
block := (lastMsg.Content.ContentBlocks)[i]
|
||||
if block.Type == schemas.ChatContentBlockTypeText && block.Text != nil {
|
||||
latestMessage = *block.Text
|
||||
break
|
||||
}
|
||||
}
|
||||
// If no text block found, use placeholder
|
||||
if latestMessage == "" {
|
||||
latestMessage = "-"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if req.ChatRequest.Params != nil {
|
||||
// Convert the struct to a map using reflection or JSON marshaling
|
||||
jsonData, err := sonic.Marshal(req.ChatRequest.Params)
|
||||
if err == nil {
|
||||
sonic.Unmarshal(jsonData, &modelParams)
|
||||
}
|
||||
}
|
||||
case schemas.ResponsesRequest, schemas.ResponsesStreamRequest, schemas.WebSocketResponsesRequest:
|
||||
for _, message := range req.ResponsesRequest.Input {
|
||||
if message.Content != nil {
|
||||
role := schemas.ChatMessageRoleUser
|
||||
if message.Role != nil {
|
||||
role = schemas.ChatMessageRole(*message.Role)
|
||||
}
|
||||
messages = append(messages, maximSchemas.CompletionRequest{
|
||||
Role: string(role),
|
||||
Content: message.Content,
|
||||
})
|
||||
}
|
||||
}
|
||||
if len(req.ResponsesRequest.Input) > 0 {
|
||||
lastMsg := req.ResponsesRequest.Input[len(req.ResponsesRequest.Input)-1]
|
||||
// Initialize to placeholder in case content is missing or empty
|
||||
latestMessage = "-"
|
||||
|
||||
// Check if Content is nil before accessing its fields
|
||||
if lastMsg.Content != nil {
|
||||
if lastMsg.Content.ContentStr != nil {
|
||||
latestMessage = *lastMsg.Content.ContentStr
|
||||
} else if lastMsg.Content.ContentBlocks != nil {
|
||||
// Find the last text content block
|
||||
for i := len(lastMsg.Content.ContentBlocks) - 1; i >= 0; i-- {
|
||||
block := (lastMsg.Content.ContentBlocks)[i]
|
||||
if block.Text != nil {
|
||||
latestMessage = *block.Text
|
||||
break
|
||||
}
|
||||
}
|
||||
// If no text block found, keep the placeholder
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if req.ResponsesRequest.Params != nil {
|
||||
// Convert the struct to a map using reflection or JSON marshaling
|
||||
jsonData, err := sonic.Marshal(req.ResponsesRequest.Params)
|
||||
if err == nil {
|
||||
sonic.Unmarshal(jsonData, &modelParams)
|
||||
}
|
||||
}
|
||||
case schemas.ImageGenerationRequest, schemas.ImageGenerationStreamRequest:
|
||||
if req.ImageGenerationRequest == nil || req.ImageGenerationRequest.Input == nil {
|
||||
break
|
||||
}
|
||||
messages = append(messages, maximSchemas.CompletionRequest{
|
||||
Role: string(schemas.ChatMessageRoleUser),
|
||||
Content: req.ImageGenerationRequest.Input.Prompt,
|
||||
})
|
||||
latestMessage = req.ImageGenerationRequest.Input.Prompt
|
||||
if req.ImageGenerationRequest.Params != nil {
|
||||
jsonData, err := sonic.Marshal(req.ImageGenerationRequest.Params)
|
||||
if err == nil {
|
||||
sonic.Unmarshal(jsonData, &modelParams)
|
||||
}
|
||||
}
|
||||
case schemas.ImageEditRequest, schemas.ImageEditStreamRequest:
|
||||
if req.ImageEditRequest == nil || req.ImageEditRequest.Input == nil {
|
||||
break
|
||||
}
|
||||
messages = append(messages, maximSchemas.CompletionRequest{
|
||||
Role: string(schemas.ChatMessageRoleUser),
|
||||
Content: req.ImageEditRequest.Input.Prompt,
|
||||
})
|
||||
latestMessage = req.ImageEditRequest.Input.Prompt
|
||||
if req.ImageEditRequest.Params != nil {
|
||||
jsonData, err := sonic.Marshal(req.ImageEditRequest.Params)
|
||||
if err == nil {
|
||||
sonic.Unmarshal(jsonData, &modelParams)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if traceID == "" {
|
||||
// If traceID is not set, create a new trace
|
||||
traceID = uuid.New().String()
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("bifrost_%s", string(req.RequestType))
|
||||
if traceName != "" {
|
||||
name = traceName
|
||||
}
|
||||
|
||||
traceConfig := logging.TraceConfig{
|
||||
Id: traceID,
|
||||
Name: maxim.StrPtr(name),
|
||||
}
|
||||
|
||||
if sessionID != "" {
|
||||
traceConfig.SessionId = &sessionID
|
||||
}
|
||||
|
||||
// Create trace in the effective log repository
|
||||
logger, err := plugin.getOrCreateLogger(effectiveLogRepoID)
|
||||
if err != nil {
|
||||
return req, nil, fmt.Errorf("failed to create trace: %w", err)
|
||||
}
|
||||
|
||||
trace := logger.Trace(&traceConfig)
|
||||
trace.SetInput(latestMessage)
|
||||
generationID := uuid.New().String()
|
||||
|
||||
generationConfig := logging.GenerationConfig{
|
||||
Id: generationID,
|
||||
Model: model,
|
||||
Provider: string(provider),
|
||||
Messages: messages,
|
||||
ModelParameters: modelParams,
|
||||
}
|
||||
|
||||
if generationName != "" {
|
||||
generationConfig.Name = &generationName
|
||||
}
|
||||
|
||||
// Add generation to the effective log repository
|
||||
logger.AddGenerationToTrace(traceID, &generationConfig)
|
||||
|
||||
// Extract and log attachments from message content
|
||||
for _, att := range ExtractAttachmentsFromRequest(req) {
|
||||
if att != nil {
|
||||
logger.GenerationAddAttachment(generationID, att)
|
||||
}
|
||||
}
|
||||
|
||||
if ctx != nil {
|
||||
if _, ok := ctx.Value(TraceIDKey).(string); !ok {
|
||||
ctx.SetValue(TraceIDKey, traceID)
|
||||
}
|
||||
ctx.SetValue(GenerationIDKey, generationID)
|
||||
|
||||
// Extract request ID from context, if not present, create a new one
|
||||
requestID, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string)
|
||||
if !ok || requestID == "" {
|
||||
// This should never happen since core/bifrost.go guarantees it's set before PreHooks
|
||||
requestID = uuid.New().String()
|
||||
plugin.logger.Warn("%s request ID missing in PreLLMHook, using fallback: %s", PluginLoggerPrefix, requestID)
|
||||
}
|
||||
|
||||
// If streaming, create accumulator via central tracer using traceID
|
||||
if bifrost.IsStreamRequestType(req.RequestType) {
|
||||
tracer, bifrostTraceID, err := bifrost.GetTracerFromContext(ctx)
|
||||
if err == nil && tracer != nil && bifrostTraceID != "" {
|
||||
tracer.CreateStreamAccumulator(bifrostTraceID, time.Now())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return req, nil, nil
|
||||
}
|
||||
|
||||
// PostLLMHook is called after a request has been processed by Bifrost.
|
||||
// It completes the request trace by:
|
||||
// - Adding response data to the generation if a generation ID exists
|
||||
// - Logging error details if bifrostErr is provided
|
||||
// - Ending the generation if it exists
|
||||
// - Ending the trace if a trace ID exists
|
||||
// - Flushing all pending log data
|
||||
//
|
||||
// The function gracefully handles cases where trace or generation IDs may be missing,
|
||||
// ensuring that partial logging is still performed when possible.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Pointer to the schemas.BifrostContext containing trace/generation IDs
|
||||
// - result: The Bifrost response to be traced
|
||||
// - bifrostErr: The BifrostError returned by the request, if any
|
||||
//
|
||||
// Returns:
|
||||
// - *schemas.BifrostResponse: The original response, unmodified
|
||||
// - *schemas.BifrostError: The original error, unmodified
|
||||
// - error: Never returns an error as it handles missing IDs gracefully
|
||||
func (plugin *Plugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) {
|
||||
requestType, _, _, _ := bifrost.GetResponseFields(result, bifrostErr)
|
||||
if requestType == schemas.RealtimeRequest {
|
||||
return result, bifrostErr, nil
|
||||
}
|
||||
|
||||
// Get effective log repo ID for this request
|
||||
effectiveLogRepoID := plugin.getEffectiveLogRepoID(ctx)
|
||||
if effectiveLogRepoID == "" {
|
||||
return result, bifrostErr, nil
|
||||
}
|
||||
if ctx == nil {
|
||||
return result, bifrostErr, nil
|
||||
}
|
||||
|
||||
requestID, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string)
|
||||
if !ok || requestID == "" {
|
||||
return result, bifrostErr, nil
|
||||
}
|
||||
|
||||
// Capture context values BEFORE goroutine to avoid race conditions
|
||||
// when the same context is reused across multiple requests
|
||||
generationID, hasGenerationID := ctx.Value(GenerationIDKey).(string)
|
||||
traceID, hasTraceID := ctx.Value(TraceIDKey).(string)
|
||||
tags, hasTags := ctx.Value(TagsKey).(map[string]string)
|
||||
|
||||
isFinalChunk := bifrost.IsFinalChunk(ctx)
|
||||
|
||||
go func() {
|
||||
requestType, _, originalModel, resolvedModel := bifrost.GetResponseFields(result, bifrostErr)
|
||||
modelTag := resolvedModel
|
||||
if modelTag == "" {
|
||||
modelTag = originalModel
|
||||
}
|
||||
|
||||
var streamResponse *streaming.ProcessedStreamResponse
|
||||
if bifrost.IsStreamRequestType(requestType) {
|
||||
// Use central tracer's accumulator
|
||||
tracer, bifrostTraceID, err := bifrost.GetTracerFromContext(ctx)
|
||||
if err == nil && tracer != nil && bifrostTraceID != "" {
|
||||
accResult := tracer.ProcessStreamingChunk(bifrostTraceID, isFinalChunk, result, bifrostErr)
|
||||
if accResult != nil {
|
||||
streamResponse = convertAccResultToProcessedStreamResponse(accResult)
|
||||
}
|
||||
}
|
||||
|
||||
// For streaming: only process on final chunk. Skip intermediate chunks.
|
||||
// When there's an error, streamResponse may be nil but we must still log bifrostErr.
|
||||
if !isFinalChunk {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
logger, err := plugin.getOrCreateLogger(effectiveLogRepoID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if hasGenerationID {
|
||||
if bifrostErr != nil {
|
||||
// Safely extract message from nested error
|
||||
message := ""
|
||||
code := ""
|
||||
errorType := ""
|
||||
if bifrostErr.Error != nil {
|
||||
message = bifrostErr.Error.Message
|
||||
if bifrostErr.Error.Code != nil {
|
||||
code = *bifrostErr.Error.Code
|
||||
}
|
||||
if bifrostErr.Error.Type != nil {
|
||||
errorType = *bifrostErr.Error.Type
|
||||
}
|
||||
}
|
||||
genErr := maximSchemas.GenerationError{
|
||||
Message: message,
|
||||
Code: &code,
|
||||
Type: &errorType,
|
||||
}
|
||||
logger.SetGenerationError(generationID, &genErr)
|
||||
|
||||
if bifrost.IsStreamRequestType(requestType) {
|
||||
// Cleanup via central tracer
|
||||
tracer, bifrostTraceID, err := bifrost.GetTracerFromContext(ctx)
|
||||
if err == nil && tracer != nil && bifrostTraceID != "" {
|
||||
tracer.CleanupStreamAccumulator(bifrostTraceID)
|
||||
}
|
||||
}
|
||||
} else if result != nil {
|
||||
switch requestType {
|
||||
case schemas.TextCompletionRequest, schemas.TextCompletionStreamRequest:
|
||||
if streamResponse != nil {
|
||||
logger.AddResultToGeneration(generationID, streamResponse.ToBifrostResponse().TextCompletionResponse)
|
||||
} else {
|
||||
logger.AddResultToGeneration(generationID, result.TextCompletionResponse)
|
||||
}
|
||||
case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest:
|
||||
if streamResponse != nil {
|
||||
logger.AddResultToGeneration(generationID, streamResponse.ToBifrostResponse().ChatResponse)
|
||||
} else {
|
||||
logger.AddResultToGeneration(generationID, result.ChatResponse)
|
||||
}
|
||||
case schemas.ResponsesRequest, schemas.ResponsesStreamRequest, schemas.WebSocketResponsesRequest:
|
||||
if streamResponse != nil {
|
||||
logger.AddResultToGeneration(generationID, streamResponse.ToBifrostResponse().ResponsesResponse)
|
||||
} else {
|
||||
logger.AddResultToGeneration(generationID, result.ResponsesResponse)
|
||||
}
|
||||
case schemas.ImageGenerationRequest, schemas.ImageGenerationStreamRequest,
|
||||
schemas.ImageEditRequest, schemas.ImageEditStreamRequest:
|
||||
if streamResponse != nil {
|
||||
logger.AddResultToGeneration(generationID, streamResponse.ToBifrostResponse().ImageGenerationResponse)
|
||||
} else if result != nil {
|
||||
logger.AddResultToGeneration(generationID, result.ImageGenerationResponse)
|
||||
}
|
||||
}
|
||||
if streamResponse != nil && isFinalChunk {
|
||||
// Cleanup via central tracer
|
||||
tracer, bifrostTraceID, err := bifrost.GetTracerFromContext(ctx)
|
||||
if err == nil && tracer != nil && bifrostTraceID != "" {
|
||||
tracer.CleanupStreamAccumulator(bifrostTraceID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if hasTraceID {
|
||||
logger.EndTrace(traceID)
|
||||
}
|
||||
|
||||
// add tags to the generation and trace
|
||||
if hasTags {
|
||||
for key, value := range tags {
|
||||
if generationID != "" {
|
||||
logger.AddTagToGeneration(generationID, key, value)
|
||||
}
|
||||
if traceID != "" {
|
||||
logger.AddTagToTrace(traceID, key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
if hasGenerationID && generationID != "" && modelTag != "" {
|
||||
logger.AddTagToGeneration(generationID, "model", string(modelTag))
|
||||
}
|
||||
if hasTraceID && traceID != "" && modelTag != "" {
|
||||
logger.AddTagToTrace(traceID, "model", string(modelTag))
|
||||
}
|
||||
// Flush only the effective logger that was used for this request
|
||||
logger.Flush()
|
||||
}()
|
||||
return result, bifrostErr, nil
|
||||
}
|
||||
|
||||
func (plugin *Plugin) Cleanup() error {
|
||||
// Flush all loggers
|
||||
plugin.loggerMutex.RLock()
|
||||
for _, logger := range plugin.loggers {
|
||||
logger.Flush()
|
||||
}
|
||||
plugin.loggerMutex.RUnlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user