first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

View File

@@ -0,0 +1,932 @@
package prompts
import (
"context"
"encoding/json"
"testing"
"github.com/maximhq/bifrost/core/schemas"
tables "github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// ============================================================
// InitWithResolver
// ============================================================
func TestInitWithResolver_NilStore(t *testing.T) {
_, err := InitWithResolver(context.Background(), nil, &staticResolver{}, NewMockLogger())
require.Error(t, err, "expected error for nil store")
}
func TestInitWithResolver_NilResolverFallsBackToHeader(t *testing.T) {
ms := &mockStore{}
p, err := InitWithResolver(context.Background(), ms, nil, NewMockLogger())
require.NoError(t, err)
require.NotNil(t, p)
_, ok := p.resolver.(*headerResolver)
assert.True(t, ok, "expected headerResolver, got %T", p.resolver)
}
// ============================================================
// loadCache
// ============================================================
func TestLoadCache_EmptyStore(t *testing.T) {
p := newPluginWithStore(&mockStore{})
require.NoError(t, p.loadCache(context.Background()))
assert.Empty(t, p.promptsByID)
assert.Empty(t, p.versionsByPromptAndNumber)
}
func TestLoadCache_PopulatesMaps(t *testing.T) {
v1 := makeVersion(1, "p1", true, versionMsg(schemas.ChatMessageRoleSystem, "Hello"))
v2 := makeVersion(2, "p2", true)
p1 := makePrompt("p1", &v1)
p2 := makePrompt("p2", &v2)
p := newPluginWithStore(&mockStore{
prompts: []tables.TablePrompt{p1, p2},
versions: []tables.TablePromptVersion{v1, v2},
})
require.NoError(t, p.loadCache(context.Background()))
assert.Len(t, p.promptsByID, 2)
assert.Len(t, p.versionsByPromptAndNumber, 2)
assert.NotNil(t, p.promptsByID["p1"])
assert.NotNil(t, p.versionsByPromptAndNumber["p1"][1])
}
func TestLoadCache_GetPromptsError(t *testing.T) {
p := newPluginWithStore(&mockStore{err: errTest("boom")})
err := p.loadCache(context.Background())
require.Error(t, err)
}
func TestLoadCache_GetVersionsError(t *testing.T) {
p := newPluginWithStore(&versionsErrStore{
prompts: []tables.TablePrompt{makePrompt("p1", nil)},
err: errTest("versions boom"),
})
err := p.loadCache(context.Background())
require.Error(t, err)
assert.Contains(t, err.Error(), "versions boom")
}
// ============================================================
// PreLLMHook
// ============================================================
func TestPreLLMHook_NoPromptID(t *testing.T) {
p := newTestPlugin(&staticResolver{promptID: ""}, nil, nil)
out, sc, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hello")))
require.NoError(t, err)
assert.Nil(t, sc)
assert.Len(t, out.ChatRequest.Input, 1)
}
func TestPreLLMHook_PromptNotFound(t *testing.T) {
log := NewMockLogger()
p := newTestPluginWithLogger(&staticResolver{promptID: "missing"}, nil, nil, log)
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hello")))
require.NoError(t, err)
assert.Len(t, out.ChatRequest.Input, 1, "input should be unchanged")
assert.True(t, log.Warned(), "expected a warning for unknown prompt")
}
func TestPreLLMHook_UseLatestVersion(t *testing.T) {
v := makeVersion(1, "p1", true,
versionMsg(schemas.ChatMessageRoleSystem, "Be helpful"),
)
prompt := makePrompt("p1", &v)
p := newTestPlugin(
&staticResolver{promptID: "p1"},
map[string]*tables.TablePrompt{"p1": &prompt},
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
)
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hello")))
require.NoError(t, err)
require.Len(t, out.ChatRequest.Input, 2, "expected system prompt + user message")
assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role)
assert.Equal(t, "Be helpful", msgText(out.ChatRequest.Input[0]))
assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[1].Role)
assert.Equal(t, "hello", msgText(out.ChatRequest.Input[1]))
}
func TestPreLLMHook_UseSpecificVersion(t *testing.T) {
vLatest := makeVersion(1, "p1", true,
versionMsg(schemas.ChatMessageRoleSystem, "latest system prompt"),
)
vOld := makeVersion(2, "p1", false,
versionMsg(schemas.ChatMessageRoleSystem, "old system prompt"),
)
prompt := makePrompt("p1", &vLatest)
p := newTestPlugin(
&staticResolver{promptID: "p1", versionNumber: 2},
map[string]*tables.TablePrompt{"p1": &prompt},
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &vLatest, 2: &vOld}},
)
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hello")))
require.NoError(t, err)
require.Len(t, out.ChatRequest.Input, 2)
// Must use vOld, not vLatest.
assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role)
assert.Equal(t, "old system prompt", msgText(out.ChatRequest.Input[0]))
}
func TestPreLLMHook_VersionNotFound(t *testing.T) {
v := makeVersion(1, "p1", true, versionMsg(schemas.ChatMessageRoleSystem, "hello"))
prompt := makePrompt("p1", &v)
log := NewMockLogger()
p := newTestPluginWithLogger(
&staticResolver{promptID: "p1", versionNumber: 99},
map[string]*tables.TablePrompt{"p1": &prompt},
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
log,
)
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi")))
require.NoError(t, err)
assert.Len(t, out.ChatRequest.Input, 1, "input should be unchanged")
assert.True(t, log.Warned(), "expected warning for missing version")
}
func TestPreLLMHook_VersionBelongsToDifferentPrompt(t *testing.T) {
v := makeVersion(1, "p2", true, versionMsg(schemas.ChatMessageRoleSystem, "wrong"))
prompt := makePrompt("p1", nil)
log := NewMockLogger()
p := newTestPluginWithLogger(
&staticResolver{promptID: "p1", versionNumber: 1},
map[string]*tables.TablePrompt{"p1": &prompt},
map[string]map[int]*tables.TablePromptVersion{"p2": {1: &v}},
log,
)
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi")))
require.NoError(t, err)
assert.Len(t, out.ChatRequest.Input, 1, "input should be unchanged")
assert.True(t, log.Warned(), "expected warning for version/prompt mismatch")
}
func TestPreLLMHook_NoLatestVersion(t *testing.T) {
prompt := makePrompt("p1", nil) // LatestVersion is nil
log := NewMockLogger()
p := newTestPluginWithLogger(
&staticResolver{promptID: "p1"},
map[string]*tables.TablePrompt{"p1": &prompt},
nil,
log,
)
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi")))
require.NoError(t, err)
assert.Len(t, out.ChatRequest.Input, 1, "input should be unchanged")
assert.True(t, log.Warned(), "expected warning for missing latest version")
}
func TestPreLLMHook_EmptyTemplate(t *testing.T) {
v := makeVersion(1, "p1", true) // no messages
prompt := makePrompt("p1", &v)
p := newTestPlugin(
&staticResolver{promptID: "p1"},
map[string]*tables.TablePrompt{"p1": &prompt},
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
)
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi")))
require.NoError(t, err)
assert.Len(t, out.ChatRequest.Input, 1)
}
func TestPreLLMHook_MultipleTemplateMessages(t *testing.T) {
v := makeVersion(1, "p1", true,
versionMsg(schemas.ChatMessageRoleSystem, "sys prompt"),
versionMsg(schemas.ChatMessageRoleUser, "example input"),
versionMsg(schemas.ChatMessageRoleAssistant, "example output"),
)
prompt := makePrompt("p1", &v)
p := newTestPlugin(
&staticResolver{promptID: "p1"},
map[string]*tables.TablePrompt{"p1": &prompt},
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
)
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("actual question")))
require.NoError(t, err)
require.Len(t, out.ChatRequest.Input, 4, "expected 3 template messages + 1 original")
assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role)
assert.Equal(t, "sys prompt", msgText(out.ChatRequest.Input[0]))
assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[1].Role)
assert.Equal(t, "example input", msgText(out.ChatRequest.Input[1]))
assert.Equal(t, schemas.ChatMessageRoleAssistant, out.ChatRequest.Input[2].Role)
assert.Equal(t, "example output", msgText(out.ChatRequest.Input[2]))
// Original user message must be last, content preserved.
assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[3].Role)
assert.Equal(t, "actual question", msgText(out.ChatRequest.Input[3]))
}
func TestPreLLMHook_ResolverError(t *testing.T) {
log := NewMockLogger()
p := newTestPluginWithLogger(
&staticResolver{err: errTest("resolver failed")},
nil, nil, log,
)
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi")))
require.NoError(t, err, "PreLLMHook must not propagate resolver errors")
assert.Len(t, out.ChatRequest.Input, 1, "input should be unchanged")
assert.True(t, log.Warned(), "expected warning for resolver error")
}
func TestPreLLMHook_MessageJSON_FallbackPath(t *testing.T) {
// Messages where Message ([]byte) is nil but MessageJSON is set — the fallback
// branch in chatMessagesFromVersionMessages. This mirrors rows loaded from
// an older DB schema before AfterFind was established.
v := makeVersion(1, "p1", true,
versionMsgViaJSON(schemas.ChatMessageRoleSystem, "from json field"),
)
prompt := makePrompt("p1", &v)
p := newTestPlugin(
&staticResolver{promptID: "p1"},
map[string]*tables.TablePrompt{"p1": &prompt},
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
)
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi")))
require.NoError(t, err)
require.Len(t, out.ChatRequest.Input, 2)
assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role)
assert.Equal(t, "from json field", msgText(out.ChatRequest.Input[0]))
}
func TestPreLLMHook_ResponsesRequest(t *testing.T) {
v := makeVersion(1, "p1", true,
versionMsg(schemas.ChatMessageRoleSystem, "be concise"),
)
prompt := makePrompt("p1", &v)
p := newTestPlugin(
&staticResolver{promptID: "p1"},
map[string]*tables.TablePrompt{"p1": &prompt},
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
)
userRole := schemas.ResponsesMessageRoleType("user")
req := &schemas.BifrostRequest{
ResponsesRequest: &schemas.BifrostResponsesRequest{
Input: []schemas.ResponsesMessage{{Role: &userRole}},
},
}
out, _, err := p.PreLLMHook(bfCtx(), req)
require.NoError(t, err)
// Template message(s) prepended before the original user input.
assert.Greater(t, len(out.ResponsesRequest.Input), 1, "expected template prepended before user message")
// Original user message must still be last.
last := out.ResponsesRequest.Input[len(out.ResponsesRequest.Input)-1]
assert.Equal(t, schemas.ResponsesMessageRoleType("user"), *last.Role)
}
// TestPreLLMHook_PromptSystemMsg_PlusUserInputSystemMsg verifies that when the
// prompt template contains a system message and the incoming request also starts
// with a system message, both system messages are forwarded to the model —
// the plugin's only job is prepending, not de-duplicating or filtering roles.
func TestPreLLMHook_PromptSystemMsg_PlusUserInputSystemMsg(t *testing.T) {
v := makeVersion(1, "p1", true,
versionMsg(schemas.ChatMessageRoleSystem, "prompt system"),
)
prompt := makePrompt("p1", &v)
p := newTestPlugin(
&staticResolver{promptID: "p1"},
map[string]*tables.TablePrompt{"p1": &prompt},
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
)
// Incoming request already has its own system message before the user turn.
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(
systemMsg("user-side system context"),
userMsg("actual question"),
))
require.NoError(t, err)
require.Len(t, out.ChatRequest.Input, 3, "expected prompt system + user system + user message")
assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role)
assert.Equal(t, "prompt system", msgText(out.ChatRequest.Input[0]))
assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[1].Role)
assert.Equal(t, "user-side system context", msgText(out.ChatRequest.Input[1]))
assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[2].Role)
assert.Equal(t, "actual question", msgText(out.ChatRequest.Input[2]))
}
// TestPreLLMHook_PromptWithToolCallMessages_PlusUserMessage verifies that when
// the prompt template contains a full tool-call turn (system → assistant with
// tool_calls → tool result) and the user sends a new message, the entire
// template is prepended and all fields (ToolCalls, ToolCallID) are preserved.
func TestPreLLMHook_PromptWithToolCallMessages_PlusUserMessage(t *testing.T) {
const callID = "call_abc123"
v := makeVersion(1, "p1", true,
versionMsg(schemas.ChatMessageRoleSystem, "you are a weather bot"),
versionMsgWithToolCall(callID, "get_weather", `{"city":"Paris"}`),
versionMsgToolResult(callID, "Sunny, 22°C"),
)
prompt := makePrompt("p1", &v)
p := newTestPlugin(
&staticResolver{promptID: "p1"},
map[string]*tables.TablePrompt{"p1": &prompt},
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
)
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("what about tomorrow?")))
require.NoError(t, err)
require.Len(t, out.ChatRequest.Input, 4, "expected system + assistant(tool_calls) + tool_result + user")
// System message from prompt.
assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role)
assert.Equal(t, "you are a weather bot", msgText(out.ChatRequest.Input[0]))
// Assistant message with tool_calls must carry its ToolCalls slice.
assistantMsg := out.ChatRequest.Input[1]
assert.Equal(t, schemas.ChatMessageRoleAssistant, assistantMsg.Role)
require.NotNil(t, assistantMsg.ChatAssistantMessage, "ChatAssistantMessage must be present")
require.Len(t, assistantMsg.ChatAssistantMessage.ToolCalls, 1)
tc := assistantMsg.ChatAssistantMessage.ToolCalls[0]
require.NotNil(t, tc.ID)
assert.Equal(t, callID, *tc.ID)
require.NotNil(t, tc.Function.Name)
assert.Equal(t, "get_weather", *tc.Function.Name)
assert.Equal(t, `{"city":"Paris"}`, tc.Function.Arguments)
// Tool result message must carry the ToolCallID.
toolResultMsg := out.ChatRequest.Input[2]
assert.Equal(t, schemas.ChatMessageRoleTool, toolResultMsg.Role)
assert.Equal(t, "Sunny, 22°C", msgText(toolResultMsg))
require.NotNil(t, toolResultMsg.ChatToolMessage, "ChatToolMessage must be present")
require.NotNil(t, toolResultMsg.ChatToolMessage.ToolCallID)
assert.Equal(t, callID, *toolResultMsg.ChatToolMessage.ToolCallID)
// Original user message is last.
assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[3].Role)
assert.Equal(t, "what about tomorrow?", msgText(out.ChatRequest.Input[3]))
}
// TestPreLLMHook_MultipleSystemMessages_InPromptTemplate verifies that a prompt
// template may itself contain multiple system messages and all of them are
// prepended before the user's input in the original order.
func TestPreLLMHook_MultipleSystemMessages_InPromptTemplate(t *testing.T) {
v := makeVersion(1, "p1", true,
versionMsg(schemas.ChatMessageRoleSystem, "first system"),
versionMsg(schemas.ChatMessageRoleSystem, "second system"),
)
prompt := makePrompt("p1", &v)
p := newTestPlugin(
&staticResolver{promptID: "p1"},
map[string]*tables.TablePrompt{"p1": &prompt},
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
)
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hello")))
require.NoError(t, err)
require.Len(t, out.ChatRequest.Input, 3, "expected 2 system messages + user message")
assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role)
assert.Equal(t, "first system", msgText(out.ChatRequest.Input[0]))
assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[1].Role)
assert.Equal(t, "second system", msgText(out.ChatRequest.Input[1]))
assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[2].Role)
assert.Equal(t, "hello", msgText(out.ChatRequest.Input[2]))
}
// ============================================================
// HTTPTransportPreHook
// ============================================================
func TestHTTPTransportPreHook_NilRequest(t *testing.T) {
p := newTestPlugin(nil, nil, nil)
resp, err := p.HTTPTransportPreHook(bfCtx(), nil)
assert.NoError(t, err)
assert.Nil(t, resp)
}
func TestHTTPTransportPreHook_SetsPromptID(t *testing.T) {
p := newTestPlugin(nil, nil, nil)
ctx := bfCtx()
req := &schemas.HTTPRequest{
Headers: map[string]string{PromptIDHeader: "my-prompt"},
}
_, _ = p.HTTPTransportPreHook(ctx, req)
got, _ := ctx.Value(PromptIDKey).(string)
assert.Equal(t, "my-prompt", got)
}
func TestHTTPTransportPreHook_SetsVersionID(t *testing.T) {
p := newTestPlugin(nil, nil, nil)
ctx := bfCtx()
req := &schemas.HTTPRequest{
Headers: map[string]string{PromptVersionHeader: "42"},
}
_, _ = p.HTTPTransportPreHook(ctx, req)
got, _ := ctx.Value(PromptVersionKey).(string)
assert.Equal(t, "42", got)
}
func TestHTTPTransportPreHook_TrimsWhitespace(t *testing.T) {
p := newTestPlugin(nil, nil, nil)
ctx := bfCtx()
req := &schemas.HTTPRequest{
Headers: map[string]string{PromptIDHeader: " padded "},
}
_, _ = p.HTTPTransportPreHook(ctx, req)
got, _ := ctx.Value(PromptIDKey).(string)
assert.Equal(t, "padded", got)
}
func TestHTTPTransportPreHook_WhitespaceOnlyNotSet(t *testing.T) {
p := newTestPlugin(nil, nil, nil)
ctx := bfCtx()
req := &schemas.HTTPRequest{
Headers: map[string]string{PromptIDHeader: " "},
}
_, _ = p.HTTPTransportPreHook(ctx, req)
assert.Nil(t, ctx.Value(PromptIDKey), "whitespace-only header must not be stored in context")
}
func TestHTTPTransportPreHook_CaseInsensitiveHeaders(t *testing.T) {
p := newTestPlugin(nil, nil, nil)
ctx := bfCtx()
// "X-Bf-Prompt-Id" is a title-case variant of the canonical "x-bf-prompt-id".
req := &schemas.HTTPRequest{
Headers: map[string]string{"X-Bf-Prompt-Id": "upper-case"},
}
_, _ = p.HTTPTransportPreHook(ctx, req)
got, _ := ctx.Value(PromptIDKey).(string)
assert.Equal(t, "upper-case", got)
}
// ============================================================
// chatMessageFromStoredJSON
// ============================================================
func TestChatMessageFromStoredJSON(t *testing.T) {
systemText := "you are helpful"
directMsg := schemas.ChatMessage{
Role: schemas.ChatMessageRoleSystem,
Content: &schemas.ChatMessageContent{ContentStr: &systemText},
}
directJSON, _ := json.Marshal(directMsg)
envelopeJSON := []byte(`{"payload":` + string(directJSON) + `}`)
tests := []struct {
name string
input []byte
wantErr bool
wantRole schemas.ChatMessageRole
wantText string
}{
{
name: "direct format",
input: directJSON,
wantRole: schemas.ChatMessageRoleSystem,
wantText: systemText,
},
{
name: "envelope format",
input: envelopeJSON,
wantRole: schemas.ChatMessageRoleSystem,
wantText: systemText,
},
{
// UI format for assistant messages: originalType=completion_result,
// payload is a BifrostChatResponse; message lives at choices[0].message.
name: "completion_result envelope (UI assistant format)",
input: []byte(`{"originalType":"completion_result","payload":{"id":"r1","choices":[{"index":0,"message":{"role":"assistant","content":"hi there"},"finish_reason":"stop"}]}}`),
wantRole: schemas.ChatMessageRoleAssistant,
wantText: "hi there",
},
{
// completion_result with no choices falls through to direct ChatMessage parse.
name: "completion_result envelope with empty choices",
input: []byte(`{"originalType":"completion_result","payload":{"id":"r1","choices":[]}}`),
wantErr: false,
wantRole: "",
wantText: "",
},
{
name: "empty bytes",
input: []byte(""),
wantErr: true,
},
{
name: "null bytes",
input: []byte("null"),
wantErr: true,
},
{
name: "whitespace only",
input: []byte(" "),
wantErr: true,
},
{
name: "malformed envelope payload",
input: []byte(`{"payload":"not-a-chat-message"}`),
wantErr: true,
},
{
// {"payload":null} — envelope path is skipped (payload is "null"),
// falls through to direct decode of the outer object as ChatMessage.
// schemas.Unmarshal succeeds on an unknown-field object → empty ChatMessage, no error.
name: "envelope with null payload falls through to direct decode",
input: []byte(`{"payload":null}`),
wantErr: false,
wantRole: "",
wantText: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := convertVersionMessagesToChatMessages(tt.input)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantRole, got.Role)
assert.Equal(t, tt.wantText, msgText(got))
})
}
}
func TestChatMessageFromStoredJSON_EnvelopeWithEmptyStringPayload(t *testing.T) {
// {"payload":""} — the payload field is a non-null, non-empty JSON string `""`.
// The envelope path attempts to unmarshal `""` (a JSON string literal) into
// schemas.ChatMessage (a struct), which fails. The error is returned directly;
// there is no further fallback.
input := []byte(`{"payload":""}`)
_, err := convertVersionMessagesToChatMessages(input)
require.Error(t, err)
assert.Contains(t, err.Error(), "decoding prompt message envelope payload")
}
// ============================================================
// parsePromptVersionNumber
// ============================================================
func TestParsePromptVersionNumber(t *testing.T) {
type want struct {
num int
specified bool
wantErr bool
}
tests := []struct {
name string
value any // stored in context; nil means don't set
want want
}{
{name: "nil — not specified", value: nil, want: want{0, false, false}},
{name: "string valid", value: "99", want: want{99, true, false}},
{name: "string empty", value: "", want: want{0, false, false}},
{name: "string whitespace", value: " ", want: want{0, false, false}},
{name: "string invalid", value: "abc", want: want{0, true, true}},
{name: "unknown type", value: struct{}{}, want: want{0, false, false}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := bfCtx()
if tt.value != nil {
ctx.SetValue(PromptVersionKey, tt.value)
}
num, err := parseNumberFromContext(ctx, PromptVersionKey)
if tt.want.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.want.num, num)
})
}
}
// ============================================================
// mergeChatMessages
// ============================================================
func TestMergeChatMessages(t *testing.T) {
t.Run("prepends prefix before existing messages", func(t *testing.T) {
dest := []schemas.ChatMessage{userMsg("original")}
prefix := []schemas.ChatMessage{systemMsg("system")}
mergeChatMessages(&dest, prefix)
require.Len(t, dest, 2)
assert.Equal(t, schemas.ChatMessageRoleSystem, dest[0].Role)
assert.Equal(t, "system", msgText(dest[0]))
assert.Equal(t, schemas.ChatMessageRoleUser, dest[1].Role)
assert.Equal(t, "original", msgText(dest[1]))
})
t.Run("nil dest is a no-op", func(t *testing.T) {
// Must not panic.
mergeChatMessages(nil, []schemas.ChatMessage{systemMsg("x")})
})
t.Run("empty prefix is a no-op", func(t *testing.T) {
dest := []schemas.ChatMessage{userMsg("only")}
mergeChatMessages(&dest, nil)
assert.Len(t, dest, 1)
assert.Equal(t, "only", msgText(dest[0]))
})
}
// ============================================================
// chatMessagesFromVersionMessages
// ============================================================
func TestChatMessagesFromVersionMessages_SingleMessage(t *testing.T) {
msg := versionMsg(schemas.ChatMessageRoleUser, "hello")
out, err := chatMessagesFromVersionMessages([]tables.TablePromptVersionMessage{msg})
require.NoError(t, err)
require.Len(t, out, 1)
assert.Equal(t, schemas.ChatMessageRoleUser, out[0].Role)
assert.Equal(t, "hello", msgText(out[0]))
}
func TestChatMessagesFromVersionMessages_MessageJSONFallback(t *testing.T) {
// Row has no Message bytes but has MessageJSON — exercises the fallback branch.
msg := versionMsgViaJSON(schemas.ChatMessageRoleAssistant, "assistant reply")
out, err := chatMessagesFromVersionMessages([]tables.TablePromptVersionMessage{msg})
require.NoError(t, err)
require.Len(t, out, 1)
assert.Equal(t, schemas.ChatMessageRoleAssistant, out[0].Role)
assert.Equal(t, "assistant reply", msgText(out[0]))
}
func TestChatMessagesFromVersionMessages_PreservesOrder(t *testing.T) {
msgs := []tables.TablePromptVersionMessage{
versionMsg(schemas.ChatMessageRoleSystem, "first"),
versionMsg(schemas.ChatMessageRoleUser, "second"),
versionMsg(schemas.ChatMessageRoleAssistant, "third"),
}
out, err := chatMessagesFromVersionMessages(msgs)
require.NoError(t, err)
require.Len(t, out, 3)
assert.Equal(t, schemas.ChatMessageRoleSystem, out[0].Role)
assert.Equal(t, "first", msgText(out[0]))
assert.Equal(t, schemas.ChatMessageRoleUser, out[1].Role)
assert.Equal(t, "second", msgText(out[1]))
assert.Equal(t, schemas.ChatMessageRoleAssistant, out[2].Role)
assert.Equal(t, "third", msgText(out[2]))
}
func TestChatMessagesFromVersionMessages_InvalidJSON(t *testing.T) {
bad := tables.TablePromptVersionMessage{Message: []byte(`not-json`)}
_, err := chatMessagesFromVersionMessages([]tables.TablePromptVersionMessage{bad})
require.Error(t, err)
}
// ============================================================
// PreLLMHook — model params merge and override
// ============================================================
// TestPreLLMHook_VersionParamsApplied_WhenRequestHasNoParams verifies that when
// the request carries no Params at all, the version's ModelParams become the
// effective parameters on the outgoing request.
func TestPreLLMHook_VersionParamsApplied_WhenRequestHasNoParams(t *testing.T) {
v := makeVersionWithParams(1, "p1", true,
tables.ModelParams{"temperature": float64(0.7)},
versionMsg(schemas.ChatMessageRoleSystem, "sys"),
)
prompt := makePrompt("p1", &v)
p := newTestPlugin(
&staticResolver{promptID: "p1"},
map[string]*tables.TablePrompt{"p1": &prompt},
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
)
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi")))
require.NoError(t, err)
require.NotNil(t, out.ChatRequest.Params, "expected Params to be set from version ModelParams")
require.NotNil(t, out.ChatRequest.Params.Temperature)
assert.InDelta(t, 0.7, *out.ChatRequest.Params.Temperature, 0.001)
}
// TestPreLLMHook_RequestParamsOverrideVersionParams verifies that when both the
// version and the request specify the same parameter, the request value wins.
func TestPreLLMHook_RequestParamsOverrideVersionParams(t *testing.T) {
reqTemp := 0.9
v := makeVersionWithParams(1, "p1", true,
tables.ModelParams{"temperature": float64(0.3)},
versionMsg(schemas.ChatMessageRoleSystem, "sys"),
)
prompt := makePrompt("p1", &v)
p := newTestPlugin(
&staticResolver{promptID: "p1"},
map[string]*tables.TablePrompt{"p1": &prompt},
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
)
req := chatRequestWithParams(&schemas.ChatParameters{Temperature: &reqTemp}, userMsg("hello"))
out, _, err := p.PreLLMHook(bfCtx(), req)
require.NoError(t, err)
require.NotNil(t, out.ChatRequest.Params)
require.NotNil(t, out.ChatRequest.Params.Temperature)
assert.InDelta(t, reqTemp, *out.ChatRequest.Params.Temperature, 0.001,
"request temperature must override version default temperature")
}
// TestPreLLMHook_RequestParamsPartialOverride verifies the mixed case: version
// sets temperature and max_completion_tokens; request overrides only temperature.
// The version's max_completion_tokens must still be applied.
func TestPreLLMHook_RequestParamsPartialOverride(t *testing.T) {
reqTemp := 0.9
maxTokens := 200
v := makeVersionWithParams(1, "p1", true,
tables.ModelParams{
"temperature": float64(0.3),
"max_completion_tokens": float64(maxTokens),
},
versionMsg(schemas.ChatMessageRoleSystem, "sys"),
)
prompt := makePrompt("p1", &v)
p := newTestPlugin(
&staticResolver{promptID: "p1"},
map[string]*tables.TablePrompt{"p1": &prompt},
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
)
req := chatRequestWithParams(&schemas.ChatParameters{Temperature: &reqTemp}, userMsg("hello"))
out, _, err := p.PreLLMHook(bfCtx(), req)
require.NoError(t, err)
require.NotNil(t, out.ChatRequest.Params)
require.NotNil(t, out.ChatRequest.Params.Temperature)
assert.InDelta(t, reqTemp, *out.ChatRequest.Params.Temperature, 0.001,
"request temperature must override version temperature")
require.NotNil(t, out.ChatRequest.Params.MaxCompletionTokens,
"version max_completion_tokens must be applied when request does not override it")
assert.Equal(t, maxTokens, *out.ChatRequest.Params.MaxCompletionTokens)
}
// ============================================================
// PreLLMHook — model field preservation
// ============================================================
// TestPreLLMHook_ModelInVersionParams_DoesNotOverrideRequestModel verifies that
// a "model" key inside a version's ModelParams (which the UI may store alongside
// temperature etc.) does NOT replace the model field on the outgoing
// BifrostChatRequest. The model chosen by the caller must always win.
func TestPreLLMHook_ModelInVersionParams_DoesNotOverrideRequestModel(t *testing.T) {
v := makeVersionWithParams(1, "p1", true,
tables.ModelParams{
"model": "openai/gpt-4o",
"temperature": float64(0.5),
},
versionMsg(schemas.ChatMessageRoleSystem, "sys"),
)
prompt := makePrompt("p1", &v)
p := newTestPlugin(
&staticResolver{promptID: "p1"},
map[string]*tables.TablePrompt{"p1": &prompt},
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
)
req := chatRequestWithModel("openai/gpt-3.5-turbo", userMsg("hi"))
out, _, err := p.PreLLMHook(bfCtx(), req)
require.NoError(t, err)
assert.Equal(t, "openai/gpt-3.5-turbo", out.ChatRequest.Model,
"request model must not be overridden by model stored in version ModelParams")
}
// ============================================================
// loadCache + PreLLMHook integration (store → cache → injection)
// ============================================================
// TestLoadCacheAndPreLLMHook_EndToEnd verifies the full pipeline:
// mockStore returns TablePrompt/TablePromptVersion structs → loadCache populates
// the in-memory maps → PreLLMHook injects the template messages correctly.
// This catches any mismatch between how loadCache builds the maps and how
// PreLLMHook reads them (e.g. pointer aliasing, LatestVersion linking).
func TestLoadCacheAndPreLLMHook_EndToEnd(t *testing.T) {
v := makeVersion(1, "p1", true,
versionMsg(schemas.ChatMessageRoleSystem, "end-to-end system"),
)
prompt := makePrompt("p1", &v)
ms := &mockStore{
prompts: []tables.TablePrompt{prompt},
versions: []tables.TablePromptVersion{v},
}
p := newPluginWithStore(ms)
require.NoError(t, p.loadCache(context.Background()))
p.resolver = &staticResolver{promptID: "p1"}
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("user msg")))
require.NoError(t, err)
require.Len(t, out.ChatRequest.Input, 2)
assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role)
assert.Equal(t, "end-to-end system", msgText(out.ChatRequest.Input[0]))
assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[1].Role)
assert.Equal(t, "user msg", msgText(out.ChatRequest.Input[1]))
}
// TestLoadCacheAndPreLLMHook_SpecificVersion exercises the loadCache → PreLLMHook
// path for a version lookup by ID (not just the LatestVersion pointer).
func TestLoadCacheAndPreLLMHook_SpecificVersion(t *testing.T) {
vOld := makeVersion(2, "p1", false,
versionMsg(schemas.ChatMessageRoleSystem, "old via store"),
)
vLatest := makeVersion(3, "p1", true,
versionMsg(schemas.ChatMessageRoleSystem, "latest via store"),
)
prompt := makePrompt("p1", &vLatest)
ms := &mockStore{
prompts: []tables.TablePrompt{prompt},
versions: []tables.TablePromptVersion{vOld, vLatest},
}
p := newPluginWithStore(ms)
require.NoError(t, p.loadCache(context.Background()))
p.resolver = &staticResolver{promptID: "p1", versionNumber: 2}
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("question")))
require.NoError(t, err)
require.Len(t, out.ChatRequest.Input, 2)
assert.Equal(t, "old via store", msgText(out.ChatRequest.Input[0]))
}
// TestPreLLMHook_AssistantMessage_UIFormat verifies that assistant messages stored
// in the Bifrost UI's completion_result format (payload.choices[0].message) are
// correctly extracted and prepended to the request.
func TestPreLLMHook_AssistantMessage_UIFormat(t *testing.T) {
v := makeVersion(1, "p1", true,
versionMsg(schemas.ChatMessageRoleSystem, "be helpful"),
versionMsgAssistantUIFormat("sure, how can I help?"),
)
prompt := makePrompt("p1", &v)
p := newTestPlugin(
&staticResolver{promptID: "p1"},
map[string]*tables.TablePrompt{"p1": &prompt},
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
)
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hello")))
require.NoError(t, err)
require.Len(t, out.ChatRequest.Input, 3, "expected system + assistant + user")
assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role)
assert.Equal(t, "be helpful", msgText(out.ChatRequest.Input[0]))
assert.Equal(t, schemas.ChatMessageRoleAssistant, out.ChatRequest.Input[1].Role)
assert.Equal(t, "sure, how can I help?", msgText(out.ChatRequest.Input[1]))
assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[2].Role)
assert.Equal(t, "hello", msgText(out.ChatRequest.Input[2]))
}