Files
bifrost/plugins/prompts/helpers_test.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

385 lines
13 KiB
Go

package prompts
import (
"context"
"encoding/json"
"fmt"
"sync"
"github.com/maximhq/bifrost/core/schemas"
tables "github.com/maximhq/bifrost/framework/configstore/tables"
)
// ============================================================
// MockLogger — captures log output per level for assertions.
// Follows the same pattern as plugins/governance/test_utils.go.
// ============================================================
type MockLogger struct {
mu sync.Mutex
debugs []string
infos []string
warnings []string
errors []string
}
func NewMockLogger() *MockLogger {
return &MockLogger{
debugs: make([]string, 0),
infos: make([]string, 0),
warnings: make([]string, 0),
errors: make([]string, 0),
}
}
func (l *MockLogger) Debug(format string, args ...any) {
l.mu.Lock()
defer l.mu.Unlock()
l.debugs = append(l.debugs, format)
}
func (l *MockLogger) Info(format string, args ...any) {
l.mu.Lock()
defer l.mu.Unlock()
l.infos = append(l.infos, format)
}
func (l *MockLogger) Warn(format string, args ...any) {
l.mu.Lock()
defer l.mu.Unlock()
l.warnings = append(l.warnings, format)
}
func (l *MockLogger) Error(format string, args ...any) {
l.mu.Lock()
defer l.mu.Unlock()
l.errors = append(l.errors, format)
}
func (l *MockLogger) Fatal(format string, args ...any) {
l.mu.Lock()
defer l.mu.Unlock()
l.errors = append(l.errors, format)
}
func (l *MockLogger) SetLevel(_ schemas.LogLevel) {}
func (l *MockLogger) SetOutputType(_ schemas.LoggerOutputType) {}
func (l *MockLogger) LogHTTPRequest(_ schemas.LogLevel, _ string) schemas.LogEventBuilder {
return schemas.NoopLogEvent
}
// Warned returns true if at least one warning was logged.
func (l *MockLogger) Warned() bool {
l.mu.Lock()
defer l.mu.Unlock()
return len(l.warnings) > 0
}
// ============================================================
// mockStore — satisfies InMemoryStore with controllable responses.
// ============================================================
type mockStore struct {
prompts []tables.TablePrompt
versions []tables.TablePromptVersion
err error
}
func (m *mockStore) GetPrompts(_ context.Context, _ *string) ([]tables.TablePrompt, error) {
return m.prompts, m.err
}
func (m *mockStore) GetAllPromptVersions(_ context.Context) ([]tables.TablePromptVersion, error) {
return m.versions, m.err
}
// versionsErrStore succeeds on GetPrompts but fails on GetAllPromptVersions.
type versionsErrStore struct {
prompts []tables.TablePrompt
err error
}
func (s *versionsErrStore) GetPrompts(_ context.Context, _ *string) ([]tables.TablePrompt, error) {
return s.prompts, nil
}
func (s *versionsErrStore) GetAllPromptVersions(_ context.Context) ([]tables.TablePromptVersion, error) {
return nil, s.err
}
// ============================================================
// staticResolver — returns fixed IDs; decouples PreLLMHook
// tests from HTTP header / context mechanics.
// ============================================================
type staticResolver struct {
promptID string
versionNumber int
err error
}
func (r *staticResolver) Resolve(_ *schemas.BifrostContext, _ *schemas.BifrostRequest) (string, int, error) {
return r.promptID, r.versionNumber, r.err
}
// ============================================================
// Plugin builders
// ============================================================
// newPluginWithStore builds a Plugin whose store is set but maps are empty.
// Use only for loadCache tests.
func newPluginWithStore(s InMemoryStore) *Plugin {
return &Plugin{
store: s,
logger: NewMockLogger(),
resolver: &staticResolver{},
promptsByID: make(map[string]*tables.TablePrompt),
versionsByPromptAndNumber: make(map[string]map[int]*tables.TablePromptVersion),
}
}
// newTestPlugin builds a Plugin with pre-seeded in-memory maps, bypassing Init
// and loadCache entirely. The store is nil — safe as long as no test path calls
// into the store.
func newTestPlugin(resolver PromptResolver, promptMap map[string]*tables.TablePrompt, versionMap map[string]map[int]*tables.TablePromptVersion) *Plugin {
return newTestPluginWithLogger(resolver, promptMap, versionMap, NewMockLogger())
}
// newTestPluginWithLogger is like newTestPlugin but accepts a caller-provided logger
// so tests can inspect logged warnings.
func newTestPluginWithLogger(resolver PromptResolver, promptMap map[string]*tables.TablePrompt, versionMap map[string]map[int]*tables.TablePromptVersion, log schemas.Logger) *Plugin {
if resolver == nil {
resolver = &staticResolver{}
}
if promptMap == nil {
promptMap = make(map[string]*tables.TablePrompt)
}
if versionMap == nil {
versionMap = make(map[string]map[int]*tables.TablePromptVersion)
}
return &Plugin{
store: nil,
logger: log,
resolver: resolver,
promptsByID: promptMap,
versionsByPromptAndNumber: versionMap,
}
}
// ============================================================
// Message builders
// ============================================================
// versionMsg creates a TablePromptVersionMessage in the production envelope
// format {"payload": <chat_message_json>}, matching what the frontend writes
// to the DB and what AfterFind populates into the Message field.
func versionMsg(role schemas.ChatMessageRole, text string) tables.TablePromptVersionMessage {
content := text
inner := schemas.ChatMessage{
Role: role,
Content: &schemas.ChatMessageContent{ContentStr: &content},
}
innerJSON, err := json.Marshal(inner)
if err != nil {
panic(fmt.Sprintf("versionMsg: marshal inner failed: %v", err))
}
envelope := fmt.Sprintf(`{"payload":%s}`, string(innerJSON))
return tables.TablePromptVersionMessage{
Message: tables.PromptMessage(envelope),
}
}
// versionMsgViaJSON creates a TablePromptVersionMessage that has an empty Message
// field but a populated MessageJSON field, exercising the fallback branch in
// chatMessagesFromVersionMessages.
func versionMsgViaJSON(role schemas.ChatMessageRole, text string) tables.TablePromptVersionMessage {
content := text
inner := schemas.ChatMessage{
Role: role,
Content: &schemas.ChatMessageContent{ContentStr: &content},
}
innerJSON, err := json.Marshal(inner)
if err != nil {
panic(fmt.Sprintf("versionMsgViaJSON: marshal failed: %v", err))
}
envelope := fmt.Sprintf(`{"payload":%s}`, string(innerJSON))
return tables.TablePromptVersionMessage{
Message: nil, // empty — triggers MessageJSON fallback
MessageJSON: envelope,
}
}
// makeVersion returns a TablePromptVersion with the supplied messages.
// VersionNumber is set to int(id) so tests can reference versions by their number.
func makeVersion(id uint, promptID string, isLatest bool, msgs ...tables.TablePromptVersionMessage) tables.TablePromptVersion {
return tables.TablePromptVersion{
ID: id,
PromptID: promptID,
IsLatest: isLatest,
VersionNumber: int(id),
Messages: msgs,
}
}
// makePrompt returns a TablePrompt, optionally linked to a latest version.
func makePrompt(id string, latest *tables.TablePromptVersion) tables.TablePrompt {
return tables.TablePrompt{ID: id, Name: id, LatestVersion: latest}
}
// ============================================================
// Request / context builders
// ============================================================
// chatRequest returns a BifrostRequest wrapping a ChatRequest with the given messages.
func chatRequest(msgs ...schemas.ChatMessage) *schemas.BifrostRequest {
return &schemas.BifrostRequest{
ChatRequest: &schemas.BifrostChatRequest{
Input: append([]schemas.ChatMessage{}, msgs...),
},
}
}
// userMsg returns a user-role ChatMessage with plain text content.
func userMsg(text string) schemas.ChatMessage {
t := text
return schemas.ChatMessage{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{ContentStr: &t},
}
}
// systemMsg returns a system-role ChatMessage with plain text content.
func systemMsg(text string) schemas.ChatMessage {
t := text
return schemas.ChatMessage{
Role: schemas.ChatMessageRoleSystem,
Content: &schemas.ChatMessageContent{ContentStr: &t},
}
}
// bfCtx returns a fresh BifrostContext with no deadline.
func bfCtx() *schemas.BifrostContext {
return schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
}
// versionMsgWithToolCall creates a TablePromptVersionMessage for an assistant
// message that contains a single tool call (role=assistant, tool_calls=[...]).
func versionMsgWithToolCall(callID, funcName, funcArgs string) tables.TablePromptVersionMessage {
name := funcName
id := callID
inner := schemas.ChatMessage{
Role: schemas.ChatMessageRoleAssistant,
ChatAssistantMessage: &schemas.ChatAssistantMessage{
ToolCalls: []schemas.ChatAssistantMessageToolCall{
{
ID: &id,
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: &name,
Arguments: funcArgs,
},
},
},
},
}
innerJSON, err := json.Marshal(inner)
if err != nil {
panic(fmt.Sprintf("versionMsgWithToolCall: marshal failed: %v", err))
}
envelope := fmt.Sprintf(`{"payload":%s}`, string(innerJSON))
return tables.TablePromptVersionMessage{
Message: tables.PromptMessage(envelope),
}
}
// versionMsgToolResult creates a TablePromptVersionMessage for a tool-result
// message (role=tool) with the given tool_call_id and result text.
func versionMsgToolResult(callID, result string) tables.TablePromptVersionMessage {
id := callID
inner := schemas.ChatMessage{
Role: schemas.ChatMessageRoleTool,
Content: &schemas.ChatMessageContent{ContentStr: &result},
ChatToolMessage: &schemas.ChatToolMessage{ToolCallID: &id},
}
innerJSON, err := json.Marshal(inner)
if err != nil {
panic(fmt.Sprintf("versionMsgToolResult: marshal failed: %v", err))
}
envelope := fmt.Sprintf(`{"payload":%s}`, string(innerJSON))
return tables.TablePromptVersionMessage{
Message: tables.PromptMessage(envelope),
}
}
// makeVersionWithParams returns a TablePromptVersion with explicit ModelParams and messages.
// VersionNumber is set to int(id) so tests can reference versions by their number.
func makeVersionWithParams(id uint, promptID string, isLatest bool, params tables.ModelParams, msgs ...tables.TablePromptVersionMessage) tables.TablePromptVersion {
return tables.TablePromptVersion{
ID: id,
PromptID: promptID,
IsLatest: isLatest,
VersionNumber: int(id),
ModelParams: params,
Messages: msgs,
}
}
// chatRequestWithParams returns a BifrostRequest with Params pre-set.
func chatRequestWithParams(params *schemas.ChatParameters, msgs ...schemas.ChatMessage) *schemas.BifrostRequest {
return &schemas.BifrostRequest{
ChatRequest: &schemas.BifrostChatRequest{
Input: append([]schemas.ChatMessage{}, msgs...),
Params: params,
},
}
}
// chatRequestWithModel returns a BifrostRequest with the Model field pre-set.
func chatRequestWithModel(model string, msgs ...schemas.ChatMessage) *schemas.BifrostRequest {
return &schemas.BifrostRequest{
ChatRequest: &schemas.BifrostChatRequest{
Model: model,
Input: append([]schemas.ChatMessage{}, msgs...),
},
}
}
// versionMsgAssistantUIFormat creates a TablePromptVersionMessage in the format
// the Bifrost UI writes for assistant (completion_result) messages.
// The message is nested at payload.choices[0].message, matching SerializedMessage.
func versionMsgAssistantUIFormat(text string) tables.TablePromptVersionMessage {
content := text
inner := schemas.ChatMessage{
Role: schemas.ChatMessageRoleAssistant,
Content: &schemas.ChatMessageContent{ContentStr: &content},
}
innerJSON, err := json.Marshal(inner)
if err != nil {
panic(fmt.Sprintf("versionMsgAssistantUIFormat: marshal failed: %v", err))
}
payload := fmt.Sprintf(`{"id":"resp-1","choices":[{"index":0,"message":%s,"finish_reason":"stop"}]}`, string(innerJSON))
envelope := fmt.Sprintf(`{"originalType":"completion_result","payload":%s}`, payload)
return tables.TablePromptVersionMessage{
Message: tables.PromptMessage(envelope),
}
}
// ============================================================
// errTest — minimal error type for test use
// ============================================================
type errTest string
func (e errTest) Error() string { return string(e) }
// ============================================================
// Assertion helpers
// ============================================================
// msgText extracts the ContentStr from a ChatMessage, returning "" if absent.
func msgText(msg schemas.ChatMessage) string {
if msg.Content == nil || msg.Content.ContentStr == nil {
return ""
}
return *msg.Content.ContentStr
}