607 lines
22 KiB
Go
607 lines
22 KiB
Go
// Package prompts implements the Bifrost LLM plugin that resolves stored prompt templates
|
||
// from the config store and prepends their messages to chat and Responses API requests.
|
||
// HTTP clients select a prompt via x-bf-prompt-id / x-bf-prompt-version headers; optional
|
||
// custom PromptResolver implementations can override how ID and version are chosen.
|
||
package prompts
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"maps"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
|
||
bifrost "github.com/maximhq/bifrost/core"
|
||
"github.com/maximhq/bifrost/core/schemas"
|
||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||
)
|
||
|
||
const (
|
||
// PluginName is the canonical name registered for the prompts plugin.
|
||
PluginName = "prompts"
|
||
|
||
// PromptIDHeader and PromptVersionHeader are request headers copied into BifrostContext
|
||
// in HTTPTransportPreHook so PreLLMHook and custom resolvers can read them.
|
||
PromptIDHeader = "x-bf-prompt-id"
|
||
PromptVersionHeader = "x-bf-prompt-version"
|
||
|
||
// PromptIDKey and PromptVersionKey are context keys for the resolved header values.
|
||
PromptIDKey schemas.BifrostContextKey = PromptIDHeader
|
||
PromptVersionKey schemas.BifrostContextKey = PromptVersionHeader
|
||
)
|
||
|
||
// InMemoryStore is the data source for prompts and all versions. Implementations typically
|
||
// wrap the framework config store; the plugin keeps an in-memory index built by loadCache.
|
||
type InMemoryStore interface {
|
||
GetPrompts(ctx context.Context, folderID *string) ([]configstoreTables.TablePrompt, error)
|
||
GetAllPromptVersions(ctx context.Context) ([]configstoreTables.TablePromptVersion, error)
|
||
}
|
||
|
||
// PromptResolver decides which prompt and version to inject for a given request.
|
||
// Returning an empty promptID means no injection for this request.
|
||
type PromptResolver interface {
|
||
Resolve(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (promptID string, versionNumber int, err error)
|
||
}
|
||
|
||
// headerResolver is the default OSS resolver: it reads prompt ID and version from context
|
||
// keys populated from HTTP headers in HTTPTransportPreHook (x-bf-prompt-id, x-bf-prompt-version).
|
||
type headerResolver struct {
|
||
logger schemas.Logger
|
||
}
|
||
|
||
// Resolve returns the prompt ID and version number from context. An empty promptID means
|
||
// no prompt injection for this request. Version 0 means “use latest” when passed to resolveVersion.
|
||
func (r *headerResolver) Resolve(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (string, int, error) {
|
||
promptID := bifrost.GetStringFromContext(ctx, PromptIDKey)
|
||
if promptID == "" {
|
||
return "", 0, nil
|
||
}
|
||
versionNumber, err := parseNumberFromContext(ctx, PromptVersionKey)
|
||
if err != nil {
|
||
return "", 0, fmt.Errorf("failed to parse version number: %w", err)
|
||
}
|
||
return promptID, versionNumber, nil
|
||
}
|
||
|
||
// Plugin implements schemas.LLMPlugin (and HTTP transport hooks) for server-side prompt injection.
|
||
// It loads prompts and versions into memory, resolves which version to use per request, merges
|
||
// the version’s model parameters with the client request (request wins), and prepends template
|
||
// messages before chat or Responses input.
|
||
//
|
||
// Fields:
|
||
// - store: backing persistence for prompts and versions
|
||
// - logger: Bifrost logger for non-fatal merge/param warnings
|
||
// - resolver: chooses prompt ID and version; defaults to headerResolver
|
||
// - mu: protects promptsByID and versionsByPromptAndNumber
|
||
// - promptsByID: prompt ID → prompt row (includes LatestVersion when using “latest”)
|
||
// - versionsByPromptAndNumber: prompt ID → version number → version row
|
||
type Plugin struct {
|
||
store InMemoryStore
|
||
logger schemas.Logger
|
||
resolver PromptResolver
|
||
|
||
mu sync.RWMutex
|
||
promptsByID map[string]*configstoreTables.TablePrompt
|
||
versionsByPromptAndNumber map[string]map[int]*configstoreTables.TablePromptVersion
|
||
}
|
||
|
||
// Init constructs a Plugin using the default header-based resolver (x-bf-prompt-id / x-bf-prompt-version).
|
||
//
|
||
// Parameters:
|
||
// - ctx: used for the initial loadCache call
|
||
// - store: required config store backend for prompts
|
||
// - logger: used by the default resolver and param merge paths
|
||
//
|
||
// Returns:
|
||
// - schemas.LLMPlugin: the initialized plugin
|
||
// - error: if the store is missing or the initial cache load fails
|
||
func Init(ctx context.Context, store InMemoryStore, logger schemas.Logger) (schemas.LLMPlugin, error) {
|
||
return InitWithResolver(ctx, store, &headerResolver{logger: logger}, logger)
|
||
}
|
||
|
||
// InitWithResolver constructs a Plugin with an explicit PromptResolver (nil falls back to headerResolver).
|
||
//
|
||
// Parameters:
|
||
// - ctx: used for the initial loadCache call
|
||
// - store: required config store backend for prompts
|
||
// - resolver: custom resolution logic; if nil, headerResolver is used
|
||
// - logger: passed to the default resolver when it is constructed internally
|
||
//
|
||
// Returns:
|
||
// - *Plugin: the initialized plugin (concrete type for Reload and handler integration)
|
||
// - error: if the store is missing or the initial cache load fails
|
||
func InitWithResolver(ctx context.Context, store InMemoryStore, resolver PromptResolver, logger schemas.Logger) (*Plugin, error) {
|
||
if store == nil {
|
||
return nil, fmt.Errorf("config store is required for prompts plugin")
|
||
}
|
||
if resolver == nil {
|
||
resolver = &headerResolver{logger: logger}
|
||
}
|
||
p := &Plugin{
|
||
store: store,
|
||
logger: logger,
|
||
resolver: resolver,
|
||
promptsByID: make(map[string]*configstoreTables.TablePrompt),
|
||
versionsByPromptAndNumber: make(map[string]map[int]*configstoreTables.TablePromptVersion),
|
||
}
|
||
if err := p.loadCache(ctx); err != nil {
|
||
return nil, fmt.Errorf("failed to load prompts into memory: %w", err)
|
||
}
|
||
return p, nil
|
||
}
|
||
|
||
// loadCache rebuilds the in-memory maps with exactly two DB queries:
|
||
// one for all prompts (with their latest version), one for all versions.
|
||
func (p *Plugin) loadCache(ctx context.Context) error {
|
||
prompts, err := p.store.GetPrompts(ctx, nil)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
versions, err := p.store.GetAllPromptVersions(ctx)
|
||
if err != nil {
|
||
return fmt.Errorf("loading all prompt versions: %w", err)
|
||
}
|
||
|
||
newPrompts := make(map[string]*configstoreTables.TablePrompt, len(prompts))
|
||
for i := range prompts {
|
||
newPrompts[prompts[i].ID] = &prompts[i]
|
||
}
|
||
|
||
newVersionsByPromptAndNumber := make(map[string]map[int]*configstoreTables.TablePromptVersion)
|
||
for i := range versions {
|
||
v := &versions[i]
|
||
if _, ok := newVersionsByPromptAndNumber[v.PromptID]; !ok {
|
||
newVersionsByPromptAndNumber[v.PromptID] = make(map[int]*configstoreTables.TablePromptVersion)
|
||
}
|
||
newVersionsByPromptAndNumber[v.PromptID][v.VersionNumber] = v
|
||
}
|
||
|
||
p.mu.Lock()
|
||
p.promptsByID = newPrompts
|
||
p.versionsByPromptAndNumber = newVersionsByPromptAndNumber
|
||
p.mu.Unlock()
|
||
return nil
|
||
}
|
||
|
||
// Reload refreshes the in-memory cache from the store. Called by the HTTP handler
|
||
// after any create/update/delete operation on prompts or versions.
|
||
func (p *Plugin) Reload(ctx context.Context) error {
|
||
return p.loadCache(ctx)
|
||
}
|
||
|
||
// GetName returns the plugin identifier ("prompts").
|
||
func (p *Plugin) GetName() string {
|
||
return PluginName
|
||
}
|
||
|
||
// HTTPTransportPreHook copies x-bf-prompt-id and x-bf-prompt-version from the incoming HTTP request
|
||
// into BifrostContext so the default header resolver and PreLLMHook can read them.
|
||
func (p *Plugin) HTTPTransportPreHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) {
|
||
if req == nil {
|
||
return nil, nil
|
||
}
|
||
if id := strings.TrimSpace(req.CaseInsensitiveHeaderLookup(PromptIDHeader)); id != "" {
|
||
ctx.SetValue(PromptIDKey, id)
|
||
}
|
||
if v := strings.TrimSpace(req.CaseInsensitiveHeaderLookup(PromptVersionHeader)); v != "" {
|
||
ctx.SetValue(PromptVersionKey, v)
|
||
}
|
||
return nil, nil
|
||
}
|
||
|
||
// HTTPTransportPostHook is a no-op; this plugin does not modify HTTP response headers.
|
||
func (p *Plugin) HTTPTransportPostHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, resp *schemas.HTTPResponse) error {
|
||
return nil
|
||
}
|
||
|
||
// HTTPTransportStreamChunkHook passes streaming chunks through unchanged; prompt injection
|
||
// happens in PreLLMHook before the provider call.
|
||
func (p *Plugin) HTTPTransportStreamChunkHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, chunk *schemas.BifrostStreamChunk) (*schemas.BifrostStreamChunk, error) {
|
||
return chunk, nil
|
||
}
|
||
|
||
// PreLLMHook resolves the prompt via PromptResolver, loads the version from the in-memory
|
||
// cache, sets governance/observability context (selected prompt name and version), merges
|
||
// version ModelParams with the request (request overrides), converts stored messages to
|
||
// chat messages, and prepends them to Chat or Responses input. Non-HTTP transports rely
|
||
// on context keys set by callers instead of HTTPTransportPreHook.
|
||
//
|
||
// Parameters:
|
||
// - ctx: may set BifrostContextKeySelectedPromptName, BifrostContextKeySelectedPromptID and BifrostContextKeySelectedPromptVersion when a prompt is applied
|
||
// - req: chat or Responses request to mutate in place
|
||
//
|
||
// Returns:
|
||
// - *schemas.BifrostRequest: possibly modified request
|
||
// - *schemas.LLMPluginShortCircuit: always nil
|
||
// - error: resolution failure or missing prompt/version; invalid or empty template returns
|
||
// the request unchanged with a nil error
|
||
func (p *Plugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) {
|
||
if req == nil {
|
||
return req, nil, nil
|
||
}
|
||
|
||
promptID, versionNumber, err := p.resolver.Resolve(ctx, req)
|
||
if err != nil {
|
||
p.logger.Warn("prompts plugin: failed to resolve prompt: %v", err)
|
||
return req, nil, nil
|
||
}
|
||
if promptID == "" {
|
||
return req, nil, nil
|
||
}
|
||
|
||
prompt, version, found := p.resolveVersion(promptID, versionNumber)
|
||
if !found {
|
||
p.logger.Warn("prompts plugin: prompt or version not found: promptID=%s versionNumber=%d", promptID, versionNumber)
|
||
return req, nil, nil
|
||
}
|
||
|
||
if version == nil {
|
||
p.logger.Warn("prompts plugin: prompt has no resolved version: promptID=%s", promptID)
|
||
return req, nil, nil
|
||
}
|
||
|
||
if prompt != nil && prompt.Name != "" {
|
||
ctx.SetValue(schemas.BifrostContextKeySelectedPromptID, prompt.ID)
|
||
ctx.SetValue(schemas.BifrostContextKeySelectedPromptName, prompt.Name)
|
||
}
|
||
ctx.SetValue(schemas.BifrostContextKeySelectedPromptVersion, strconv.Itoa(version.VersionNumber))
|
||
|
||
// Apply model params from the version (version params are defaults; request params win).
|
||
switch {
|
||
case req.ChatRequest != nil:
|
||
applyVersionParamsToChatRequest(version, req.ChatRequest, p.logger)
|
||
case req.ResponsesRequest != nil:
|
||
applyVersionParamsToResponsesRequest(version, req.ResponsesRequest, p.logger)
|
||
}
|
||
|
||
template, err := chatMessagesFromVersionMessages(version.Messages)
|
||
if err != nil {
|
||
p.logger.Warn("prompts plugin: failed to convert version messages to chat messages: %v", err)
|
||
return req, nil, nil
|
||
}
|
||
if len(template) == 0 {
|
||
p.logger.Warn("prompts plugin: no template messages found for prompt %s version %d", promptID, version.VersionNumber)
|
||
return req, nil, nil
|
||
}
|
||
|
||
switch {
|
||
case req.ChatRequest != nil:
|
||
mergeChatMessages(&req.ChatRequest.Input, template)
|
||
case req.ResponsesRequest != nil:
|
||
mergeResponsesMessages(&req.ResponsesRequest.Input, template)
|
||
}
|
||
|
||
return req, nil, nil
|
||
}
|
||
|
||
// PostLLMHook is a no-op; the plugin does not modify responses.
|
||
func (p *Plugin) PostLLMHook(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) {
|
||
return resp, bifrostErr, nil
|
||
}
|
||
|
||
// knownSyntheticChatParamKeys are flat JSON keys that ChatParameters.UnmarshalJSON
|
||
// promotes into nested structs. They should not be treated as ExtraParams even though
|
||
// they won't appear as top-level keys in a re-marshaled ChatParameters.
|
||
var knownSyntheticChatParamKeys = map[string]struct{}{
|
||
"reasoning_effort": {},
|
||
"reasoning_max_tokens": {},
|
||
}
|
||
|
||
// buildMergedParamsMap builds a merged map[string]interface{} where version params
|
||
// serve as defaults and request params take priority. reqParamsBytes is the JSON of
|
||
// the request's standard params (ExtraParams excluded); reqExtraParams is its ExtraParams map.
|
||
func buildMergedParamsMap(versionParams configstoreTables.ModelParams, reqParamsBytes []byte, reqExtraParams map[string]interface{}) (map[string]interface{}, error) {
|
||
merged := make(map[string]interface{}, len(versionParams))
|
||
maps.Copy(merged, versionParams)
|
||
if len(reqParamsBytes) > 0 && string(reqParamsBytes) != "null" {
|
||
var reqMap map[string]interface{}
|
||
if err := schemas.Unmarshal(reqParamsBytes, &reqMap); err != nil {
|
||
return nil, fmt.Errorf("unmarshal request params: %w", err)
|
||
}
|
||
maps.Copy(merged, reqMap)
|
||
}
|
||
maps.Copy(merged, reqExtraParams)
|
||
return merged, nil
|
||
}
|
||
|
||
// applyVersionParamsToChatRequest applies the prompt version's ModelParams to the
|
||
// chat request. Version params are defaults; params already set in the request win.
|
||
func applyVersionParamsToChatRequest(version *configstoreTables.TablePromptVersion, req *schemas.BifrostChatRequest, logger schemas.Logger) {
|
||
if len(version.ModelParams) == 0 {
|
||
return
|
||
}
|
||
|
||
var reqParamsBytes []byte
|
||
var reqExtraParams map[string]interface{}
|
||
if req.Params != nil {
|
||
b, err := schemas.Marshal(req.Params)
|
||
if err != nil {
|
||
logger.Warn("prompts plugin: failed to marshal chat request params: %v", err)
|
||
return
|
||
}
|
||
reqParamsBytes = b
|
||
reqExtraParams = req.Params.ExtraParams
|
||
}
|
||
|
||
merged, err := buildMergedParamsMap(version.ModelParams, reqParamsBytes, reqExtraParams)
|
||
if err != nil {
|
||
logger.Warn("prompts plugin: failed to build merged chat params: %v", err)
|
||
return
|
||
}
|
||
|
||
mergedJSON, err := schemas.Marshal(merged)
|
||
if err != nil {
|
||
logger.Warn("prompts plugin: failed to marshal merged chat params: %v", err)
|
||
return
|
||
}
|
||
|
||
var result schemas.ChatParameters
|
||
if err := schemas.Unmarshal(mergedJSON, &result); err != nil {
|
||
logger.Warn("prompts plugin: failed to unmarshal merged chat params: %v", err)
|
||
return
|
||
}
|
||
|
||
// Detect keys from merged that were not recognized as standard ChatParameters fields
|
||
// (i.e. they won't appear in the re-marshaled output) and put them in ExtraParams.
|
||
var recognizedMap map[string]interface{}
|
||
recognizedBytes, err := schemas.Marshal(&result)
|
||
if err != nil {
|
||
logger.Warn("prompts plugin: failed to marshal result chat params: %v", err)
|
||
return
|
||
}
|
||
if err := schemas.Unmarshal(recognizedBytes, &recognizedMap); err != nil {
|
||
logger.Warn("prompts plugin: failed to unmarshal recognized chat params: %v", err)
|
||
return
|
||
}
|
||
for k, v := range merged {
|
||
if _, ok := recognizedMap[k]; ok {
|
||
continue
|
||
}
|
||
if _, synthetic := knownSyntheticChatParamKeys[k]; synthetic {
|
||
continue
|
||
}
|
||
if result.ExtraParams == nil {
|
||
result.ExtraParams = make(map[string]interface{})
|
||
}
|
||
if _, alreadySet := result.ExtraParams[k]; !alreadySet {
|
||
result.ExtraParams[k] = v
|
||
}
|
||
}
|
||
|
||
req.Params = &result
|
||
}
|
||
|
||
// applyVersionParamsToResponsesRequest applies the prompt version's ModelParams to the
|
||
// responses request. Version params are defaults; params already set in the request win.
|
||
func applyVersionParamsToResponsesRequest(version *configstoreTables.TablePromptVersion, req *schemas.BifrostResponsesRequest, logger schemas.Logger) {
|
||
if len(version.ModelParams) == 0 {
|
||
return
|
||
}
|
||
|
||
var reqParamsBytes []byte
|
||
var reqExtraParams map[string]interface{}
|
||
if req.Params != nil {
|
||
b, err := schemas.Marshal(req.Params)
|
||
if err != nil {
|
||
logger.Warn("prompts plugin: failed to marshal responses request params: %v", err)
|
||
return
|
||
}
|
||
reqParamsBytes = b
|
||
reqExtraParams = req.Params.ExtraParams
|
||
}
|
||
|
||
merged, err := buildMergedParamsMap(version.ModelParams, reqParamsBytes, reqExtraParams)
|
||
if err != nil {
|
||
logger.Warn("prompts plugin: failed to build merged responses params: %v", err)
|
||
return
|
||
}
|
||
|
||
mergedJSON, err := schemas.Marshal(merged)
|
||
if err != nil {
|
||
logger.Warn("prompts plugin: failed to marshal merged responses params: %v", err)
|
||
return
|
||
}
|
||
|
||
var result schemas.ResponsesParameters
|
||
if err := schemas.Unmarshal(mergedJSON, &result); err != nil {
|
||
logger.Warn("prompts plugin: failed to unmarshal merged responses params: %v", err)
|
||
return
|
||
}
|
||
|
||
// Detect unrecognized keys and add them to ExtraParams.
|
||
var recognizedMap map[string]interface{}
|
||
recognizedBytes, err := schemas.Marshal(&result)
|
||
if err != nil {
|
||
logger.Warn("prompts plugin: failed to marshal result responses params: %v", err)
|
||
return
|
||
}
|
||
if err := schemas.Unmarshal(recognizedBytes, &recognizedMap); err != nil {
|
||
logger.Warn("prompts plugin: failed to unmarshal recognized responses params: %v", err)
|
||
return
|
||
}
|
||
for k, v := range merged {
|
||
if _, ok := recognizedMap[k]; ok {
|
||
continue
|
||
}
|
||
if result.ExtraParams == nil {
|
||
result.ExtraParams = make(map[string]interface{})
|
||
}
|
||
if _, alreadySet := result.ExtraParams[k]; !alreadySet {
|
||
result.ExtraParams[k] = v
|
||
}
|
||
}
|
||
|
||
req.Params = &result
|
||
}
|
||
|
||
// resolveVersion centralises the map-lookup logic shared by setPromptStreamFromVersionForTransport
|
||
// and PreLLMHook. It returns the prompt and its resolved version.
|
||
//
|
||
// If versionNumber > 0, that explicit version is loaded from versionsByPromptAndNumber (from
|
||
// x-bf-prompt-version header or a custom PromptResolver such as deployment traffic routing).
|
||
// If versionNumber == 0, the prompt's latest version is used (no header / resolver chose latest).
|
||
func (p *Plugin) resolveVersion(promptID string, versionNumber int) (
|
||
*configstoreTables.TablePrompt, *configstoreTables.TablePromptVersion, bool,
|
||
) {
|
||
p.mu.RLock()
|
||
defer p.mu.RUnlock()
|
||
|
||
prompt, ok := p.promptsByID[promptID]
|
||
if !ok || prompt == nil {
|
||
return nil, nil, false
|
||
}
|
||
if versionNumber > 0 {
|
||
byNumber, ok := p.versionsByPromptAndNumber[promptID]
|
||
if !ok {
|
||
return nil, nil, false
|
||
}
|
||
v, found := byNumber[versionNumber]
|
||
if !found || v == nil {
|
||
return nil, nil, false
|
||
}
|
||
return prompt, v, true
|
||
}
|
||
return prompt, prompt.LatestVersion, true
|
||
}
|
||
|
||
// Cleanup releases plugin resources; the prompts plugin has nothing to tear down.
|
||
func (p *Plugin) Cleanup() error {
|
||
return nil
|
||
}
|
||
|
||
// parseNumberFromContext parses a decimal integer from a string context value. Missing or
|
||
// empty values yield 0 with no error (treated as “no explicit version”).
|
||
func parseNumberFromContext(ctx *schemas.BifrostContext, key schemas.BifrostContextKey) (num int, err error) {
|
||
s, ok := ctx.Value(key).(string)
|
||
if !ok {
|
||
return 0, nil
|
||
}
|
||
s = strings.TrimSpace(s)
|
||
if s == "" {
|
||
return 0, nil
|
||
}
|
||
n, err := strconv.ParseInt(s, 10, 64)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return int(n), nil
|
||
}
|
||
|
||
// chatMessagePopulated reports whether a ChatMessage carries any meaningful content for injection.
|
||
func chatMessagePopulated(cm schemas.ChatMessage) bool {
|
||
if strings.TrimSpace(string(cm.Role)) != "" {
|
||
return true
|
||
}
|
||
if cm.Content != nil {
|
||
return true
|
||
}
|
||
if cm.Name != nil && strings.TrimSpace(*cm.Name) != "" {
|
||
return true
|
||
}
|
||
if cm.ChatToolMessage != nil {
|
||
return true
|
||
}
|
||
if cm.ChatAssistantMessage != nil {
|
||
return true
|
||
}
|
||
return false
|
||
}
|
||
|
||
// convertVersionMessagesToChatMessages unmarshals prompt-repo JSON into ChatMessage.
|
||
func convertVersionMessagesToChatMessages(data []byte) (schemas.ChatMessage, error) {
|
||
s := strings.TrimSpace(string(data))
|
||
if s == "" || s == "null" {
|
||
return schemas.ChatMessage{}, fmt.Errorf("empty message")
|
||
}
|
||
data = []byte(s)
|
||
|
||
var msg struct {
|
||
OriginalType string `json:"originalType"`
|
||
Payload json.RawMessage `json:"payload"`
|
||
}
|
||
if err := schemas.Unmarshal(data, &msg); err == nil {
|
||
ps := strings.TrimSpace(string(msg.Payload))
|
||
if ps != "" && ps != "null" {
|
||
if msg.OriginalType == "completion_result" {
|
||
var result struct {
|
||
Choices []struct {
|
||
Message *schemas.ChatMessage `json:"message"`
|
||
} `json:"choices"`
|
||
}
|
||
if err := schemas.Unmarshal([]byte(ps), &result); err == nil &&
|
||
len(result.Choices) > 0 && result.Choices[0].Message != nil {
|
||
if chatMessagePopulated(*result.Choices[0].Message) {
|
||
return *result.Choices[0].Message, nil
|
||
}
|
||
}
|
||
}
|
||
|
||
// completion_request / tool_result / legacy envelope: payload is a direct ChatMessage.
|
||
var message schemas.ChatMessage
|
||
if err := schemas.Unmarshal([]byte(ps), &message); err != nil {
|
||
return schemas.ChatMessage{}, fmt.Errorf("decoding prompt message envelope payload: %w", err)
|
||
}
|
||
if chatMessagePopulated(message) {
|
||
return message, nil
|
||
}
|
||
}
|
||
}
|
||
|
||
var chatMessage schemas.ChatMessage
|
||
if err := schemas.Unmarshal(data, &chatMessage); err != nil {
|
||
return schemas.ChatMessage{}, err
|
||
}
|
||
return chatMessage, nil
|
||
}
|
||
|
||
// chatMessagesFromVersionMessages decodes each stored row into schemas.ChatMessage, preferring
|
||
// Message bytes and falling back to MessageJSON when needed.
|
||
func chatMessagesFromVersionMessages(messages []configstoreTables.TablePromptVersionMessage) ([]schemas.ChatMessage, error) {
|
||
out := make([]schemas.ChatMessage, 0, len(messages))
|
||
for i := range messages {
|
||
row := &messages[i]
|
||
data := row.Message
|
||
if len(data) == 0 && row.MessageJSON != "" {
|
||
data = []byte(row.MessageJSON)
|
||
}
|
||
cm, err := convertVersionMessagesToChatMessages(data)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("stored prompt message is not valid chat JSON: %w", err)
|
||
}
|
||
out = append(out, cm)
|
||
}
|
||
return out, nil
|
||
}
|
||
|
||
// mergeChatMessages prepends prefix to the chat input slice (template first, then client messages).
|
||
func mergeChatMessages(dest *[]schemas.ChatMessage, prefix []schemas.ChatMessage) {
|
||
if dest == nil || len(prefix) == 0 {
|
||
return
|
||
}
|
||
cur := *dest
|
||
merged := make([]schemas.ChatMessage, 0, len(prefix)+len(cur))
|
||
merged = append(merged, prefix...)
|
||
merged = append(merged, cur...)
|
||
*dest = merged
|
||
}
|
||
|
||
// mergeResponsesMessages converts template chat messages to ResponsesMessage entries and
|
||
// prepends them before the client’s Responses input.
|
||
func mergeResponsesMessages(dest *[]schemas.ResponsesMessage, template []schemas.ChatMessage) {
|
||
if dest == nil || len(template) == 0 {
|
||
return
|
||
}
|
||
var prefix []schemas.ResponsesMessage
|
||
for i := range template {
|
||
prefix = append(prefix, template[i].ToResponsesMessages()...)
|
||
}
|
||
cur := *dest
|
||
merged := make([]schemas.ResponsesMessage, 0, len(prefix)+len(cur))
|
||
merged = append(merged, prefix...)
|
||
merged = append(merged, cur...)
|
||
*dest = merged
|
||
}
|