first commit
This commit is contained in:
932
plugins/prompts/plugin_test.go
Normal file
932
plugins/prompts/plugin_test.go
Normal file
@@ -0,0 +1,932 @@
|
||||
package prompts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
tables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ============================================================
|
||||
// InitWithResolver
|
||||
// ============================================================
|
||||
|
||||
func TestInitWithResolver_NilStore(t *testing.T) {
|
||||
_, err := InitWithResolver(context.Background(), nil, &staticResolver{}, NewMockLogger())
|
||||
require.Error(t, err, "expected error for nil store")
|
||||
}
|
||||
|
||||
func TestInitWithResolver_NilResolverFallsBackToHeader(t *testing.T) {
|
||||
ms := &mockStore{}
|
||||
p, err := InitWithResolver(context.Background(), ms, nil, NewMockLogger())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, p)
|
||||
_, ok := p.resolver.(*headerResolver)
|
||||
assert.True(t, ok, "expected headerResolver, got %T", p.resolver)
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// loadCache
|
||||
// ============================================================
|
||||
|
||||
func TestLoadCache_EmptyStore(t *testing.T) {
|
||||
p := newPluginWithStore(&mockStore{})
|
||||
require.NoError(t, p.loadCache(context.Background()))
|
||||
assert.Empty(t, p.promptsByID)
|
||||
assert.Empty(t, p.versionsByPromptAndNumber)
|
||||
}
|
||||
|
||||
func TestLoadCache_PopulatesMaps(t *testing.T) {
|
||||
v1 := makeVersion(1, "p1", true, versionMsg(schemas.ChatMessageRoleSystem, "Hello"))
|
||||
v2 := makeVersion(2, "p2", true)
|
||||
p1 := makePrompt("p1", &v1)
|
||||
p2 := makePrompt("p2", &v2)
|
||||
|
||||
p := newPluginWithStore(&mockStore{
|
||||
prompts: []tables.TablePrompt{p1, p2},
|
||||
versions: []tables.TablePromptVersion{v1, v2},
|
||||
})
|
||||
|
||||
require.NoError(t, p.loadCache(context.Background()))
|
||||
assert.Len(t, p.promptsByID, 2)
|
||||
assert.Len(t, p.versionsByPromptAndNumber, 2)
|
||||
assert.NotNil(t, p.promptsByID["p1"])
|
||||
assert.NotNil(t, p.versionsByPromptAndNumber["p1"][1])
|
||||
}
|
||||
|
||||
func TestLoadCache_GetPromptsError(t *testing.T) {
|
||||
p := newPluginWithStore(&mockStore{err: errTest("boom")})
|
||||
err := p.loadCache(context.Background())
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestLoadCache_GetVersionsError(t *testing.T) {
|
||||
p := newPluginWithStore(&versionsErrStore{
|
||||
prompts: []tables.TablePrompt{makePrompt("p1", nil)},
|
||||
err: errTest("versions boom"),
|
||||
})
|
||||
err := p.loadCache(context.Background())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "versions boom")
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// PreLLMHook
|
||||
// ============================================================
|
||||
|
||||
func TestPreLLMHook_NoPromptID(t *testing.T) {
|
||||
p := newTestPlugin(&staticResolver{promptID: ""}, nil, nil)
|
||||
out, sc, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hello")))
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, sc)
|
||||
assert.Len(t, out.ChatRequest.Input, 1)
|
||||
}
|
||||
|
||||
func TestPreLLMHook_PromptNotFound(t *testing.T) {
|
||||
log := NewMockLogger()
|
||||
p := newTestPluginWithLogger(&staticResolver{promptID: "missing"}, nil, nil, log)
|
||||
|
||||
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hello")))
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, out.ChatRequest.Input, 1, "input should be unchanged")
|
||||
assert.True(t, log.Warned(), "expected a warning for unknown prompt")
|
||||
}
|
||||
|
||||
func TestPreLLMHook_UseLatestVersion(t *testing.T) {
|
||||
v := makeVersion(1, "p1", true,
|
||||
versionMsg(schemas.ChatMessageRoleSystem, "Be helpful"),
|
||||
)
|
||||
prompt := makePrompt("p1", &v)
|
||||
|
||||
p := newTestPlugin(
|
||||
&staticResolver{promptID: "p1"},
|
||||
map[string]*tables.TablePrompt{"p1": &prompt},
|
||||
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
|
||||
)
|
||||
|
||||
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hello")))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out.ChatRequest.Input, 2, "expected system prompt + user message")
|
||||
|
||||
assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role)
|
||||
assert.Equal(t, "Be helpful", msgText(out.ChatRequest.Input[0]))
|
||||
|
||||
assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[1].Role)
|
||||
assert.Equal(t, "hello", msgText(out.ChatRequest.Input[1]))
|
||||
}
|
||||
|
||||
func TestPreLLMHook_UseSpecificVersion(t *testing.T) {
|
||||
vLatest := makeVersion(1, "p1", true,
|
||||
versionMsg(schemas.ChatMessageRoleSystem, "latest system prompt"),
|
||||
)
|
||||
vOld := makeVersion(2, "p1", false,
|
||||
versionMsg(schemas.ChatMessageRoleSystem, "old system prompt"),
|
||||
)
|
||||
prompt := makePrompt("p1", &vLatest)
|
||||
|
||||
p := newTestPlugin(
|
||||
&staticResolver{promptID: "p1", versionNumber: 2},
|
||||
map[string]*tables.TablePrompt{"p1": &prompt},
|
||||
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &vLatest, 2: &vOld}},
|
||||
)
|
||||
|
||||
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hello")))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out.ChatRequest.Input, 2)
|
||||
|
||||
// Must use vOld, not vLatest.
|
||||
assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role)
|
||||
assert.Equal(t, "old system prompt", msgText(out.ChatRequest.Input[0]))
|
||||
}
|
||||
|
||||
func TestPreLLMHook_VersionNotFound(t *testing.T) {
|
||||
v := makeVersion(1, "p1", true, versionMsg(schemas.ChatMessageRoleSystem, "hello"))
|
||||
prompt := makePrompt("p1", &v)
|
||||
log := NewMockLogger()
|
||||
|
||||
p := newTestPluginWithLogger(
|
||||
&staticResolver{promptID: "p1", versionNumber: 99},
|
||||
map[string]*tables.TablePrompt{"p1": &prompt},
|
||||
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
|
||||
log,
|
||||
)
|
||||
|
||||
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi")))
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, out.ChatRequest.Input, 1, "input should be unchanged")
|
||||
assert.True(t, log.Warned(), "expected warning for missing version")
|
||||
}
|
||||
|
||||
func TestPreLLMHook_VersionBelongsToDifferentPrompt(t *testing.T) {
|
||||
v := makeVersion(1, "p2", true, versionMsg(schemas.ChatMessageRoleSystem, "wrong"))
|
||||
prompt := makePrompt("p1", nil)
|
||||
log := NewMockLogger()
|
||||
|
||||
p := newTestPluginWithLogger(
|
||||
&staticResolver{promptID: "p1", versionNumber: 1},
|
||||
map[string]*tables.TablePrompt{"p1": &prompt},
|
||||
map[string]map[int]*tables.TablePromptVersion{"p2": {1: &v}},
|
||||
log,
|
||||
)
|
||||
|
||||
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi")))
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, out.ChatRequest.Input, 1, "input should be unchanged")
|
||||
assert.True(t, log.Warned(), "expected warning for version/prompt mismatch")
|
||||
}
|
||||
|
||||
func TestPreLLMHook_NoLatestVersion(t *testing.T) {
|
||||
prompt := makePrompt("p1", nil) // LatestVersion is nil
|
||||
log := NewMockLogger()
|
||||
|
||||
p := newTestPluginWithLogger(
|
||||
&staticResolver{promptID: "p1"},
|
||||
map[string]*tables.TablePrompt{"p1": &prompt},
|
||||
nil,
|
||||
log,
|
||||
)
|
||||
|
||||
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi")))
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, out.ChatRequest.Input, 1, "input should be unchanged")
|
||||
assert.True(t, log.Warned(), "expected warning for missing latest version")
|
||||
}
|
||||
|
||||
func TestPreLLMHook_EmptyTemplate(t *testing.T) {
|
||||
v := makeVersion(1, "p1", true) // no messages
|
||||
prompt := makePrompt("p1", &v)
|
||||
|
||||
p := newTestPlugin(
|
||||
&staticResolver{promptID: "p1"},
|
||||
map[string]*tables.TablePrompt{"p1": &prompt},
|
||||
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
|
||||
)
|
||||
|
||||
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi")))
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, out.ChatRequest.Input, 1)
|
||||
}
|
||||
|
||||
func TestPreLLMHook_MultipleTemplateMessages(t *testing.T) {
|
||||
v := makeVersion(1, "p1", true,
|
||||
versionMsg(schemas.ChatMessageRoleSystem, "sys prompt"),
|
||||
versionMsg(schemas.ChatMessageRoleUser, "example input"),
|
||||
versionMsg(schemas.ChatMessageRoleAssistant, "example output"),
|
||||
)
|
||||
prompt := makePrompt("p1", &v)
|
||||
|
||||
p := newTestPlugin(
|
||||
&staticResolver{promptID: "p1"},
|
||||
map[string]*tables.TablePrompt{"p1": &prompt},
|
||||
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
|
||||
)
|
||||
|
||||
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("actual question")))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out.ChatRequest.Input, 4, "expected 3 template messages + 1 original")
|
||||
|
||||
assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role)
|
||||
assert.Equal(t, "sys prompt", msgText(out.ChatRequest.Input[0]))
|
||||
|
||||
assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[1].Role)
|
||||
assert.Equal(t, "example input", msgText(out.ChatRequest.Input[1]))
|
||||
|
||||
assert.Equal(t, schemas.ChatMessageRoleAssistant, out.ChatRequest.Input[2].Role)
|
||||
assert.Equal(t, "example output", msgText(out.ChatRequest.Input[2]))
|
||||
|
||||
// Original user message must be last, content preserved.
|
||||
assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[3].Role)
|
||||
assert.Equal(t, "actual question", msgText(out.ChatRequest.Input[3]))
|
||||
}
|
||||
|
||||
func TestPreLLMHook_ResolverError(t *testing.T) {
|
||||
log := NewMockLogger()
|
||||
p := newTestPluginWithLogger(
|
||||
&staticResolver{err: errTest("resolver failed")},
|
||||
nil, nil, log,
|
||||
)
|
||||
|
||||
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi")))
|
||||
require.NoError(t, err, "PreLLMHook must not propagate resolver errors")
|
||||
assert.Len(t, out.ChatRequest.Input, 1, "input should be unchanged")
|
||||
assert.True(t, log.Warned(), "expected warning for resolver error")
|
||||
}
|
||||
|
||||
func TestPreLLMHook_MessageJSON_FallbackPath(t *testing.T) {
|
||||
// Messages where Message ([]byte) is nil but MessageJSON is set — the fallback
|
||||
// branch in chatMessagesFromVersionMessages. This mirrors rows loaded from
|
||||
// an older DB schema before AfterFind was established.
|
||||
v := makeVersion(1, "p1", true,
|
||||
versionMsgViaJSON(schemas.ChatMessageRoleSystem, "from json field"),
|
||||
)
|
||||
prompt := makePrompt("p1", &v)
|
||||
|
||||
p := newTestPlugin(
|
||||
&staticResolver{promptID: "p1"},
|
||||
map[string]*tables.TablePrompt{"p1": &prompt},
|
||||
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
|
||||
)
|
||||
|
||||
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi")))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out.ChatRequest.Input, 2)
|
||||
|
||||
assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role)
|
||||
assert.Equal(t, "from json field", msgText(out.ChatRequest.Input[0]))
|
||||
}
|
||||
|
||||
func TestPreLLMHook_ResponsesRequest(t *testing.T) {
|
||||
v := makeVersion(1, "p1", true,
|
||||
versionMsg(schemas.ChatMessageRoleSystem, "be concise"),
|
||||
)
|
||||
prompt := makePrompt("p1", &v)
|
||||
|
||||
p := newTestPlugin(
|
||||
&staticResolver{promptID: "p1"},
|
||||
map[string]*tables.TablePrompt{"p1": &prompt},
|
||||
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
|
||||
)
|
||||
|
||||
userRole := schemas.ResponsesMessageRoleType("user")
|
||||
req := &schemas.BifrostRequest{
|
||||
ResponsesRequest: &schemas.BifrostResponsesRequest{
|
||||
Input: []schemas.ResponsesMessage{{Role: &userRole}},
|
||||
},
|
||||
}
|
||||
|
||||
out, _, err := p.PreLLMHook(bfCtx(), req)
|
||||
require.NoError(t, err)
|
||||
// Template message(s) prepended before the original user input.
|
||||
assert.Greater(t, len(out.ResponsesRequest.Input), 1, "expected template prepended before user message")
|
||||
// Original user message must still be last.
|
||||
last := out.ResponsesRequest.Input[len(out.ResponsesRequest.Input)-1]
|
||||
assert.Equal(t, schemas.ResponsesMessageRoleType("user"), *last.Role)
|
||||
}
|
||||
|
||||
// TestPreLLMHook_PromptSystemMsg_PlusUserInputSystemMsg verifies that when the
|
||||
// prompt template contains a system message and the incoming request also starts
|
||||
// with a system message, both system messages are forwarded to the model —
|
||||
// the plugin's only job is prepending, not de-duplicating or filtering roles.
|
||||
func TestPreLLMHook_PromptSystemMsg_PlusUserInputSystemMsg(t *testing.T) {
|
||||
v := makeVersion(1, "p1", true,
|
||||
versionMsg(schemas.ChatMessageRoleSystem, "prompt system"),
|
||||
)
|
||||
prompt := makePrompt("p1", &v)
|
||||
|
||||
p := newTestPlugin(
|
||||
&staticResolver{promptID: "p1"},
|
||||
map[string]*tables.TablePrompt{"p1": &prompt},
|
||||
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
|
||||
)
|
||||
|
||||
// Incoming request already has its own system message before the user turn.
|
||||
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(
|
||||
systemMsg("user-side system context"),
|
||||
userMsg("actual question"),
|
||||
))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out.ChatRequest.Input, 3, "expected prompt system + user system + user message")
|
||||
|
||||
assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role)
|
||||
assert.Equal(t, "prompt system", msgText(out.ChatRequest.Input[0]))
|
||||
|
||||
assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[1].Role)
|
||||
assert.Equal(t, "user-side system context", msgText(out.ChatRequest.Input[1]))
|
||||
|
||||
assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[2].Role)
|
||||
assert.Equal(t, "actual question", msgText(out.ChatRequest.Input[2]))
|
||||
}
|
||||
|
||||
// TestPreLLMHook_PromptWithToolCallMessages_PlusUserMessage verifies that when
|
||||
// the prompt template contains a full tool-call turn (system → assistant with
|
||||
// tool_calls → tool result) and the user sends a new message, the entire
|
||||
// template is prepended and all fields (ToolCalls, ToolCallID) are preserved.
|
||||
func TestPreLLMHook_PromptWithToolCallMessages_PlusUserMessage(t *testing.T) {
|
||||
const callID = "call_abc123"
|
||||
v := makeVersion(1, "p1", true,
|
||||
versionMsg(schemas.ChatMessageRoleSystem, "you are a weather bot"),
|
||||
versionMsgWithToolCall(callID, "get_weather", `{"city":"Paris"}`),
|
||||
versionMsgToolResult(callID, "Sunny, 22°C"),
|
||||
)
|
||||
prompt := makePrompt("p1", &v)
|
||||
|
||||
p := newTestPlugin(
|
||||
&staticResolver{promptID: "p1"},
|
||||
map[string]*tables.TablePrompt{"p1": &prompt},
|
||||
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
|
||||
)
|
||||
|
||||
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("what about tomorrow?")))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out.ChatRequest.Input, 4, "expected system + assistant(tool_calls) + tool_result + user")
|
||||
|
||||
// System message from prompt.
|
||||
assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role)
|
||||
assert.Equal(t, "you are a weather bot", msgText(out.ChatRequest.Input[0]))
|
||||
|
||||
// Assistant message with tool_calls must carry its ToolCalls slice.
|
||||
assistantMsg := out.ChatRequest.Input[1]
|
||||
assert.Equal(t, schemas.ChatMessageRoleAssistant, assistantMsg.Role)
|
||||
require.NotNil(t, assistantMsg.ChatAssistantMessage, "ChatAssistantMessage must be present")
|
||||
require.Len(t, assistantMsg.ChatAssistantMessage.ToolCalls, 1)
|
||||
tc := assistantMsg.ChatAssistantMessage.ToolCalls[0]
|
||||
require.NotNil(t, tc.ID)
|
||||
assert.Equal(t, callID, *tc.ID)
|
||||
require.NotNil(t, tc.Function.Name)
|
||||
assert.Equal(t, "get_weather", *tc.Function.Name)
|
||||
assert.Equal(t, `{"city":"Paris"}`, tc.Function.Arguments)
|
||||
|
||||
// Tool result message must carry the ToolCallID.
|
||||
toolResultMsg := out.ChatRequest.Input[2]
|
||||
assert.Equal(t, schemas.ChatMessageRoleTool, toolResultMsg.Role)
|
||||
assert.Equal(t, "Sunny, 22°C", msgText(toolResultMsg))
|
||||
require.NotNil(t, toolResultMsg.ChatToolMessage, "ChatToolMessage must be present")
|
||||
require.NotNil(t, toolResultMsg.ChatToolMessage.ToolCallID)
|
||||
assert.Equal(t, callID, *toolResultMsg.ChatToolMessage.ToolCallID)
|
||||
|
||||
// Original user message is last.
|
||||
assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[3].Role)
|
||||
assert.Equal(t, "what about tomorrow?", msgText(out.ChatRequest.Input[3]))
|
||||
}
|
||||
|
||||
// TestPreLLMHook_MultipleSystemMessages_InPromptTemplate verifies that a prompt
|
||||
// template may itself contain multiple system messages and all of them are
|
||||
// prepended before the user's input in the original order.
|
||||
func TestPreLLMHook_MultipleSystemMessages_InPromptTemplate(t *testing.T) {
|
||||
v := makeVersion(1, "p1", true,
|
||||
versionMsg(schemas.ChatMessageRoleSystem, "first system"),
|
||||
versionMsg(schemas.ChatMessageRoleSystem, "second system"),
|
||||
)
|
||||
prompt := makePrompt("p1", &v)
|
||||
|
||||
p := newTestPlugin(
|
||||
&staticResolver{promptID: "p1"},
|
||||
map[string]*tables.TablePrompt{"p1": &prompt},
|
||||
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
|
||||
)
|
||||
|
||||
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hello")))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out.ChatRequest.Input, 3, "expected 2 system messages + user message")
|
||||
|
||||
assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role)
|
||||
assert.Equal(t, "first system", msgText(out.ChatRequest.Input[0]))
|
||||
|
||||
assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[1].Role)
|
||||
assert.Equal(t, "second system", msgText(out.ChatRequest.Input[1]))
|
||||
|
||||
assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[2].Role)
|
||||
assert.Equal(t, "hello", msgText(out.ChatRequest.Input[2]))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// HTTPTransportPreHook
|
||||
// ============================================================
|
||||
|
||||
func TestHTTPTransportPreHook_NilRequest(t *testing.T) {
|
||||
p := newTestPlugin(nil, nil, nil)
|
||||
resp, err := p.HTTPTransportPreHook(bfCtx(), nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, resp)
|
||||
}
|
||||
|
||||
func TestHTTPTransportPreHook_SetsPromptID(t *testing.T) {
|
||||
p := newTestPlugin(nil, nil, nil)
|
||||
ctx := bfCtx()
|
||||
req := &schemas.HTTPRequest{
|
||||
Headers: map[string]string{PromptIDHeader: "my-prompt"},
|
||||
}
|
||||
|
||||
_, _ = p.HTTPTransportPreHook(ctx, req)
|
||||
|
||||
got, _ := ctx.Value(PromptIDKey).(string)
|
||||
assert.Equal(t, "my-prompt", got)
|
||||
}
|
||||
|
||||
func TestHTTPTransportPreHook_SetsVersionID(t *testing.T) {
|
||||
p := newTestPlugin(nil, nil, nil)
|
||||
ctx := bfCtx()
|
||||
req := &schemas.HTTPRequest{
|
||||
Headers: map[string]string{PromptVersionHeader: "42"},
|
||||
}
|
||||
|
||||
_, _ = p.HTTPTransportPreHook(ctx, req)
|
||||
|
||||
got, _ := ctx.Value(PromptVersionKey).(string)
|
||||
assert.Equal(t, "42", got)
|
||||
}
|
||||
|
||||
func TestHTTPTransportPreHook_TrimsWhitespace(t *testing.T) {
|
||||
p := newTestPlugin(nil, nil, nil)
|
||||
ctx := bfCtx()
|
||||
req := &schemas.HTTPRequest{
|
||||
Headers: map[string]string{PromptIDHeader: " padded "},
|
||||
}
|
||||
|
||||
_, _ = p.HTTPTransportPreHook(ctx, req)
|
||||
|
||||
got, _ := ctx.Value(PromptIDKey).(string)
|
||||
assert.Equal(t, "padded", got)
|
||||
}
|
||||
|
||||
func TestHTTPTransportPreHook_WhitespaceOnlyNotSet(t *testing.T) {
|
||||
p := newTestPlugin(nil, nil, nil)
|
||||
ctx := bfCtx()
|
||||
req := &schemas.HTTPRequest{
|
||||
Headers: map[string]string{PromptIDHeader: " "},
|
||||
}
|
||||
|
||||
_, _ = p.HTTPTransportPreHook(ctx, req)
|
||||
|
||||
assert.Nil(t, ctx.Value(PromptIDKey), "whitespace-only header must not be stored in context")
|
||||
}
|
||||
|
||||
func TestHTTPTransportPreHook_CaseInsensitiveHeaders(t *testing.T) {
|
||||
p := newTestPlugin(nil, nil, nil)
|
||||
ctx := bfCtx()
|
||||
// "X-Bf-Prompt-Id" is a title-case variant of the canonical "x-bf-prompt-id".
|
||||
req := &schemas.HTTPRequest{
|
||||
Headers: map[string]string{"X-Bf-Prompt-Id": "upper-case"},
|
||||
}
|
||||
|
||||
_, _ = p.HTTPTransportPreHook(ctx, req)
|
||||
|
||||
got, _ := ctx.Value(PromptIDKey).(string)
|
||||
assert.Equal(t, "upper-case", got)
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// chatMessageFromStoredJSON
|
||||
// ============================================================
|
||||
|
||||
func TestChatMessageFromStoredJSON(t *testing.T) {
|
||||
systemText := "you are helpful"
|
||||
directMsg := schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleSystem,
|
||||
Content: &schemas.ChatMessageContent{ContentStr: &systemText},
|
||||
}
|
||||
directJSON, _ := json.Marshal(directMsg)
|
||||
envelopeJSON := []byte(`{"payload":` + string(directJSON) + `}`)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
wantErr bool
|
||||
wantRole schemas.ChatMessageRole
|
||||
wantText string
|
||||
}{
|
||||
{
|
||||
name: "direct format",
|
||||
input: directJSON,
|
||||
wantRole: schemas.ChatMessageRoleSystem,
|
||||
wantText: systemText,
|
||||
},
|
||||
{
|
||||
name: "envelope format",
|
||||
input: envelopeJSON,
|
||||
wantRole: schemas.ChatMessageRoleSystem,
|
||||
wantText: systemText,
|
||||
},
|
||||
{
|
||||
// UI format for assistant messages: originalType=completion_result,
|
||||
// payload is a BifrostChatResponse; message lives at choices[0].message.
|
||||
name: "completion_result envelope (UI assistant format)",
|
||||
input: []byte(`{"originalType":"completion_result","payload":{"id":"r1","choices":[{"index":0,"message":{"role":"assistant","content":"hi there"},"finish_reason":"stop"}]}}`),
|
||||
wantRole: schemas.ChatMessageRoleAssistant,
|
||||
wantText: "hi there",
|
||||
},
|
||||
{
|
||||
// completion_result with no choices falls through to direct ChatMessage parse.
|
||||
name: "completion_result envelope with empty choices",
|
||||
input: []byte(`{"originalType":"completion_result","payload":{"id":"r1","choices":[]}}`),
|
||||
wantErr: false,
|
||||
wantRole: "",
|
||||
wantText: "",
|
||||
},
|
||||
{
|
||||
name: "empty bytes",
|
||||
input: []byte(""),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "null bytes",
|
||||
input: []byte("null"),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "whitespace only",
|
||||
input: []byte(" "),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "malformed envelope payload",
|
||||
input: []byte(`{"payload":"not-a-chat-message"}`),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
// {"payload":null} — envelope path is skipped (payload is "null"),
|
||||
// falls through to direct decode of the outer object as ChatMessage.
|
||||
// schemas.Unmarshal succeeds on an unknown-field object → empty ChatMessage, no error.
|
||||
name: "envelope with null payload falls through to direct decode",
|
||||
input: []byte(`{"payload":null}`),
|
||||
wantErr: false,
|
||||
wantRole: "",
|
||||
wantText: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := convertVersionMessagesToChatMessages(tt.input)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantRole, got.Role)
|
||||
assert.Equal(t, tt.wantText, msgText(got))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatMessageFromStoredJSON_EnvelopeWithEmptyStringPayload(t *testing.T) {
|
||||
// {"payload":""} — the payload field is a non-null, non-empty JSON string `""`.
|
||||
// The envelope path attempts to unmarshal `""` (a JSON string literal) into
|
||||
// schemas.ChatMessage (a struct), which fails. The error is returned directly;
|
||||
// there is no further fallback.
|
||||
input := []byte(`{"payload":""}`)
|
||||
_, err := convertVersionMessagesToChatMessages(input)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "decoding prompt message envelope payload")
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// parsePromptVersionNumber
|
||||
// ============================================================
|
||||
|
||||
func TestParsePromptVersionNumber(t *testing.T) {
|
||||
type want struct {
|
||||
num int
|
||||
specified bool
|
||||
wantErr bool
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value any // stored in context; nil means don't set
|
||||
want want
|
||||
}{
|
||||
{name: "nil — not specified", value: nil, want: want{0, false, false}},
|
||||
{name: "string valid", value: "99", want: want{99, true, false}},
|
||||
{name: "string empty", value: "", want: want{0, false, false}},
|
||||
{name: "string whitespace", value: " ", want: want{0, false, false}},
|
||||
{name: "string invalid", value: "abc", want: want{0, true, true}},
|
||||
{name: "unknown type", value: struct{}{}, want: want{0, false, false}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := bfCtx()
|
||||
if tt.value != nil {
|
||||
ctx.SetValue(PromptVersionKey, tt.value)
|
||||
}
|
||||
|
||||
num, err := parseNumberFromContext(ctx, PromptVersionKey)
|
||||
|
||||
if tt.want.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.want.num, num)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// mergeChatMessages
|
||||
// ============================================================
|
||||
|
||||
func TestMergeChatMessages(t *testing.T) {
|
||||
t.Run("prepends prefix before existing messages", func(t *testing.T) {
|
||||
dest := []schemas.ChatMessage{userMsg("original")}
|
||||
prefix := []schemas.ChatMessage{systemMsg("system")}
|
||||
mergeChatMessages(&dest, prefix)
|
||||
|
||||
require.Len(t, dest, 2)
|
||||
assert.Equal(t, schemas.ChatMessageRoleSystem, dest[0].Role)
|
||||
assert.Equal(t, "system", msgText(dest[0]))
|
||||
assert.Equal(t, schemas.ChatMessageRoleUser, dest[1].Role)
|
||||
assert.Equal(t, "original", msgText(dest[1]))
|
||||
})
|
||||
|
||||
t.Run("nil dest is a no-op", func(t *testing.T) {
|
||||
// Must not panic.
|
||||
mergeChatMessages(nil, []schemas.ChatMessage{systemMsg("x")})
|
||||
})
|
||||
|
||||
t.Run("empty prefix is a no-op", func(t *testing.T) {
|
||||
dest := []schemas.ChatMessage{userMsg("only")}
|
||||
mergeChatMessages(&dest, nil)
|
||||
assert.Len(t, dest, 1)
|
||||
assert.Equal(t, "only", msgText(dest[0]))
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// chatMessagesFromVersionMessages
|
||||
// ============================================================
|
||||
|
||||
func TestChatMessagesFromVersionMessages_SingleMessage(t *testing.T) {
|
||||
msg := versionMsg(schemas.ChatMessageRoleUser, "hello")
|
||||
out, err := chatMessagesFromVersionMessages([]tables.TablePromptVersionMessage{msg})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out, 1)
|
||||
assert.Equal(t, schemas.ChatMessageRoleUser, out[0].Role)
|
||||
assert.Equal(t, "hello", msgText(out[0]))
|
||||
}
|
||||
|
||||
func TestChatMessagesFromVersionMessages_MessageJSONFallback(t *testing.T) {
|
||||
// Row has no Message bytes but has MessageJSON — exercises the fallback branch.
|
||||
msg := versionMsgViaJSON(schemas.ChatMessageRoleAssistant, "assistant reply")
|
||||
out, err := chatMessagesFromVersionMessages([]tables.TablePromptVersionMessage{msg})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out, 1)
|
||||
assert.Equal(t, schemas.ChatMessageRoleAssistant, out[0].Role)
|
||||
assert.Equal(t, "assistant reply", msgText(out[0]))
|
||||
}
|
||||
|
||||
func TestChatMessagesFromVersionMessages_PreservesOrder(t *testing.T) {
|
||||
msgs := []tables.TablePromptVersionMessage{
|
||||
versionMsg(schemas.ChatMessageRoleSystem, "first"),
|
||||
versionMsg(schemas.ChatMessageRoleUser, "second"),
|
||||
versionMsg(schemas.ChatMessageRoleAssistant, "third"),
|
||||
}
|
||||
out, err := chatMessagesFromVersionMessages(msgs)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out, 3)
|
||||
assert.Equal(t, schemas.ChatMessageRoleSystem, out[0].Role)
|
||||
assert.Equal(t, "first", msgText(out[0]))
|
||||
assert.Equal(t, schemas.ChatMessageRoleUser, out[1].Role)
|
||||
assert.Equal(t, "second", msgText(out[1]))
|
||||
assert.Equal(t, schemas.ChatMessageRoleAssistant, out[2].Role)
|
||||
assert.Equal(t, "third", msgText(out[2]))
|
||||
}
|
||||
|
||||
func TestChatMessagesFromVersionMessages_InvalidJSON(t *testing.T) {
|
||||
bad := tables.TablePromptVersionMessage{Message: []byte(`not-json`)}
|
||||
_, err := chatMessagesFromVersionMessages([]tables.TablePromptVersionMessage{bad})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// PreLLMHook — model params merge and override
|
||||
// ============================================================
|
||||
|
||||
// TestPreLLMHook_VersionParamsApplied_WhenRequestHasNoParams verifies that when
|
||||
// the request carries no Params at all, the version's ModelParams become the
|
||||
// effective parameters on the outgoing request.
|
||||
func TestPreLLMHook_VersionParamsApplied_WhenRequestHasNoParams(t *testing.T) {
|
||||
v := makeVersionWithParams(1, "p1", true,
|
||||
tables.ModelParams{"temperature": float64(0.7)},
|
||||
versionMsg(schemas.ChatMessageRoleSystem, "sys"),
|
||||
)
|
||||
prompt := makePrompt("p1", &v)
|
||||
|
||||
p := newTestPlugin(
|
||||
&staticResolver{promptID: "p1"},
|
||||
map[string]*tables.TablePrompt{"p1": &prompt},
|
||||
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
|
||||
)
|
||||
|
||||
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi")))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, out.ChatRequest.Params, "expected Params to be set from version ModelParams")
|
||||
require.NotNil(t, out.ChatRequest.Params.Temperature)
|
||||
assert.InDelta(t, 0.7, *out.ChatRequest.Params.Temperature, 0.001)
|
||||
}
|
||||
|
||||
// TestPreLLMHook_RequestParamsOverrideVersionParams verifies that when both the
|
||||
// version and the request specify the same parameter, the request value wins.
|
||||
func TestPreLLMHook_RequestParamsOverrideVersionParams(t *testing.T) {
|
||||
reqTemp := 0.9
|
||||
v := makeVersionWithParams(1, "p1", true,
|
||||
tables.ModelParams{"temperature": float64(0.3)},
|
||||
versionMsg(schemas.ChatMessageRoleSystem, "sys"),
|
||||
)
|
||||
prompt := makePrompt("p1", &v)
|
||||
|
||||
p := newTestPlugin(
|
||||
&staticResolver{promptID: "p1"},
|
||||
map[string]*tables.TablePrompt{"p1": &prompt},
|
||||
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
|
||||
)
|
||||
|
||||
req := chatRequestWithParams(&schemas.ChatParameters{Temperature: &reqTemp}, userMsg("hello"))
|
||||
out, _, err := p.PreLLMHook(bfCtx(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, out.ChatRequest.Params)
|
||||
require.NotNil(t, out.ChatRequest.Params.Temperature)
|
||||
assert.InDelta(t, reqTemp, *out.ChatRequest.Params.Temperature, 0.001,
|
||||
"request temperature must override version default temperature")
|
||||
}
|
||||
|
||||
// TestPreLLMHook_RequestParamsPartialOverride verifies the mixed case: version
|
||||
// sets temperature and max_completion_tokens; request overrides only temperature.
|
||||
// The version's max_completion_tokens must still be applied.
|
||||
func TestPreLLMHook_RequestParamsPartialOverride(t *testing.T) {
|
||||
reqTemp := 0.9
|
||||
maxTokens := 200
|
||||
v := makeVersionWithParams(1, "p1", true,
|
||||
tables.ModelParams{
|
||||
"temperature": float64(0.3),
|
||||
"max_completion_tokens": float64(maxTokens),
|
||||
},
|
||||
versionMsg(schemas.ChatMessageRoleSystem, "sys"),
|
||||
)
|
||||
prompt := makePrompt("p1", &v)
|
||||
|
||||
p := newTestPlugin(
|
||||
&staticResolver{promptID: "p1"},
|
||||
map[string]*tables.TablePrompt{"p1": &prompt},
|
||||
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
|
||||
)
|
||||
|
||||
req := chatRequestWithParams(&schemas.ChatParameters{Temperature: &reqTemp}, userMsg("hello"))
|
||||
out, _, err := p.PreLLMHook(bfCtx(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, out.ChatRequest.Params)
|
||||
require.NotNil(t, out.ChatRequest.Params.Temperature)
|
||||
assert.InDelta(t, reqTemp, *out.ChatRequest.Params.Temperature, 0.001,
|
||||
"request temperature must override version temperature")
|
||||
require.NotNil(t, out.ChatRequest.Params.MaxCompletionTokens,
|
||||
"version max_completion_tokens must be applied when request does not override it")
|
||||
assert.Equal(t, maxTokens, *out.ChatRequest.Params.MaxCompletionTokens)
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// PreLLMHook — model field preservation
|
||||
// ============================================================
|
||||
|
||||
// TestPreLLMHook_ModelInVersionParams_DoesNotOverrideRequestModel verifies that
|
||||
// a "model" key inside a version's ModelParams (which the UI may store alongside
|
||||
// temperature etc.) does NOT replace the model field on the outgoing
|
||||
// BifrostChatRequest. The model chosen by the caller must always win.
|
||||
func TestPreLLMHook_ModelInVersionParams_DoesNotOverrideRequestModel(t *testing.T) {
|
||||
v := makeVersionWithParams(1, "p1", true,
|
||||
tables.ModelParams{
|
||||
"model": "openai/gpt-4o",
|
||||
"temperature": float64(0.5),
|
||||
},
|
||||
versionMsg(schemas.ChatMessageRoleSystem, "sys"),
|
||||
)
|
||||
prompt := makePrompt("p1", &v)
|
||||
|
||||
p := newTestPlugin(
|
||||
&staticResolver{promptID: "p1"},
|
||||
map[string]*tables.TablePrompt{"p1": &prompt},
|
||||
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
|
||||
)
|
||||
|
||||
req := chatRequestWithModel("openai/gpt-3.5-turbo", userMsg("hi"))
|
||||
out, _, err := p.PreLLMHook(bfCtx(), req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/gpt-3.5-turbo", out.ChatRequest.Model,
|
||||
"request model must not be overridden by model stored in version ModelParams")
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// loadCache + PreLLMHook integration (store → cache → injection)
|
||||
// ============================================================
|
||||
|
||||
// TestLoadCacheAndPreLLMHook_EndToEnd verifies the full pipeline:
|
||||
// mockStore returns TablePrompt/TablePromptVersion structs → loadCache populates
|
||||
// the in-memory maps → PreLLMHook injects the template messages correctly.
|
||||
// This catches any mismatch between how loadCache builds the maps and how
|
||||
// PreLLMHook reads them (e.g. pointer aliasing, LatestVersion linking).
|
||||
func TestLoadCacheAndPreLLMHook_EndToEnd(t *testing.T) {
|
||||
v := makeVersion(1, "p1", true,
|
||||
versionMsg(schemas.ChatMessageRoleSystem, "end-to-end system"),
|
||||
)
|
||||
prompt := makePrompt("p1", &v)
|
||||
|
||||
ms := &mockStore{
|
||||
prompts: []tables.TablePrompt{prompt},
|
||||
versions: []tables.TablePromptVersion{v},
|
||||
}
|
||||
|
||||
p := newPluginWithStore(ms)
|
||||
require.NoError(t, p.loadCache(context.Background()))
|
||||
|
||||
p.resolver = &staticResolver{promptID: "p1"}
|
||||
|
||||
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("user msg")))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out.ChatRequest.Input, 2)
|
||||
|
||||
assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role)
|
||||
assert.Equal(t, "end-to-end system", msgText(out.ChatRequest.Input[0]))
|
||||
assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[1].Role)
|
||||
assert.Equal(t, "user msg", msgText(out.ChatRequest.Input[1]))
|
||||
}
|
||||
|
||||
// TestLoadCacheAndPreLLMHook_SpecificVersion exercises the loadCache → PreLLMHook
|
||||
// path for a version lookup by ID (not just the LatestVersion pointer).
|
||||
func TestLoadCacheAndPreLLMHook_SpecificVersion(t *testing.T) {
|
||||
vOld := makeVersion(2, "p1", false,
|
||||
versionMsg(schemas.ChatMessageRoleSystem, "old via store"),
|
||||
)
|
||||
vLatest := makeVersion(3, "p1", true,
|
||||
versionMsg(schemas.ChatMessageRoleSystem, "latest via store"),
|
||||
)
|
||||
prompt := makePrompt("p1", &vLatest)
|
||||
|
||||
ms := &mockStore{
|
||||
prompts: []tables.TablePrompt{prompt},
|
||||
versions: []tables.TablePromptVersion{vOld, vLatest},
|
||||
}
|
||||
|
||||
p := newPluginWithStore(ms)
|
||||
require.NoError(t, p.loadCache(context.Background()))
|
||||
|
||||
p.resolver = &staticResolver{promptID: "p1", versionNumber: 2}
|
||||
|
||||
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("question")))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out.ChatRequest.Input, 2)
|
||||
assert.Equal(t, "old via store", msgText(out.ChatRequest.Input[0]))
|
||||
}
|
||||
|
||||
// TestPreLLMHook_AssistantMessage_UIFormat verifies that assistant messages stored
|
||||
// in the Bifrost UI's completion_result format (payload.choices[0].message) are
|
||||
// correctly extracted and prepended to the request.
|
||||
func TestPreLLMHook_AssistantMessage_UIFormat(t *testing.T) {
|
||||
v := makeVersion(1, "p1", true,
|
||||
versionMsg(schemas.ChatMessageRoleSystem, "be helpful"),
|
||||
versionMsgAssistantUIFormat("sure, how can I help?"),
|
||||
)
|
||||
prompt := makePrompt("p1", &v)
|
||||
|
||||
p := newTestPlugin(
|
||||
&staticResolver{promptID: "p1"},
|
||||
map[string]*tables.TablePrompt{"p1": &prompt},
|
||||
map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}},
|
||||
)
|
||||
|
||||
out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hello")))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out.ChatRequest.Input, 3, "expected system + assistant + user")
|
||||
|
||||
assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role)
|
||||
assert.Equal(t, "be helpful", msgText(out.ChatRequest.Input[0]))
|
||||
|
||||
assert.Equal(t, schemas.ChatMessageRoleAssistant, out.ChatRequest.Input[1].Role)
|
||||
assert.Equal(t, "sure, how can I help?", msgText(out.ChatRequest.Input[1]))
|
||||
|
||||
assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[2].Role)
|
||||
assert.Equal(t, "hello", msgText(out.ChatRequest.Input[2]))
|
||||
}
|
||||
Reference in New Issue
Block a user