// Package prompts implements the Bifrost LLM plugin that resolves stored prompt templates // from the config store and prepends their messages to chat and Responses API requests. // HTTP clients select a prompt via x-bf-prompt-id / x-bf-prompt-version headers; optional // custom PromptResolver implementations can override how ID and version are chosen. package prompts import ( "context" "encoding/json" "fmt" "maps" "strconv" "strings" "sync" bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" ) const ( // PluginName is the canonical name registered for the prompts plugin. PluginName = "prompts" // PromptIDHeader and PromptVersionHeader are request headers copied into BifrostContext // in HTTPTransportPreHook so PreLLMHook and custom resolvers can read them. PromptIDHeader = "x-bf-prompt-id" PromptVersionHeader = "x-bf-prompt-version" // PromptIDKey and PromptVersionKey are context keys for the resolved header values. PromptIDKey schemas.BifrostContextKey = PromptIDHeader PromptVersionKey schemas.BifrostContextKey = PromptVersionHeader ) // InMemoryStore is the data source for prompts and all versions. Implementations typically // wrap the framework config store; the plugin keeps an in-memory index built by loadCache. type InMemoryStore interface { GetPrompts(ctx context.Context, folderID *string) ([]configstoreTables.TablePrompt, error) GetAllPromptVersions(ctx context.Context) ([]configstoreTables.TablePromptVersion, error) } // PromptResolver decides which prompt and version to inject for a given request. // Returning an empty promptID means no injection for this request. type PromptResolver interface { Resolve(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (promptID string, versionNumber int, err error) } // headerResolver is the default OSS resolver: it reads prompt ID and version from context // keys populated from HTTP headers in HTTPTransportPreHook (x-bf-prompt-id, x-bf-prompt-version). type headerResolver struct { logger schemas.Logger } // Resolve returns the prompt ID and version number from context. An empty promptID means // no prompt injection for this request. Version 0 means “use latest” when passed to resolveVersion. func (r *headerResolver) Resolve(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (string, int, error) { promptID := bifrost.GetStringFromContext(ctx, PromptIDKey) if promptID == "" { return "", 0, nil } versionNumber, err := parseNumberFromContext(ctx, PromptVersionKey) if err != nil { return "", 0, fmt.Errorf("failed to parse version number: %w", err) } return promptID, versionNumber, nil } // Plugin implements schemas.LLMPlugin (and HTTP transport hooks) for server-side prompt injection. // It loads prompts and versions into memory, resolves which version to use per request, merges // the version’s model parameters with the client request (request wins), and prepends template // messages before chat or Responses input. // // Fields: // - store: backing persistence for prompts and versions // - logger: Bifrost logger for non-fatal merge/param warnings // - resolver: chooses prompt ID and version; defaults to headerResolver // - mu: protects promptsByID and versionsByPromptAndNumber // - promptsByID: prompt ID → prompt row (includes LatestVersion when using “latest”) // - versionsByPromptAndNumber: prompt ID → version number → version row type Plugin struct { store InMemoryStore logger schemas.Logger resolver PromptResolver mu sync.RWMutex promptsByID map[string]*configstoreTables.TablePrompt versionsByPromptAndNumber map[string]map[int]*configstoreTables.TablePromptVersion } // Init constructs a Plugin using the default header-based resolver (x-bf-prompt-id / x-bf-prompt-version). // // Parameters: // - ctx: used for the initial loadCache call // - store: required config store backend for prompts // - logger: used by the default resolver and param merge paths // // Returns: // - schemas.LLMPlugin: the initialized plugin // - error: if the store is missing or the initial cache load fails func Init(ctx context.Context, store InMemoryStore, logger schemas.Logger) (schemas.LLMPlugin, error) { return InitWithResolver(ctx, store, &headerResolver{logger: logger}, logger) } // InitWithResolver constructs a Plugin with an explicit PromptResolver (nil falls back to headerResolver). // // Parameters: // - ctx: used for the initial loadCache call // - store: required config store backend for prompts // - resolver: custom resolution logic; if nil, headerResolver is used // - logger: passed to the default resolver when it is constructed internally // // Returns: // - *Plugin: the initialized plugin (concrete type for Reload and handler integration) // - error: if the store is missing or the initial cache load fails func InitWithResolver(ctx context.Context, store InMemoryStore, resolver PromptResolver, logger schemas.Logger) (*Plugin, error) { if store == nil { return nil, fmt.Errorf("config store is required for prompts plugin") } if resolver == nil { resolver = &headerResolver{logger: logger} } p := &Plugin{ store: store, logger: logger, resolver: resolver, promptsByID: make(map[string]*configstoreTables.TablePrompt), versionsByPromptAndNumber: make(map[string]map[int]*configstoreTables.TablePromptVersion), } if err := p.loadCache(ctx); err != nil { return nil, fmt.Errorf("failed to load prompts into memory: %w", err) } return p, nil } // loadCache rebuilds the in-memory maps with exactly two DB queries: // one for all prompts (with their latest version), one for all versions. func (p *Plugin) loadCache(ctx context.Context) error { prompts, err := p.store.GetPrompts(ctx, nil) if err != nil { return err } versions, err := p.store.GetAllPromptVersions(ctx) if err != nil { return fmt.Errorf("loading all prompt versions: %w", err) } newPrompts := make(map[string]*configstoreTables.TablePrompt, len(prompts)) for i := range prompts { newPrompts[prompts[i].ID] = &prompts[i] } newVersionsByPromptAndNumber := make(map[string]map[int]*configstoreTables.TablePromptVersion) for i := range versions { v := &versions[i] if _, ok := newVersionsByPromptAndNumber[v.PromptID]; !ok { newVersionsByPromptAndNumber[v.PromptID] = make(map[int]*configstoreTables.TablePromptVersion) } newVersionsByPromptAndNumber[v.PromptID][v.VersionNumber] = v } p.mu.Lock() p.promptsByID = newPrompts p.versionsByPromptAndNumber = newVersionsByPromptAndNumber p.mu.Unlock() return nil } // Reload refreshes the in-memory cache from the store. Called by the HTTP handler // after any create/update/delete operation on prompts or versions. func (p *Plugin) Reload(ctx context.Context) error { return p.loadCache(ctx) } // GetName returns the plugin identifier ("prompts"). func (p *Plugin) GetName() string { return PluginName } // HTTPTransportPreHook copies x-bf-prompt-id and x-bf-prompt-version from the incoming HTTP request // into BifrostContext so the default header resolver and PreLLMHook can read them. func (p *Plugin) HTTPTransportPreHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { if req == nil { return nil, nil } if id := strings.TrimSpace(req.CaseInsensitiveHeaderLookup(PromptIDHeader)); id != "" { ctx.SetValue(PromptIDKey, id) } if v := strings.TrimSpace(req.CaseInsensitiveHeaderLookup(PromptVersionHeader)); v != "" { ctx.SetValue(PromptVersionKey, v) } return nil, nil } // HTTPTransportPostHook is a no-op; this plugin does not modify HTTP response headers. func (p *Plugin) HTTPTransportPostHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, resp *schemas.HTTPResponse) error { return nil } // HTTPTransportStreamChunkHook passes streaming chunks through unchanged; prompt injection // happens in PreLLMHook before the provider call. func (p *Plugin) HTTPTransportStreamChunkHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, chunk *schemas.BifrostStreamChunk) (*schemas.BifrostStreamChunk, error) { return chunk, nil } // PreLLMHook resolves the prompt via PromptResolver, loads the version from the in-memory // cache, sets governance/observability context (selected prompt name and version), merges // version ModelParams with the request (request overrides), converts stored messages to // chat messages, and prepends them to Chat or Responses input. Non-HTTP transports rely // on context keys set by callers instead of HTTPTransportPreHook. // // Parameters: // - ctx: may set BifrostContextKeySelectedPromptName, BifrostContextKeySelectedPromptID and BifrostContextKeySelectedPromptVersion when a prompt is applied // - req: chat or Responses request to mutate in place // // Returns: // - *schemas.BifrostRequest: possibly modified request // - *schemas.LLMPluginShortCircuit: always nil // - error: resolution failure or missing prompt/version; invalid or empty template returns // the request unchanged with a nil error func (p *Plugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) { if req == nil { return req, nil, nil } promptID, versionNumber, err := p.resolver.Resolve(ctx, req) if err != nil { p.logger.Warn("prompts plugin: failed to resolve prompt: %v", err) return req, nil, nil } if promptID == "" { return req, nil, nil } prompt, version, found := p.resolveVersion(promptID, versionNumber) if !found { p.logger.Warn("prompts plugin: prompt or version not found: promptID=%s versionNumber=%d", promptID, versionNumber) return req, nil, nil } if version == nil { p.logger.Warn("prompts plugin: prompt has no resolved version: promptID=%s", promptID) return req, nil, nil } if prompt != nil && prompt.Name != "" { ctx.SetValue(schemas.BifrostContextKeySelectedPromptID, prompt.ID) ctx.SetValue(schemas.BifrostContextKeySelectedPromptName, prompt.Name) } ctx.SetValue(schemas.BifrostContextKeySelectedPromptVersion, strconv.Itoa(version.VersionNumber)) // Apply model params from the version (version params are defaults; request params win). switch { case req.ChatRequest != nil: applyVersionParamsToChatRequest(version, req.ChatRequest, p.logger) case req.ResponsesRequest != nil: applyVersionParamsToResponsesRequest(version, req.ResponsesRequest, p.logger) } template, err := chatMessagesFromVersionMessages(version.Messages) if err != nil { p.logger.Warn("prompts plugin: failed to convert version messages to chat messages: %v", err) return req, nil, nil } if len(template) == 0 { p.logger.Warn("prompts plugin: no template messages found for prompt %s version %d", promptID, version.VersionNumber) return req, nil, nil } switch { case req.ChatRequest != nil: mergeChatMessages(&req.ChatRequest.Input, template) case req.ResponsesRequest != nil: mergeResponsesMessages(&req.ResponsesRequest.Input, template) } return req, nil, nil } // PostLLMHook is a no-op; the plugin does not modify responses. func (p *Plugin) PostLLMHook(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { return resp, bifrostErr, nil } // knownSyntheticChatParamKeys are flat JSON keys that ChatParameters.UnmarshalJSON // promotes into nested structs. They should not be treated as ExtraParams even though // they won't appear as top-level keys in a re-marshaled ChatParameters. var knownSyntheticChatParamKeys = map[string]struct{}{ "reasoning_effort": {}, "reasoning_max_tokens": {}, } // buildMergedParamsMap builds a merged map[string]interface{} where version params // serve as defaults and request params take priority. reqParamsBytes is the JSON of // the request's standard params (ExtraParams excluded); reqExtraParams is its ExtraParams map. func buildMergedParamsMap(versionParams configstoreTables.ModelParams, reqParamsBytes []byte, reqExtraParams map[string]interface{}) (map[string]interface{}, error) { merged := make(map[string]interface{}, len(versionParams)) maps.Copy(merged, versionParams) if len(reqParamsBytes) > 0 && string(reqParamsBytes) != "null" { var reqMap map[string]interface{} if err := schemas.Unmarshal(reqParamsBytes, &reqMap); err != nil { return nil, fmt.Errorf("unmarshal request params: %w", err) } maps.Copy(merged, reqMap) } maps.Copy(merged, reqExtraParams) return merged, nil } // applyVersionParamsToChatRequest applies the prompt version's ModelParams to the // chat request. Version params are defaults; params already set in the request win. func applyVersionParamsToChatRequest(version *configstoreTables.TablePromptVersion, req *schemas.BifrostChatRequest, logger schemas.Logger) { if len(version.ModelParams) == 0 { return } var reqParamsBytes []byte var reqExtraParams map[string]interface{} if req.Params != nil { b, err := schemas.Marshal(req.Params) if err != nil { logger.Warn("prompts plugin: failed to marshal chat request params: %v", err) return } reqParamsBytes = b reqExtraParams = req.Params.ExtraParams } merged, err := buildMergedParamsMap(version.ModelParams, reqParamsBytes, reqExtraParams) if err != nil { logger.Warn("prompts plugin: failed to build merged chat params: %v", err) return } mergedJSON, err := schemas.Marshal(merged) if err != nil { logger.Warn("prompts plugin: failed to marshal merged chat params: %v", err) return } var result schemas.ChatParameters if err := schemas.Unmarshal(mergedJSON, &result); err != nil { logger.Warn("prompts plugin: failed to unmarshal merged chat params: %v", err) return } // Detect keys from merged that were not recognized as standard ChatParameters fields // (i.e. they won't appear in the re-marshaled output) and put them in ExtraParams. var recognizedMap map[string]interface{} recognizedBytes, err := schemas.Marshal(&result) if err != nil { logger.Warn("prompts plugin: failed to marshal result chat params: %v", err) return } if err := schemas.Unmarshal(recognizedBytes, &recognizedMap); err != nil { logger.Warn("prompts plugin: failed to unmarshal recognized chat params: %v", err) return } for k, v := range merged { if _, ok := recognizedMap[k]; ok { continue } if _, synthetic := knownSyntheticChatParamKeys[k]; synthetic { continue } if result.ExtraParams == nil { result.ExtraParams = make(map[string]interface{}) } if _, alreadySet := result.ExtraParams[k]; !alreadySet { result.ExtraParams[k] = v } } req.Params = &result } // applyVersionParamsToResponsesRequest applies the prompt version's ModelParams to the // responses request. Version params are defaults; params already set in the request win. func applyVersionParamsToResponsesRequest(version *configstoreTables.TablePromptVersion, req *schemas.BifrostResponsesRequest, logger schemas.Logger) { if len(version.ModelParams) == 0 { return } var reqParamsBytes []byte var reqExtraParams map[string]interface{} if req.Params != nil { b, err := schemas.Marshal(req.Params) if err != nil { logger.Warn("prompts plugin: failed to marshal responses request params: %v", err) return } reqParamsBytes = b reqExtraParams = req.Params.ExtraParams } merged, err := buildMergedParamsMap(version.ModelParams, reqParamsBytes, reqExtraParams) if err != nil { logger.Warn("prompts plugin: failed to build merged responses params: %v", err) return } mergedJSON, err := schemas.Marshal(merged) if err != nil { logger.Warn("prompts plugin: failed to marshal merged responses params: %v", err) return } var result schemas.ResponsesParameters if err := schemas.Unmarshal(mergedJSON, &result); err != nil { logger.Warn("prompts plugin: failed to unmarshal merged responses params: %v", err) return } // Detect unrecognized keys and add them to ExtraParams. var recognizedMap map[string]interface{} recognizedBytes, err := schemas.Marshal(&result) if err != nil { logger.Warn("prompts plugin: failed to marshal result responses params: %v", err) return } if err := schemas.Unmarshal(recognizedBytes, &recognizedMap); err != nil { logger.Warn("prompts plugin: failed to unmarshal recognized responses params: %v", err) return } for k, v := range merged { if _, ok := recognizedMap[k]; ok { continue } if result.ExtraParams == nil { result.ExtraParams = make(map[string]interface{}) } if _, alreadySet := result.ExtraParams[k]; !alreadySet { result.ExtraParams[k] = v } } req.Params = &result } // resolveVersion centralises the map-lookup logic shared by setPromptStreamFromVersionForTransport // and PreLLMHook. It returns the prompt and its resolved version. // // If versionNumber > 0, that explicit version is loaded from versionsByPromptAndNumber (from // x-bf-prompt-version header or a custom PromptResolver such as deployment traffic routing). // If versionNumber == 0, the prompt's latest version is used (no header / resolver chose latest). func (p *Plugin) resolveVersion(promptID string, versionNumber int) ( *configstoreTables.TablePrompt, *configstoreTables.TablePromptVersion, bool, ) { p.mu.RLock() defer p.mu.RUnlock() prompt, ok := p.promptsByID[promptID] if !ok || prompt == nil { return nil, nil, false } if versionNumber > 0 { byNumber, ok := p.versionsByPromptAndNumber[promptID] if !ok { return nil, nil, false } v, found := byNumber[versionNumber] if !found || v == nil { return nil, nil, false } return prompt, v, true } return prompt, prompt.LatestVersion, true } // Cleanup releases plugin resources; the prompts plugin has nothing to tear down. func (p *Plugin) Cleanup() error { return nil } // parseNumberFromContext parses a decimal integer from a string context value. Missing or // empty values yield 0 with no error (treated as “no explicit version”). func parseNumberFromContext(ctx *schemas.BifrostContext, key schemas.BifrostContextKey) (num int, err error) { s, ok := ctx.Value(key).(string) if !ok { return 0, nil } s = strings.TrimSpace(s) if s == "" { return 0, nil } n, err := strconv.ParseInt(s, 10, 64) if err != nil { return 0, err } return int(n), nil } // chatMessagePopulated reports whether a ChatMessage carries any meaningful content for injection. func chatMessagePopulated(cm schemas.ChatMessage) bool { if strings.TrimSpace(string(cm.Role)) != "" { return true } if cm.Content != nil { return true } if cm.Name != nil && strings.TrimSpace(*cm.Name) != "" { return true } if cm.ChatToolMessage != nil { return true } if cm.ChatAssistantMessage != nil { return true } return false } // convertVersionMessagesToChatMessages unmarshals prompt-repo JSON into ChatMessage. func convertVersionMessagesToChatMessages(data []byte) (schemas.ChatMessage, error) { s := strings.TrimSpace(string(data)) if s == "" || s == "null" { return schemas.ChatMessage{}, fmt.Errorf("empty message") } data = []byte(s) var msg struct { OriginalType string `json:"originalType"` Payload json.RawMessage `json:"payload"` } if err := schemas.Unmarshal(data, &msg); err == nil { ps := strings.TrimSpace(string(msg.Payload)) if ps != "" && ps != "null" { if msg.OriginalType == "completion_result" { var result struct { Choices []struct { Message *schemas.ChatMessage `json:"message"` } `json:"choices"` } if err := schemas.Unmarshal([]byte(ps), &result); err == nil && len(result.Choices) > 0 && result.Choices[0].Message != nil { if chatMessagePopulated(*result.Choices[0].Message) { return *result.Choices[0].Message, nil } } } // completion_request / tool_result / legacy envelope: payload is a direct ChatMessage. var message schemas.ChatMessage if err := schemas.Unmarshal([]byte(ps), &message); err != nil { return schemas.ChatMessage{}, fmt.Errorf("decoding prompt message envelope payload: %w", err) } if chatMessagePopulated(message) { return message, nil } } } var chatMessage schemas.ChatMessage if err := schemas.Unmarshal(data, &chatMessage); err != nil { return schemas.ChatMessage{}, err } return chatMessage, nil } // chatMessagesFromVersionMessages decodes each stored row into schemas.ChatMessage, preferring // Message bytes and falling back to MessageJSON when needed. func chatMessagesFromVersionMessages(messages []configstoreTables.TablePromptVersionMessage) ([]schemas.ChatMessage, error) { out := make([]schemas.ChatMessage, 0, len(messages)) for i := range messages { row := &messages[i] data := row.Message if len(data) == 0 && row.MessageJSON != "" { data = []byte(row.MessageJSON) } cm, err := convertVersionMessagesToChatMessages(data) if err != nil { return nil, fmt.Errorf("stored prompt message is not valid chat JSON: %w", err) } out = append(out, cm) } return out, nil } // mergeChatMessages prepends prefix to the chat input slice (template first, then client messages). func mergeChatMessages(dest *[]schemas.ChatMessage, prefix []schemas.ChatMessage) { if dest == nil || len(prefix) == 0 { return } cur := *dest merged := make([]schemas.ChatMessage, 0, len(prefix)+len(cur)) merged = append(merged, prefix...) merged = append(merged, cur...) *dest = merged } // mergeResponsesMessages converts template chat messages to ResponsesMessage entries and // prepends them before the client’s Responses input. func mergeResponsesMessages(dest *[]schemas.ResponsesMessage, template []schemas.ChatMessage) { if dest == nil || len(template) == 0 { return } var prefix []schemas.ResponsesMessage for i := range template { prefix = append(prefix, template[i].ToResponsesMessages()...) } cur := *dest merged := make([]schemas.ResponsesMessage, 0, len(prefix)+len(cur)) merged = append(merged, prefix...) merged = append(merged, cur...) *dest = merged }