Files
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

263 lines
7.6 KiB
Go

package jsonparser
import (
"strings"
"sync"
"time"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
const (
PluginName = "streaming-json-parser"
)
type Usage string
const (
AllRequests Usage = "all_requests"
PerRequest Usage = "per_request"
)
// AccumulatedContent holds both the content and timestamp for a request
type AccumulatedContent struct {
Content *strings.Builder
Timestamp time.Time
}
// JsonParserPlugin provides JSON parsing capabilities for streaming responses
// It handles partial JSON chunks by accumulating them and making the accumulated content valid JSON
type JsonParserPlugin struct {
usage Usage
// State management for accumulating chunks
accumulatedContent map[string]*AccumulatedContent // requestID -> accumulated content with timestamp
mutex sync.RWMutex
// Cleanup configuration
cleanupInterval time.Duration
maxAge time.Duration
stopCleanup chan struct{}
stopOnce sync.Once
}
// PluginConfig holds configuration options for the JSON parser plugin
type PluginConfig struct {
Usage Usage
CleanupInterval time.Duration
MaxAge time.Duration
}
const (
EnableStreamingJSONParser schemas.BifrostContextKey = "enable-streaming-json-parser"
)
// Init creates a new JSON parser plugin instance with custom configuration
func Init(config PluginConfig) (*JsonParserPlugin, error) {
// Set defaults if not provided
if config.CleanupInterval <= 0 {
config.CleanupInterval = 5 * time.Minute
}
if config.MaxAge <= 0 {
config.MaxAge = 30 * time.Minute
}
if config.Usage == "" {
config.Usage = PerRequest
}
plugin := &JsonParserPlugin{
usage: config.Usage,
accumulatedContent: make(map[string]*AccumulatedContent),
cleanupInterval: config.CleanupInterval,
maxAge: config.MaxAge,
stopCleanup: make(chan struct{}),
}
// Start the cleanup goroutine
go plugin.startCleanupGoroutine()
return plugin, nil
}
// GetName returns the plugin name
func (p *JsonParserPlugin) GetName() string {
return PluginName
}
// HTTPTransportPreHook is not used for this plugin
func (p *JsonParserPlugin) HTTPTransportPreHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) {
return nil, nil
}
// HTTPTransportPostHook is not used for this plugin
func (p *JsonParserPlugin) HTTPTransportPostHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, resp *schemas.HTTPResponse) error {
return nil
}
// HTTPTransportStreamChunkHook passes through streaming chunks unchanged
func (p *JsonParserPlugin) HTTPTransportStreamChunkHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, chunk *schemas.BifrostStreamChunk) (*schemas.BifrostStreamChunk, error) {
return chunk, nil
}
// PreLLMHook is not used for this plugin as we only process responses
// Parameters:
// - ctx: The Bifrost context
// - req: The Bifrost request
//
// Returns:
// - *schemas.BifrostRequest: The processed request
// - *schemas.LLMPluginShortCircuit: The plugin short circuit if the request is not allowed
// - error: Any error that occurred during processing
func (p *JsonParserPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) {
return req, nil, nil
}
// PostLLMHook processes streaming responses by accumulating chunks and making accumulated content valid JSON
// Parameters:
// - ctx: The Bifrost context
// - result: The Bifrost response to be processed
// - err: The Bifrost error to be processed
//
// Returns:
// - *schemas.BifrostResponse: The processed response
// - *schemas.BifrostError: The processed error
// - error: Any error that occurred during processing
func (p *JsonParserPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) {
// If there's an error, don't process
if err != nil {
return result, err, nil
}
extraFields := result.GetExtraFields()
// Check if plugin should run based on usage type
if !p.shouldRun(ctx, extraFields.RequestType) {
return result, err, nil
}
// If no chat response, return as is
if result == nil || result.ChatResponse == nil {
return result, err, nil
}
// Get request ID for state management, if it's not set, return as is
requestID := p.getRequestID(ctx, result)
if requestID == "" {
return result, err, nil
}
// Create a deep copy of the result to avoid modifying the original pointer
// This ensures other plugins using the same pointer don't get corrupted data
resultCopy := p.deepCopyBifrostResponse(result)
if resultCopy == nil || resultCopy.ChatResponse == nil {
return result, err, nil
}
// Process only streaming choices to accumulate and fix partial JSON
if len(resultCopy.ChatResponse.Choices) > 0 {
for i := range resultCopy.ChatResponse.Choices {
choice := &resultCopy.ChatResponse.Choices[i]
// Handle only streaming response
if choice.ChatStreamResponseChoice != nil {
if choice.ChatStreamResponseChoice.Delta.Content != nil {
content := *choice.ChatStreamResponseChoice.Delta.Content
if content != "" {
// Accumulate the content
accumulated := p.accumulateContent(requestID, content)
// Process the accumulated content to make it valid JSON
fixedContent := p.parsePartialJSON(accumulated)
if !p.isValidJSON(fixedContent) {
err = &schemas.BifrostError{
Error: &schemas.ErrorField{
Message: "Invalid JSON in streaming response",
},
StreamControl: &schemas.StreamControl{
SkipStream: bifrost.Ptr(true),
},
}
return nil, err, nil
}
// Replace the delta content with the complete valid JSON
choice.ChatStreamResponseChoice.Delta.Content = &fixedContent
}
}
}
}
}
// If this is the final chunk, cleanup the accumulated content for this request
if streamEndIndicatorValue := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator); streamEndIndicatorValue != nil {
isFinalChunk, ok := streamEndIndicatorValue.(bool)
if ok && isFinalChunk {
p.ClearRequestState(requestID)
}
}
// Return the modified copy instead of the original
return resultCopy, err, nil
}
// Cleanup performs plugin cleanup and clears accumulated content
func (p *JsonParserPlugin) Cleanup() error {
// Stop the cleanup goroutine
p.StopCleanup()
p.mutex.Lock()
defer p.mutex.Unlock()
// Clear accumulated content
p.accumulatedContent = make(map[string]*AccumulatedContent)
return nil
}
// ClearRequestState clears the accumulated content for a specific request
func (p *JsonParserPlugin) ClearRequestState(requestID string) {
p.mutex.Lock()
defer p.mutex.Unlock()
delete(p.accumulatedContent, requestID)
}
// CLEANUP METHODS
// startCleanupGoroutine starts a goroutine that periodically cleans up old accumulated content
func (p *JsonParserPlugin) startCleanupGoroutine() {
ticker := time.NewTicker(p.cleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
p.cleanupOldEntries()
case <-p.stopCleanup:
return
}
}
}
// cleanupOldEntries removes accumulated content entries that are older than maxAge
func (p *JsonParserPlugin) cleanupOldEntries() {
p.mutex.Lock()
defer p.mutex.Unlock()
now := time.Now()
cutoff := now.Add(-p.maxAge)
for requestID, content := range p.accumulatedContent {
if content.Timestamp.Before(cutoff) {
delete(p.accumulatedContent, requestID)
}
}
}
// StopCleanup stops the cleanup goroutine
func (p *JsonParserPlugin) StopCleanup() {
p.stopOnce.Do(func() {
close(p.stopCleanup)
})
}