first commit
This commit is contained in:
644
core/mcp/codemode/starlark/executecode.go
Normal file
644
core/mcp/codemode/starlark/executecode.go
Normal file
@@ -0,0 +1,644 @@
|
||||
//go:build !tinygo && !wasm
|
||||
|
||||
package starlark
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
|
||||
codemcp "github.com/maximhq/bifrost/core/mcp"
|
||||
"github.com/maximhq/bifrost/core/mcp/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"go.starlark.net/starlark"
|
||||
"go.starlark.net/starlarkstruct"
|
||||
"go.starlark.net/syntax"
|
||||
)
|
||||
|
||||
// ExecutionResult represents the result of code execution
|
||||
type ExecutionResult struct {
|
||||
Result interface{} `json:"result"`
|
||||
Logs []string `json:"logs"`
|
||||
Errors *ExecutionError `json:"errors,omitempty"`
|
||||
Environment ExecutionEnvironment `json:"environment"`
|
||||
}
|
||||
|
||||
// ExecutionErrorType represents the type of execution error
|
||||
type ExecutionErrorType string
|
||||
|
||||
const (
|
||||
ExecutionErrorTypeCompile ExecutionErrorType = "compile"
|
||||
ExecutionErrorTypeSyntax ExecutionErrorType = "syntax"
|
||||
ExecutionErrorTypeRuntime ExecutionErrorType = "runtime"
|
||||
)
|
||||
|
||||
// ExecutionError represents an error during code execution
|
||||
type ExecutionError struct {
|
||||
Kind ExecutionErrorType `json:"kind"` // "compile", "syntax", or "runtime"
|
||||
Message string `json:"message"`
|
||||
Hints []string `json:"hints"`
|
||||
}
|
||||
|
||||
// ExecutionEnvironment contains information about the execution environment
|
||||
type ExecutionEnvironment struct {
|
||||
ServerKeys []string `json:"serverKeys"`
|
||||
}
|
||||
|
||||
// createExecuteToolCodeTool creates the executeToolCode tool definition for code mode.
|
||||
// This tool allows executing Python (Starlark) code in a sandboxed interpreter with access to MCP server tools.
|
||||
func (s *StarlarkCodeMode) createExecuteToolCodeTool() schemas.ChatTool {
|
||||
executeToolCodeProps := schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("code", map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Python (Starlark) code to execute. Tool calls are synchronous: result = server.tool(param=\"value\"). " +
|
||||
"Use print() for logging. Assign to 'result' variable to return a value. " +
|
||||
"Retry after fixing syntax or logic errors, especially for read-only flows. Before rerunning code that already made tool calls, inspect prior outputs and avoid replaying stateful operations. " +
|
||||
"Example: items = server.list_items()\nfor item in items:\n print(item[\"name\"])\nresult = items",
|
||||
}),
|
||||
)
|
||||
return schemas.ChatTool{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: codemcp.ToolTypeExecuteToolCode,
|
||||
Description: schemas.Ptr(
|
||||
"Executes Python code in a sandboxed Starlark interpreter with MCP server tool access. " +
|
||||
"Servers are exposed as global objects: result = serverName.toolName(param=\"value\"). " +
|
||||
"This is the final step of the four-tool code mode workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " +
|
||||
"If you have not already read a tool's .pyi stub in this conversation, do that before writing code. " +
|
||||
"Do NOT guess callable tool names from natural language or stale assumptions; use the exact identifier returned by listToolFiles/readToolFile. " +
|
||||
|
||||
"STARLARK DIFFERENCES FROM PYTHON — READ BEFORE WRITING CODE: " +
|
||||
"1. NO try/except/finally/raise — error handling is not supported, and tool failures cannot be caught inside Starlark. " +
|
||||
"2. NO classes — use dicts and functions. " +
|
||||
"3. NO imports, direct network access, or direct filesystem access — use MCP tools instead. " +
|
||||
"4. NO is operator — use == for comparison. " +
|
||||
"5. NO f-strings — use % formatting: \"Hello %s, count=%d\" % (name, n). " +
|
||||
"6. Each executeToolCode call runs in a FRESH ISOLATED SCOPE — no variables, functions, or state persist between calls. Re-fetch data or store it via MCP tools (e.g., SQLite, FileSystem) if needed across calls. " +
|
||||
|
||||
"SYNTAX NOTES: " +
|
||||
"• Synchronous calls — NO async/await: result = server.tool(arg=\"value\") " +
|
||||
"• Use keyword arguments: server.tool(param=\"value\") NOT server.tool({\"param\": \"value\"}) " +
|
||||
"• Access dict values with brackets: result[\"key\"] NOT result.key " +
|
||||
"• Use print() for logging/debugging " +
|
||||
"• List comprehensions: [x for x in items if x[\"active\"]] " +
|
||||
"• String escapes work normally: \"line1\\nline2\" produces a newline " +
|
||||
"• Triple-quoted strings for multiline: \"\"\"multi\\nline\"\"\" " +
|
||||
"• chr(10) for newline character, chr(9) for tab " +
|
||||
"• To return a value, assign to 'result': result = computed_value " +
|
||||
"• MCP tool calls are timeout-limited; avoid long or infinite loops " +
|
||||
|
||||
"AVAILABLE BUILTINS: print, len, range, enumerate, zip, sorted, reversed, min, max, " +
|
||||
"int, float, str, bool, list, dict, tuple, set, hasattr, getattr, type, chr, ord, any, all, hash, repr. " +
|
||||
|
||||
"RETRY POLICY: Retry after fixing syntax or logic errors, especially for read-only flows. Before rerunning code that already made tool calls, inspect prior outputs and avoid replaying stateful operations.",
|
||||
),
|
||||
|
||||
Parameters: &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: executeToolCodeProps,
|
||||
Required: []string{"code"},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// handleExecuteToolCode handles the executeToolCode tool call.
|
||||
func (s *StarlarkCodeMode) handleExecuteToolCode(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) {
|
||||
toolName := "unknown"
|
||||
if toolCall.Function.Name != nil {
|
||||
toolName = *toolCall.Function.Name
|
||||
}
|
||||
s.logger.Debug("%s Handling executeToolCode tool call: %s", codemcp.CodeModeLogPrefix, toolName)
|
||||
|
||||
// Parse tool arguments
|
||||
var arguments map[string]interface{}
|
||||
if err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil {
|
||||
s.logger.Debug("%s Failed to parse tool arguments: %v", codemcp.CodeModeLogPrefix, err)
|
||||
return nil, fmt.Errorf("failed to parse tool arguments: %v", err)
|
||||
}
|
||||
|
||||
code, ok := arguments["code"].(string)
|
||||
if !ok || code == "" {
|
||||
s.logger.Debug("%s Code parameter missing or empty", codemcp.CodeModeLogPrefix)
|
||||
return nil, fmt.Errorf("code parameter is required and must be a non-empty string")
|
||||
}
|
||||
|
||||
s.logger.Debug("%s Starting code execution", codemcp.CodeModeLogPrefix)
|
||||
result := s.executeCode(ctx, code)
|
||||
s.logger.Debug("%s Code execution completed. Success: %v, Has errors: %v, Log count: %d", codemcp.CodeModeLogPrefix, result.Errors == nil, result.Errors != nil, len(result.Logs))
|
||||
|
||||
// Format response text
|
||||
var responseText string
|
||||
var executionSuccess bool = true
|
||||
if result.Errors != nil {
|
||||
s.logger.Debug("%s Formatting error response. Error kind: %s, Message length: %d, Hints count: %d", codemcp.CodeModeLogPrefix, result.Errors.Kind, len(result.Errors.Message), len(result.Errors.Hints))
|
||||
logsText := ""
|
||||
if len(result.Logs) > 0 {
|
||||
logsText = fmt.Sprintf("\n\nPrint Output:\n%s\n", strings.Join(result.Logs, "\n"))
|
||||
}
|
||||
|
||||
responseText = fmt.Sprintf(
|
||||
"Execution %s error:\n\n%s\n\nHints:\n%s%s\n\nEnvironment:\n Available server keys: %s",
|
||||
result.Errors.Kind,
|
||||
result.Errors.Message,
|
||||
strings.Join(result.Errors.Hints, "\n"),
|
||||
logsText,
|
||||
strings.Join(result.Environment.ServerKeys, ", "),
|
||||
)
|
||||
s.logger.Debug("%s Error response formatted. Response length: %d chars", codemcp.CodeModeLogPrefix, len(responseText))
|
||||
} else {
|
||||
hasLogs := len(result.Logs) > 0
|
||||
hasResult := result.Result != nil
|
||||
s.logger.Debug("%s Formatting success response. Has logs: %v, Has result: %v", codemcp.CodeModeLogPrefix, hasLogs, hasResult)
|
||||
|
||||
if !hasLogs && !hasResult {
|
||||
executionSuccess = false
|
||||
s.logger.Debug("%s Execution completed with no data (no logs, no result), marking as failure", codemcp.CodeModeLogPrefix)
|
||||
hints := []string{
|
||||
"Add print() statements throughout your code to debug and see what's happening at each step",
|
||||
"Assign the final value to 'result' variable if you want to return it: result = computed_value",
|
||||
"Check that your tool calls are actually executing and returning data",
|
||||
}
|
||||
responseText = fmt.Sprintf(
|
||||
"Execution completed but produced no data:\n\n"+
|
||||
"The code executed without errors but returned no output (no print output and no result variable).\n\n"+
|
||||
"Hints:\n%s\n\n"+
|
||||
"Environment:\n Available server keys: %s",
|
||||
strings.Join(hints, "\n"),
|
||||
strings.Join(result.Environment.ServerKeys, ", "),
|
||||
)
|
||||
s.logger.Debug("%s No-data failure response formatted. Response length: %d chars", codemcp.CodeModeLogPrefix, len(responseText))
|
||||
} else {
|
||||
if hasLogs {
|
||||
responseText = fmt.Sprintf("Print output:\n%s\n\nExecution completed successfully.",
|
||||
strings.Join(result.Logs, "\n"))
|
||||
} else {
|
||||
responseText = "Execution completed successfully."
|
||||
}
|
||||
if hasResult {
|
||||
resultJSON, err := schemas.MarshalSortedIndent(result.Result, "", " ")
|
||||
if err == nil {
|
||||
responseText += fmt.Sprintf("\nReturn value: %s", string(resultJSON))
|
||||
s.logger.Debug("%s Added return value to response (JSON length: %d chars)", codemcp.CodeModeLogPrefix, len(resultJSON))
|
||||
} else {
|
||||
s.logger.Debug("%s Failed to marshal result to JSON: %v", codemcp.CodeModeLogPrefix, err)
|
||||
}
|
||||
}
|
||||
|
||||
responseText += fmt.Sprintf("\n\nEnvironment:\n Available server keys: %s",
|
||||
strings.Join(result.Environment.ServerKeys, ", "))
|
||||
responseText += "\nNote: This is a Starlark (Python subset) environment. Use MCP tools for external interactions."
|
||||
s.logger.Debug("%s Success response formatted. Response length: %d chars, Server keys: %v", codemcp.CodeModeLogPrefix, len(responseText), result.Environment.ServerKeys)
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Debug("%s Returning tool response message. Execution success: %v", codemcp.CodeModeLogPrefix, executionSuccess)
|
||||
return createToolResponseMessage(toolCall, responseText), nil
|
||||
}
|
||||
|
||||
// executeCode executes Python (Starlark) code in a sandboxed interpreter with MCP tool bindings.
|
||||
func (s *StarlarkCodeMode) executeCode(ctx *schemas.BifrostContext, code string) ExecutionResult {
|
||||
logs := []string{}
|
||||
|
||||
s.logger.Debug("%s Starting Starlark code execution", codemcp.CodeModeLogPrefix)
|
||||
|
||||
// Step 1: Handle empty code
|
||||
trimmedCode := strings.TrimSpace(code)
|
||||
if trimmedCode == "" {
|
||||
return ExecutionResult{
|
||||
Result: nil,
|
||||
Logs: logs,
|
||||
Errors: nil,
|
||||
Environment: ExecutionEnvironment{
|
||||
ServerKeys: []string{},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Build tool bindings for all connected servers
|
||||
availableToolsPerClient := s.clientManager.GetToolPerClient(ctx)
|
||||
serverKeys := make([]string, 0, len(availableToolsPerClient))
|
||||
predeclared := starlark.StringDict{}
|
||||
|
||||
// Thread-safe log appender
|
||||
appendLog := func(msg string) {
|
||||
s.logMu.Lock()
|
||||
defer s.logMu.Unlock()
|
||||
logs = append(logs, msg)
|
||||
}
|
||||
|
||||
s.logger.Debug("%s GetToolPerClient returned %d clients", codemcp.CodeModeLogPrefix, len(availableToolsPerClient))
|
||||
|
||||
for clientName, tools := range availableToolsPerClient {
|
||||
client := s.clientManager.GetClientByName(clientName)
|
||||
if client == nil {
|
||||
s.logger.Warn("%s Client %s not found, skipping", codemcp.CodeModeLogPrefix, clientName)
|
||||
continue
|
||||
}
|
||||
s.logger.Debug("%s [%s] Client found. IsCodeModeClient: %v, ToolCount: %d", codemcp.CodeModeLogPrefix, clientName, client.ExecutionConfig.IsCodeModeClient, len(tools))
|
||||
if !client.ExecutionConfig.IsCodeModeClient || len(tools) == 0 {
|
||||
s.logger.Debug("%s [%s] Skipped: IsCodeModeClient=%v, HasTools=%v", codemcp.CodeModeLogPrefix, clientName, client.ExecutionConfig.IsCodeModeClient, len(tools) > 0)
|
||||
continue
|
||||
}
|
||||
serverKeys = append(serverKeys, clientName)
|
||||
|
||||
// Build struct with tool methods
|
||||
structMembers := starlark.StringDict{}
|
||||
|
||||
for _, tool := range tools {
|
||||
if tool.Function == nil || tool.Function.Name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
originalToolName := tool.Function.Name
|
||||
parsedToolName := getCanonicalToolName(clientName, originalToolName)
|
||||
compatibilityAlias := getCompatibilityToolAlias(clientName, originalToolName)
|
||||
|
||||
s.logger.Debug("%s [%s] Binding tool: %s -> %s", codemcp.CodeModeLogPrefix, clientName, originalToolName, parsedToolName)
|
||||
|
||||
// Capture variables for closure
|
||||
capturedToolName := originalToolName
|
||||
capturedClientName := clientName
|
||||
|
||||
// Create a Starlark builtin function for this tool
|
||||
toolFunc := starlark.NewBuiltin(parsedToolName, func(thread *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
|
||||
// Convert kwargs to Go map
|
||||
goArgs := make(map[string]interface{})
|
||||
for _, kwarg := range kwargs {
|
||||
if len(kwarg) == 2 {
|
||||
key := string(kwarg[0].(starlark.String))
|
||||
value := starlarkToGo(kwarg[1])
|
||||
goArgs[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Also handle positional args if there's exactly one dict argument
|
||||
if len(args) == 1 && len(kwargs) == 0 {
|
||||
if dict, ok := args[0].(*starlark.Dict); ok {
|
||||
for _, item := range dict.Items() {
|
||||
if keyStr, ok := item[0].(starlark.String); ok {
|
||||
goArgs[string(keyStr)] = starlarkToGo(item[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Call the MCP tool
|
||||
result, err := s.callMCPTool(ctx, capturedClientName, capturedToolName, goArgs, appendLog)
|
||||
if err != nil {
|
||||
return starlark.None, fmt.Errorf("tool call failed: %v", err)
|
||||
}
|
||||
|
||||
// Convert result back to Starlark
|
||||
return goToStarlark(result), nil
|
||||
})
|
||||
|
||||
structMembers[parsedToolName] = toolFunc
|
||||
|
||||
if compatibilityAlias != parsedToolName && isValidStarlarkIdentifier(compatibilityAlias) {
|
||||
if _, exists := structMembers[compatibilityAlias]; !exists {
|
||||
structMembers[compatibilityAlias] = toolFunc
|
||||
s.logger.Debug("%s [%s] Added compatibility alias: %s -> %s", codemcp.CodeModeLogPrefix, clientName, compatibilityAlias, parsedToolName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a struct for this server
|
||||
serverStruct := starlarkstruct.FromStringDict(starlark.String(clientName), structMembers)
|
||||
predeclared[clientName] = serverStruct
|
||||
s.logger.Debug("%s [%s] Added server struct with %d tools", codemcp.CodeModeLogPrefix, clientName, len(structMembers))
|
||||
}
|
||||
|
||||
if len(serverKeys) > 0 {
|
||||
s.logger.Debug("%s Bound %d servers with tools: %v", codemcp.CodeModeLogPrefix, len(serverKeys), serverKeys)
|
||||
} else {
|
||||
s.logger.Debug("%s No servers available for code mode execution", codemcp.CodeModeLogPrefix)
|
||||
}
|
||||
|
||||
// Step 3: Create Starlark thread with print function and timeout
|
||||
toolExecutionTimeout := s.getToolExecutionTimeout()
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout)
|
||||
defer cancel()
|
||||
|
||||
thread := &starlark.Thread{
|
||||
Name: "codemode",
|
||||
Print: func(_ *starlark.Thread, msg string) {
|
||||
appendLog(msg)
|
||||
},
|
||||
}
|
||||
|
||||
// Set up cancellation check — watch the context and cancel the Starlark
|
||||
// thread so that infinite loops and other long-running scripts are interrupted
|
||||
// when the execution timeout fires.
|
||||
thread.SetLocal("context", timeoutCtx)
|
||||
go func() {
|
||||
<-timeoutCtx.Done()
|
||||
thread.Cancel(timeoutCtx.Err().Error())
|
||||
}()
|
||||
|
||||
// Step 4: Configure Starlark dialect options for a Python-like experience
|
||||
starlarkOpts := &syntax.FileOptions{
|
||||
TopLevelControl: true, // allow if/for/while at top level (not just inside functions)
|
||||
While: true, // enable while loops
|
||||
Set: true, // enable set() builtin
|
||||
GlobalReassign: true, // allow reassignment to top-level names
|
||||
Recursion: true, // allow recursive functions
|
||||
}
|
||||
|
||||
// Step 5: Execute the code
|
||||
globals, err := starlark.ExecFileOptions(starlarkOpts, thread, "code.star", trimmedCode, predeclared)
|
||||
|
||||
if err != nil {
|
||||
errorMessage := err.Error()
|
||||
hints := generatePythonErrorHints(errorMessage, serverKeys)
|
||||
s.logger.Debug("%s Execution failed: %s", codemcp.CodeModeLogPrefix, errorMessage)
|
||||
|
||||
errorKind := ExecutionErrorTypeRuntime
|
||||
if strings.Contains(errorMessage, "syntax error") {
|
||||
errorKind = ExecutionErrorTypeSyntax
|
||||
}
|
||||
|
||||
return ExecutionResult{
|
||||
Result: nil,
|
||||
Logs: logs,
|
||||
Errors: &ExecutionError{
|
||||
Kind: errorKind,
|
||||
Message: errorMessage,
|
||||
Hints: hints,
|
||||
},
|
||||
Environment: ExecutionEnvironment{
|
||||
ServerKeys: serverKeys,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Step 6: Extract result from globals
|
||||
var result interface{}
|
||||
if resultVal, ok := globals["result"]; ok && resultVal != starlark.None {
|
||||
result = starlarkToGo(resultVal)
|
||||
}
|
||||
|
||||
s.logger.Debug("%s Execution completed successfully", codemcp.CodeModeLogPrefix)
|
||||
return ExecutionResult{
|
||||
Result: result,
|
||||
Logs: logs,
|
||||
Errors: nil,
|
||||
Environment: ExecutionEnvironment{
|
||||
ServerKeys: serverKeys,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// callMCPTool calls an MCP tool and returns the result.
|
||||
func (s *StarlarkCodeMode) callMCPTool(ctx *schemas.BifrostContext, clientName, toolName string, args map[string]interface{}, appendLog func(string)) (interface{}, error) {
|
||||
// Get available tools per client
|
||||
availableToolsPerClient := s.clientManager.GetToolPerClient(ctx)
|
||||
|
||||
// Find the client by name
|
||||
tools, exists := availableToolsPerClient[clientName]
|
||||
if !exists || len(tools) == 0 {
|
||||
return nil, fmt.Errorf("client not found for server name: %s", clientName)
|
||||
}
|
||||
|
||||
// Get client using a tool from this client
|
||||
var client *schemas.MCPClientState
|
||||
for _, tool := range tools {
|
||||
if tool.Function != nil && tool.Function.Name != "" {
|
||||
client = s.clientManager.GetClientForTool(tool.Function.Name)
|
||||
if client != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if client == nil {
|
||||
return nil, fmt.Errorf("client not found for server name: %s", clientName)
|
||||
}
|
||||
|
||||
// Strip the client name prefix from tool name before calling MCP server
|
||||
originalToolName := stripClientPrefix(toolName, clientName)
|
||||
|
||||
originalRequestID, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string)
|
||||
if !ok {
|
||||
originalRequestID = ""
|
||||
}
|
||||
|
||||
// Generate new request ID for this nested tool call
|
||||
var newRequestID string
|
||||
if s.fetchNewRequestIDFunc != nil {
|
||||
newRequestID = s.fetchNewRequestIDFunc(ctx)
|
||||
} else {
|
||||
newRequestID = fmt.Sprintf("exec_%d_%s", time.Now().UnixNano(), toolName)
|
||||
}
|
||||
|
||||
// Create new child context
|
||||
deadline, hasDeadline := ctx.Deadline()
|
||||
if !hasDeadline {
|
||||
deadline = schemas.NoDeadline
|
||||
}
|
||||
nestedCtx := schemas.NewBifrostContext(ctx, deadline)
|
||||
nestedCtx.SetValue(schemas.BifrostContextKeyRequestID, newRequestID)
|
||||
if originalRequestID != "" {
|
||||
nestedCtx.SetValue(schemas.BifrostContextKeyParentMCPRequestID, originalRequestID)
|
||||
}
|
||||
|
||||
// Marshal arguments to JSON for the tool call
|
||||
argsJSON, err := schemas.MarshalSorted(args)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal tool arguments: %v", err)
|
||||
}
|
||||
|
||||
// Build tool call for MCP request
|
||||
toolCallReq := schemas.ChatAssistantMessageToolCall{
|
||||
ID: schemas.Ptr(newRequestID),
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: schemas.Ptr(toolName),
|
||||
Arguments: string(argsJSON),
|
||||
},
|
||||
}
|
||||
|
||||
// Create BifrostMCPRequest
|
||||
mcpRequest := &schemas.BifrostMCPRequest{
|
||||
RequestType: schemas.MCPRequestTypeChatToolCall,
|
||||
ChatAssistantMessageToolCall: &toolCallReq,
|
||||
}
|
||||
|
||||
// Check if plugin pipeline is available
|
||||
if s.pluginPipelineProvider == nil {
|
||||
// Should never happen, but just in case
|
||||
s.logger.Warn("%s Plugin pipeline provider is nil", codemcp.CodeModeLogPrefix)
|
||||
return nil, fmt.Errorf("plugin pipeline provider is nil")
|
||||
}
|
||||
|
||||
// Get plugin pipeline and run hooks
|
||||
pipeline := s.pluginPipelineProvider()
|
||||
if pipeline == nil {
|
||||
// Should never happen, but just in case
|
||||
s.logger.Warn("%s Plugin pipeline is nil", codemcp.CodeModeLogPrefix)
|
||||
return nil, fmt.Errorf("plugin pipeline is nil")
|
||||
}
|
||||
defer s.releasePluginPipeline(pipeline)
|
||||
|
||||
// Run PreMCPHooks
|
||||
preReq, shortCircuit, preCount := pipeline.RunMCPPreHooks(nestedCtx, mcpRequest)
|
||||
|
||||
// Handle short-circuit cases
|
||||
if shortCircuit != nil {
|
||||
if shortCircuit.Response != nil {
|
||||
finalResp, _ := pipeline.RunMCPPostHooks(nestedCtx, shortCircuit.Response, nil, preCount)
|
||||
if finalResp != nil {
|
||||
if finalResp.ChatMessage != nil {
|
||||
return extractResultFromChatMessage(finalResp.ChatMessage), nil
|
||||
}
|
||||
if finalResp.ResponsesMessage != nil {
|
||||
result, err := extractResultFromResponsesMessage(finalResp.ResponsesMessage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result != nil {
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("plugin short-circuit returned invalid response")
|
||||
}
|
||||
if shortCircuit.Error != nil {
|
||||
pipeline.RunMCPPostHooks(nestedCtx, nil, shortCircuit.Error, preCount)
|
||||
if shortCircuit.Error.Error != nil {
|
||||
return nil, fmt.Errorf("%s", shortCircuit.Error.Error.Message)
|
||||
}
|
||||
return nil, fmt.Errorf("plugin short-circuit error")
|
||||
}
|
||||
}
|
||||
|
||||
// If pre-hooks modified the request, extract updated args
|
||||
if preReq != nil && preReq.ChatAssistantMessageToolCall != nil {
|
||||
toolCallReq = *preReq.ChatAssistantMessageToolCall
|
||||
if toolCallReq.Function.Arguments != "" {
|
||||
if err := sonic.Unmarshal([]byte(toolCallReq.Function.Arguments), &args); err != nil {
|
||||
s.logger.Warn("%s Failed to parse modified tool arguments, using original: %v", codemcp.CodeModeLogPrefix, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Execute tool
|
||||
startTime := time.Now()
|
||||
toolNameToCall := originalToolName
|
||||
|
||||
callRequest := mcp.CallToolRequest{
|
||||
Request: mcp.Request{
|
||||
Method: string(mcp.MethodToolsCall),
|
||||
},
|
||||
Params: mcp.CallToolParams{
|
||||
Name: toolNameToCall,
|
||||
Arguments: args,
|
||||
},
|
||||
Header: utils.GetHeadersForToolExecution(nestedCtx, client),
|
||||
}
|
||||
|
||||
toolExecutionTimeout := s.getToolExecutionTimeout()
|
||||
toolCtx, cancel := context.WithTimeout(nestedCtx, toolExecutionTimeout)
|
||||
defer cancel()
|
||||
|
||||
var toolResponse *mcp.CallToolResult
|
||||
var callErr error
|
||||
|
||||
if client.ExecutionConfig.AuthType == schemas.MCPAuthTypePerUserOauth {
|
||||
accessToken, err := utils.ResolvePerUserOAuthToken(nestedCtx, client, s.oauth2Provider)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if client.Conn == nil {
|
||||
// Per-user OAuth with no persistent connection — use a temporary connection.
|
||||
// Assign to outer toolResponse/callErr so the shared logging + post-hooks path runs.
|
||||
toolResponse, callErr = codemcp.ExecuteToolWithUserToken(toolCtx, client.ExecutionConfig, toolNameToCall, args, accessToken, s.logger)
|
||||
if callErr != nil && toolCtx.Err() == context.DeadlineExceeded {
|
||||
callErr = fmt.Errorf("MCP tool call timed out after %v: %s", toolExecutionTimeout, toolName)
|
||||
}
|
||||
} else {
|
||||
callRequest.Header = utils.BuildPerUserOAuthHeaders(callRequest.Header, accessToken)
|
||||
toolResponse, callErr = client.Conn.CallTool(toolCtx, callRequest)
|
||||
}
|
||||
} else {
|
||||
toolResponse, callErr = client.Conn.CallTool(toolCtx, callRequest)
|
||||
}
|
||||
|
||||
latency := time.Since(startTime).Milliseconds()
|
||||
|
||||
var mcpResp *schemas.BifrostMCPResponse
|
||||
var bifrostErr *schemas.BifrostError
|
||||
|
||||
if callErr != nil {
|
||||
s.logger.Debug("%s Tool call failed: %s.%s - %v", codemcp.CodeModeLogPrefix, clientName, toolName, callErr)
|
||||
appendLog(fmt.Sprintf("[TOOL] %s.%s error: %v", clientName, toolName, callErr))
|
||||
bifrostErr = &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: fmt.Sprintf("tool call failed for %s.%s: %v", clientName, toolName, callErr),
|
||||
},
|
||||
}
|
||||
} else {
|
||||
rawResult := extractTextFromMCPResponse(toolResponse, toolName)
|
||||
|
||||
if after, ok := strings.CutPrefix(rawResult, "Error: "); ok {
|
||||
errorMsg := after
|
||||
s.logger.Debug("%s Tool returned error result: %s.%s - %s", codemcp.CodeModeLogPrefix, clientName, toolName, errorMsg)
|
||||
appendLog(fmt.Sprintf("[TOOL] %s.%s error result: %s", clientName, toolName, errorMsg))
|
||||
bifrostErr = &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: errorMsg,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
mcpResp = &schemas.BifrostMCPResponse{
|
||||
ChatMessage: createToolResponseMessage(toolCallReq, rawResult),
|
||||
ExtraFields: schemas.BifrostMCPResponseExtraFields{
|
||||
ClientName: clientName,
|
||||
ToolName: originalToolName,
|
||||
Latency: latency,
|
||||
},
|
||||
}
|
||||
|
||||
resultStr := formatResultForLog(rawResult)
|
||||
logToolName := stripClientPrefix(toolName, clientName)
|
||||
logToolName = strings.ReplaceAll(logToolName, "-", "_")
|
||||
appendLog(fmt.Sprintf("[TOOL] %s.%s raw response: %s", clientName, logToolName, resultStr))
|
||||
}
|
||||
}
|
||||
|
||||
// Run post-hooks
|
||||
finalResp, finalErr := pipeline.RunMCPPostHooks(nestedCtx, mcpResp, bifrostErr, preCount)
|
||||
|
||||
if finalErr != nil {
|
||||
if finalErr.Error != nil {
|
||||
return nil, fmt.Errorf("%s", finalErr.Error.Message)
|
||||
}
|
||||
return nil, fmt.Errorf("tool execution failed")
|
||||
}
|
||||
|
||||
if finalResp == nil {
|
||||
return nil, fmt.Errorf("plugin post-hooks returned invalid response")
|
||||
}
|
||||
|
||||
if finalResp.ChatMessage != nil {
|
||||
return extractResultFromChatMessage(finalResp.ChatMessage), nil
|
||||
}
|
||||
|
||||
if finalResp.ResponsesMessage != nil {
|
||||
result, err := extractResultFromResponsesMessage(finalResp.ResponsesMessage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result != nil {
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("plugin post-hooks returned invalid response")
|
||||
}
|
||||
301
core/mcp/codemode/starlark/getdocs.go
Normal file
301
core/mcp/codemode/starlark/getdocs.go
Normal file
@@ -0,0 +1,301 @@
|
||||
//go:build !tinygo && !wasm
|
||||
|
||||
package starlark
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
codemcp "github.com/maximhq/bifrost/core/mcp"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// createGetToolDocsTool creates the getToolDocs tool definition for code mode.
|
||||
// This tool provides detailed documentation for a specific tool when the compact
|
||||
// signatures from readToolFile are not sufficient to understand how to use it.
|
||||
func (s *StarlarkCodeMode) createGetToolDocsTool() schemas.ChatTool {
|
||||
getToolDocsProps := schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("server", map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The server name (e.g., 'calculator'). Use listToolFiles to see available servers.",
|
||||
}),
|
||||
schemas.KV("tool", map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The tool name (e.g., 'add'). Use readToolFile to see available tools for a server.",
|
||||
}),
|
||||
)
|
||||
return schemas.ChatTool{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: codemcp.ToolTypeGetToolDocs,
|
||||
Description: schemas.Ptr(
|
||||
"Get detailed documentation for a specific tool including full parameter descriptions, " +
|
||||
"types, and usage examples. Use this when the compact signature from readToolFile " +
|
||||
"is not sufficient to understand how to use a tool. " +
|
||||
"Requires both server name and tool name as parameters.",
|
||||
),
|
||||
Parameters: &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: getToolDocsProps,
|
||||
Required: []string{"server", "tool"},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// handleGetToolDocs handles the getToolDocs tool call.
|
||||
func (s *StarlarkCodeMode) handleGetToolDocs(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) {
|
||||
// Parse tool arguments
|
||||
var arguments map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse tool arguments: %v", err)
|
||||
}
|
||||
|
||||
serverName, ok := arguments["server"].(string)
|
||||
if !ok || serverName == "" {
|
||||
return nil, fmt.Errorf("server parameter is required and must be a string")
|
||||
}
|
||||
|
||||
toolName, ok := arguments["tool"].(string)
|
||||
if !ok || toolName == "" {
|
||||
return nil, fmt.Errorf("tool parameter is required and must be a string")
|
||||
}
|
||||
|
||||
// Get available tools per client
|
||||
availableToolsPerClient := s.clientManager.GetToolPerClient(ctx)
|
||||
|
||||
// Find matching client
|
||||
var matchedClientName string
|
||||
var matchedTool *schemas.ChatTool
|
||||
|
||||
serverNameLower := strings.ToLower(serverName)
|
||||
for clientName, tools := range availableToolsPerClient {
|
||||
client := s.clientManager.GetClientByName(clientName)
|
||||
if client == nil {
|
||||
s.logger.Warn("%s Client %s not found, skipping", codemcp.CodeModeLogPrefix, clientName)
|
||||
continue
|
||||
}
|
||||
if !client.ExecutionConfig.IsCodeModeClient || len(tools) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
clientNameLower := strings.ToLower(clientName)
|
||||
if clientNameLower == serverNameLower {
|
||||
matchedClientName = clientName
|
||||
|
||||
// Find the specific tool
|
||||
for i, tool := range tools {
|
||||
if tool.Function != nil {
|
||||
if matchesToolReference(toolName, clientName, tool.Function.Name) {
|
||||
matchedTool = &tools[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Handle server not found
|
||||
if matchedClientName == "" {
|
||||
var availableServers []string
|
||||
for name := range availableToolsPerClient {
|
||||
client := s.clientManager.GetClientByName(name)
|
||||
if client != nil && client.ExecutionConfig.IsCodeModeClient {
|
||||
availableServers = append(availableServers, name)
|
||||
}
|
||||
}
|
||||
errorMsg := fmt.Sprintf("Server '%s' not found. Available servers are:\n", serverName)
|
||||
for _, sn := range availableServers {
|
||||
errorMsg += fmt.Sprintf(" - %s\n", sn)
|
||||
}
|
||||
return createToolResponseMessage(toolCall, errorMsg), nil
|
||||
}
|
||||
|
||||
// Handle tool not found
|
||||
if matchedTool == nil {
|
||||
tools := availableToolsPerClient[matchedClientName]
|
||||
var availableTools []string
|
||||
for _, tool := range tools {
|
||||
if tool.Function != nil {
|
||||
availableTools = append(availableTools, getCanonicalToolName(matchedClientName, tool.Function.Name))
|
||||
}
|
||||
}
|
||||
errorMsg := fmt.Sprintf("Tool '%s' not found in server '%s'. Available tools are:\n", toolName, matchedClientName)
|
||||
for _, t := range availableTools {
|
||||
errorMsg += fmt.Sprintf(" - %s\n", t)
|
||||
}
|
||||
return createToolResponseMessage(toolCall, errorMsg), nil
|
||||
}
|
||||
|
||||
// Generate detailed documentation using generateTypeDefinitions
|
||||
docContent := generateTypeDefinitions(matchedClientName, []schemas.ChatTool{*matchedTool}, true)
|
||||
|
||||
return createToolResponseMessage(toolCall, docContent), nil
|
||||
}
|
||||
|
||||
// generateTypeDefinitions generates Python documentation with docstrings from ChatTool schemas.
|
||||
func generateTypeDefinitions(clientName string, tools []schemas.ChatTool, isToolLevel bool) string {
|
||||
var sb strings.Builder
|
||||
|
||||
// Write comprehensive header
|
||||
sb.WriteString("# ============================================================================\n")
|
||||
if isToolLevel && len(tools) == 1 && tools[0].Function != nil {
|
||||
sb.WriteString(fmt.Sprintf("# Documentation for %s.%s tool\n", clientName, getCanonicalToolName(clientName, tools[0].Function.Name)))
|
||||
} else {
|
||||
sb.WriteString(fmt.Sprintf("# Documentation for %s MCP server\n", clientName))
|
||||
}
|
||||
sb.WriteString("# ============================================================================\n")
|
||||
sb.WriteString("#\n")
|
||||
if isToolLevel && len(tools) == 1 {
|
||||
sb.WriteString("# This file contains Python documentation for a specific tool on this MCP server.\n")
|
||||
} else {
|
||||
sb.WriteString("# This file contains Python documentation for all tools available on this MCP server.\n")
|
||||
}
|
||||
sb.WriteString("#\n")
|
||||
sb.WriteString("# USAGE INSTRUCTIONS:\n")
|
||||
sb.WriteString(fmt.Sprintf("# Call tools using: result = %s.tool_name(param=value)\n", clientName))
|
||||
sb.WriteString("# No async/await needed - calls are synchronous.\n")
|
||||
sb.WriteString("#\n")
|
||||
sb.WriteString("# STARLARK DIFFERENCE FROM PYTHON:\n")
|
||||
sb.WriteString("# for/if/while at top level MUST be inside a function.\n")
|
||||
sb.WriteString("# Wrap loops: def main(): for x in items: ... then result = main()\n")
|
||||
sb.WriteString("#\n")
|
||||
sb.WriteString("# CRITICAL - HANDLING RESPONSES:\n")
|
||||
sb.WriteString("# Tool responses are dicts. To avoid runtime errors:\n")
|
||||
sb.WriteString("# 1. Use print(result) to inspect the response structure first\n")
|
||||
sb.WriteString("# 2. Access dict values with brackets: result[\"key\"] NOT result.key\n")
|
||||
sb.WriteString("# 3. Use .get() for safe access: result.get(\"key\", default)\n")
|
||||
sb.WriteString("#\n")
|
||||
sb.WriteString("# Common error: \"key not found\" or \"has no attribute\"\n")
|
||||
sb.WriteString("# Fix: Use print() to see actual structure, then use result[\"key\"] or .get()\n")
|
||||
sb.WriteString("# ============================================================================\n\n")
|
||||
|
||||
// Generate function definitions for each tool
|
||||
for _, tool := range tools {
|
||||
if tool.Function == nil || tool.Function.Name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
originalToolName := tool.Function.Name
|
||||
toolName := getCanonicalToolName(clientName, originalToolName)
|
||||
description := ""
|
||||
if tool.Function.Description != nil {
|
||||
description = *tool.Function.Description
|
||||
}
|
||||
|
||||
// Generate function signature
|
||||
params := formatPythonParams(tool.Function.Parameters)
|
||||
sb.WriteString(fmt.Sprintf("def %s(%s) -> dict:\n", toolName, params))
|
||||
|
||||
// Generate docstring
|
||||
sb.WriteString(" \"\"\"\n")
|
||||
if description != "" {
|
||||
sb.WriteString(fmt.Sprintf(" %s\n", description))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
// Args section
|
||||
if tool.Function.Parameters != nil && tool.Function.Parameters.Properties != nil {
|
||||
props := tool.Function.Parameters.Properties
|
||||
required := make(map[string]bool)
|
||||
if tool.Function.Parameters.Required != nil {
|
||||
for _, req := range tool.Function.Parameters.Required {
|
||||
required[req] = true
|
||||
}
|
||||
}
|
||||
|
||||
if props.Len() > 0 {
|
||||
sb.WriteString(" Args:\n")
|
||||
|
||||
// Sort properties for consistent output
|
||||
propNames := make([]string, 0, props.Len())
|
||||
props.Range(func(name string, _ interface{}) bool {
|
||||
propNames = append(propNames, name)
|
||||
return true
|
||||
})
|
||||
for i := 0; i < len(propNames)-1; i++ {
|
||||
for j := i + 1; j < len(propNames); j++ {
|
||||
if propNames[i] > propNames[j] {
|
||||
propNames[i], propNames[j] = propNames[j], propNames[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, propName := range propNames {
|
||||
prop, _ := props.Get(propName)
|
||||
propMap, ok := prop.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
pyType := jsonSchemaToPython(propMap)
|
||||
propDesc := ""
|
||||
if desc, ok := propMap["description"].(string); ok && desc != "" {
|
||||
propDesc = desc
|
||||
} else {
|
||||
propDesc = fmt.Sprintf("%s parameter", propName)
|
||||
}
|
||||
|
||||
requiredNote := ""
|
||||
if required[propName] {
|
||||
requiredNote = " (required)"
|
||||
} else {
|
||||
requiredNote = " (optional)"
|
||||
}
|
||||
|
||||
sb.WriteString(fmt.Sprintf(" %s (%s): %s%s\n", propName, pyType, propDesc, requiredNote))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Returns section
|
||||
sb.WriteString(" Returns:\n")
|
||||
sb.WriteString(" dict: Response from the tool. Structure varies by tool.\n")
|
||||
sb.WriteString(" Use print(result) to inspect the actual structure.\n")
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Example section
|
||||
sb.WriteString(" Example:\n")
|
||||
sb.WriteString(fmt.Sprintf(" result = %s.%s(%s)\n", clientName, toolName, getExampleParams(tool.Function.Parameters)))
|
||||
sb.WriteString(" print(result) # Always inspect response first!\n")
|
||||
sb.WriteString(" value = result.get(\"key\", default) # Safe access\n")
|
||||
sb.WriteString(" \"\"\"\n")
|
||||
sb.WriteString(" ...\n\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// getExampleParams generates example parameter usage for a function.
|
||||
func getExampleParams(params *schemas.ToolFunctionParameters) string {
|
||||
if params == nil || params.Properties == nil || params.Properties.Len() == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
required := make(map[string]bool)
|
||||
if params.Required != nil {
|
||||
for _, req := range params.Required {
|
||||
required[req] = true
|
||||
}
|
||||
}
|
||||
|
||||
keys := params.Properties.Keys()
|
||||
|
||||
// Get first required param as example
|
||||
for _, name := range keys {
|
||||
if required[name] {
|
||||
return fmt.Sprintf("%s=\"...\"", name)
|
||||
}
|
||||
}
|
||||
|
||||
// If no required, get first param
|
||||
if len(keys) > 0 {
|
||||
return fmt.Sprintf("%s=\"...\"", keys[0])
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
23
core/mcp/codemode/starlark/init.go
Normal file
23
core/mcp/codemode/starlark/init.go
Normal file
@@ -0,0 +1,23 @@
|
||||
//go:build !tinygo && !wasm
|
||||
|
||||
package starlark
|
||||
|
||||
import "github.com/maximhq/bifrost/core/schemas"
|
||||
|
||||
// noopLogger is a no-op implementation of schemas.Logger used as a fallback
|
||||
// when no logger is provided.
|
||||
type noopLogger struct{}
|
||||
|
||||
func (noopLogger) Debug(string, ...any) {}
|
||||
func (noopLogger) Info(string, ...any) {}
|
||||
func (noopLogger) Warn(string, ...any) {}
|
||||
func (noopLogger) Error(string, ...any) {}
|
||||
func (noopLogger) Fatal(string, ...any) {}
|
||||
func (noopLogger) SetLevel(schemas.LogLevel) {}
|
||||
func (noopLogger) SetOutputType(schemas.LoggerOutputType) {}
|
||||
func (noopLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
|
||||
return schemas.NoopLogEvent
|
||||
}
|
||||
|
||||
// defaultLogger is used when nil is passed to NewStarlarkCodeMode.
|
||||
var defaultLogger schemas.Logger = noopLogger{}
|
||||
231
core/mcp/codemode/starlark/listfiles.go
Normal file
231
core/mcp/codemode/starlark/listfiles.go
Normal file
@@ -0,0 +1,231 @@
|
||||
//go:build !tinygo && !wasm
|
||||
|
||||
package starlark
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
codemcp "github.com/maximhq/bifrost/core/mcp"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// createListToolFilesTool creates the listToolFiles tool definition for code mode.
|
||||
// This tool allows listing all available virtual .pyi stub files for connected MCP servers.
|
||||
// The description is dynamically generated based on the configured CodeModeBindingLevel.
|
||||
func (s *StarlarkCodeMode) createListToolFilesTool() schemas.ChatTool {
|
||||
bindingLevel := s.GetBindingLevel()
|
||||
var description string
|
||||
|
||||
if bindingLevel == schemas.CodeModeBindingLevelServer {
|
||||
description = "Returns a tree structure listing all virtual .pyi stub files available for connected MCP servers. " +
|
||||
"Each server has a corresponding file (e.g., servers/<serverName>.pyi) that contains compact Python signatures for all tools in that server. " +
|
||||
"Safe workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " +
|
||||
"Use readToolFile before executeToolCode to read a specific server file and confirm exact callable tool names and parameters. " +
|
||||
"Use getToolDocs if you need detailed documentation for a specific tool. " +
|
||||
"In code, access tools via: server_name.tool_name(param=value). " +
|
||||
"The server names used in code correspond to the human-readable names shown in this listing. " +
|
||||
"This tool is generic and works with any set of servers connected at runtime. " +
|
||||
"Always check this tool whenever you are unsure about what tools you have available or if you want to verify available servers and their tools."
|
||||
} else {
|
||||
description = "Returns a tree structure listing all virtual .pyi stub files available for connected MCP servers, organized by individual tool. " +
|
||||
"Each tool has a corresponding file (e.g., servers/<serverName>/<toolName>.pyi) that contains compact Python signatures for that specific tool. " +
|
||||
"The <toolName> shown in each filename is the exact canonical identifier exposed in executeToolCode. " +
|
||||
"Safe workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " +
|
||||
"Use readToolFile before executeToolCode to confirm the exact signature and parameters for the tool you want to call. " +
|
||||
"Use getToolDocs if you need detailed documentation for a specific tool. " +
|
||||
"In code, access tools via: server_name.tool_name(param=value). " +
|
||||
"The server names used in code correspond to the human-readable names shown in this listing. " +
|
||||
"This tool is generic and works with any set of servers connected at runtime. " +
|
||||
"Always check this tool whenever you are unsure about what tools you have available or if you want to verify available servers and their tools."
|
||||
}
|
||||
|
||||
return schemas.ChatTool{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: codemcp.ToolTypeListToolFiles,
|
||||
Description: schemas.Ptr(description),
|
||||
Parameters: &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: schemas.NewOrderedMap(),
|
||||
Required: []string{},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// handleListToolFiles handles the listToolFiles tool call.
|
||||
// It builds a tree structure listing all virtual .pyi files available for code mode clients.
|
||||
func (s *StarlarkCodeMode) handleListToolFiles(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) {
|
||||
availableToolsPerClient := s.clientManager.GetToolPerClient(ctx)
|
||||
|
||||
if len(availableToolsPerClient) == 0 {
|
||||
responseText := "No servers are currently connected. There are no virtual .pyi files available. " +
|
||||
"Please ensure servers are connected before using this tool."
|
||||
return createToolResponseMessage(toolCall, responseText), nil
|
||||
}
|
||||
|
||||
// Get the code mode binding level
|
||||
bindingLevel := s.GetBindingLevel()
|
||||
|
||||
// Build file list based on binding level
|
||||
var files []string
|
||||
codeModeServerCount := 0
|
||||
|
||||
for clientName, tools := range availableToolsPerClient {
|
||||
client := s.clientManager.GetClientByName(clientName)
|
||||
if client == nil {
|
||||
s.logger.Warn("%s Client %s not found, skipping", codemcp.CodeModeLogPrefix, clientName)
|
||||
continue
|
||||
}
|
||||
if !client.ExecutionConfig.IsCodeModeClient {
|
||||
continue
|
||||
}
|
||||
codeModeServerCount++
|
||||
|
||||
if bindingLevel == schemas.CodeModeBindingLevelServer {
|
||||
// Server-level: one file per server
|
||||
files = append(files, fmt.Sprintf("servers/%s.pyi", clientName))
|
||||
} else {
|
||||
// Tool-level: one file per tool
|
||||
for _, tool := range tools {
|
||||
if tool.Function != nil && tool.Function.Name != "" {
|
||||
toolName := getCanonicalToolName(clientName, tool.Function.Name)
|
||||
if err := validateNormalizedToolName(toolName); err != nil {
|
||||
s.logger.Warn("%s Skipping tool '%s' from client '%s': %v", codemcp.CodeModeLogPrefix, tool.Function.Name, clientName, err)
|
||||
continue
|
||||
}
|
||||
toolFileName := fmt.Sprintf("servers/%s/%s.pyi", clientName, toolName)
|
||||
files = append(files, toolFileName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if codeModeServerCount == 0 {
|
||||
responseText := "Servers are connected but none are configured for code mode. " +
|
||||
"There are no virtual .pyi files available."
|
||||
return createToolResponseMessage(toolCall, responseText), nil
|
||||
}
|
||||
|
||||
// Build tree structure from file list
|
||||
responseText := buildListToolFilesResponse(files, bindingLevel)
|
||||
return createToolResponseMessage(toolCall, responseText), nil
|
||||
}
|
||||
|
||||
func buildListToolFilesResponse(files []string, bindingLevel schemas.CodeModeBindingLevel) string {
|
||||
tree := buildVFSTree(files)
|
||||
if tree == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
header := []string{
|
||||
"# Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode",
|
||||
}
|
||||
|
||||
if bindingLevel == schemas.CodeModeBindingLevelServer {
|
||||
header = append(header, "# Read the server .pyi file before executeToolCode to confirm exact tool names and parameters.")
|
||||
} else {
|
||||
header = append(header,
|
||||
"# Filenames below use the exact canonical tool identifiers available in executeToolCode.",
|
||||
"# Still call readToolFile before executeToolCode to confirm parameters and return shape.",
|
||||
)
|
||||
}
|
||||
|
||||
return strings.Join(append(header, "", tree), "\n")
|
||||
}
|
||||
|
||||
// VFS tree node structure for building hierarchical file structure
|
||||
type treeNode struct {
|
||||
isDirectory bool
|
||||
children map[string]*treeNode
|
||||
}
|
||||
|
||||
// buildVFSTree creates a hierarchical tree structure from a flat list of file paths.
|
||||
func buildVFSTree(files []string) string {
|
||||
if len(files) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
root := &treeNode{
|
||||
isDirectory: true,
|
||||
children: make(map[string]*treeNode),
|
||||
}
|
||||
|
||||
// Parse all files and build tree structure
|
||||
for _, file := range files {
|
||||
parts := strings.Split(file, "/")
|
||||
current := root
|
||||
|
||||
// Create all intermediate directories and final file
|
||||
for i, part := range parts {
|
||||
if _, exists := current.children[part]; !exists {
|
||||
current.children[part] = &treeNode{
|
||||
isDirectory: i < len(parts)-1, // Last part is file, not directory
|
||||
children: make(map[string]*treeNode),
|
||||
}
|
||||
}
|
||||
current = current.children[part]
|
||||
}
|
||||
}
|
||||
|
||||
// Render tree structure with proper indentation
|
||||
var lines []string
|
||||
renderTreeNode(root, "", &lines, true)
|
||||
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// renderTreeNode recursively renders a tree node and its children with proper indentation.
|
||||
func renderTreeNode(node *treeNode, indent string, lines *[]string, isRoot bool) {
|
||||
// Get sorted keys for consistent output
|
||||
var keys []string
|
||||
for key := range node.children {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
// Simple bubble sort for small lists (good enough for this use case)
|
||||
for i := 0; i < len(keys); i++ {
|
||||
for j := i + 1; j < len(keys); j++ {
|
||||
if keys[j] < keys[i] {
|
||||
keys[i], keys[j] = keys[j], keys[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
child := node.children[key]
|
||||
|
||||
// Format the line
|
||||
var line string
|
||||
if isRoot {
|
||||
// Root level - no indentation
|
||||
if child.isDirectory {
|
||||
line = key + "/"
|
||||
} else {
|
||||
line = key
|
||||
}
|
||||
} else {
|
||||
// Non-root levels - add indentation
|
||||
if child.isDirectory {
|
||||
line = indent + key + "/"
|
||||
} else {
|
||||
line = indent + key
|
||||
}
|
||||
}
|
||||
|
||||
*lines = append(*lines, line)
|
||||
|
||||
// Recurse into children
|
||||
if child.isDirectory && len(child.children) > 0 {
|
||||
var nextIndent string
|
||||
if isRoot {
|
||||
nextIndent = " "
|
||||
} else {
|
||||
nextIndent = indent + " "
|
||||
}
|
||||
renderTreeNode(child, nextIndent, lines, false)
|
||||
}
|
||||
}
|
||||
}
|
||||
443
core/mcp/codemode/starlark/readfile.go
Normal file
443
core/mcp/codemode/starlark/readfile.go
Normal file
@@ -0,0 +1,443 @@
|
||||
//go:build !tinygo && !wasm
|
||||
|
||||
package starlark
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
codemcp "github.com/maximhq/bifrost/core/mcp"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// createReadToolFileTool creates the readToolFile tool definition for code mode.
|
||||
// This tool allows reading virtual .pyi stub files for specific MCP servers/tools,
|
||||
// generating Python type stubs from the server's tool schemas.
|
||||
func (s *StarlarkCodeMode) createReadToolFileTool() schemas.ChatTool {
|
||||
bindingLevel := s.GetBindingLevel()
|
||||
|
||||
var fileNameDescription, toolDescription string
|
||||
|
||||
if bindingLevel == schemas.CodeModeBindingLevelServer {
|
||||
fileNameDescription = "The virtual filename from listToolFiles in format: servers/<serverName>.pyi (e.g., 'servers/calculator.pyi')"
|
||||
toolDescription = "Reads a virtual .pyi stub file for a specific MCP server, returning compact Python function signatures " +
|
||||
"for all tools available on that server. The fileName should be in format servers/<serverName>.pyi as listed by listToolFiles. " +
|
||||
"The function performs case-insensitive matching and removes the .pyi extension. " +
|
||||
"This is the authoritative source for the exact callable tool names and parameters to use in executeToolCode. " +
|
||||
"Each tool can be accessed in code via: serverName.tool_name(param=value). " +
|
||||
"If the compact signature is not enough to understand a tool, use getToolDocs for detailed documentation. " +
|
||||
"Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " +
|
||||
"IMPORTANT: If the response header shows 'Total lines: X (this is the complete file)', " +
|
||||
"do NOT call this tool again with startLine/endLine - you already have the complete file."
|
||||
} else {
|
||||
fileNameDescription = "The virtual filename from listToolFiles in format: servers/<serverName>/<toolName>.pyi (e.g., 'servers/calculator/add.pyi')"
|
||||
toolDescription = "Reads a virtual .pyi stub file for a specific tool, returning its compact Python function signature. " +
|
||||
"The fileName should be in format servers/<serverName>/<toolName>.pyi as listed by listToolFiles. " +
|
||||
"The function performs case-insensitive matching and removes the .pyi extension. " +
|
||||
"This is the authoritative source for the exact callable tool name and arguments to use in executeToolCode. " +
|
||||
"The tool can be accessed in code via: serverName.tool_name(param=value) using the def name shown in the file. " +
|
||||
"If the compact signature is not enough to understand the tool, use getToolDocs for detailed documentation. " +
|
||||
"Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " +
|
||||
"IMPORTANT: If the response header shows 'Total lines: X (this is the complete file)', " +
|
||||
"do NOT call this tool again with startLine/endLine - you already have the complete file."
|
||||
}
|
||||
|
||||
readToolFileProps := schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("fileName", map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": fileNameDescription,
|
||||
}),
|
||||
schemas.KV("startLine", map[string]interface{}{
|
||||
"type": "number",
|
||||
"description": "Optional 1-based starting line number for partial file read. Usually not needed - omit to read the entire file. Files are typically small (under 50 lines).",
|
||||
}),
|
||||
schemas.KV("endLine", map[string]interface{}{
|
||||
"type": "number",
|
||||
"description": "Optional 1-based ending line number for partial file read. Usually not needed - omit to read the entire file. Will be clamped to actual file size if too large.",
|
||||
}),
|
||||
)
|
||||
return schemas.ChatTool{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: codemcp.ToolTypeReadToolFile,
|
||||
Description: schemas.Ptr(toolDescription),
|
||||
Parameters: &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: readToolFileProps,
|
||||
Required: []string{"fileName"},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// handleReadToolFile handles the readToolFile tool call.
|
||||
func (s *StarlarkCodeMode) handleReadToolFile(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) {
|
||||
// Parse tool arguments
|
||||
var arguments map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse tool arguments: %v", err)
|
||||
}
|
||||
|
||||
fileName, ok := arguments["fileName"].(string)
|
||||
if !ok || fileName == "" {
|
||||
return nil, fmt.Errorf("fileName parameter is required and must be a string")
|
||||
}
|
||||
|
||||
// Parse the file path to extract server name and optional tool name
|
||||
serverName, toolName, isToolLevel := parseVFSFilePath(fileName)
|
||||
|
||||
// Get available tools per client
|
||||
availableToolsPerClient := s.clientManager.GetToolPerClient(ctx)
|
||||
|
||||
// Find matching client
|
||||
var matchedClientName string
|
||||
var matchedTools []schemas.ChatTool
|
||||
matchCount := 0
|
||||
|
||||
for clientName, tools := range availableToolsPerClient {
|
||||
client := s.clientManager.GetClientByName(clientName)
|
||||
if client == nil {
|
||||
s.logger.Warn("%s Client %s not found, skipping", codemcp.CodeModeLogPrefix, clientName)
|
||||
continue
|
||||
}
|
||||
if !client.ExecutionConfig.IsCodeModeClient || len(tools) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
clientNameLower := strings.ToLower(clientName)
|
||||
serverNameLower := strings.ToLower(serverName)
|
||||
|
||||
if clientNameLower == serverNameLower {
|
||||
matchCount++
|
||||
if matchCount > 1 {
|
||||
// Multiple matches found
|
||||
errorMsg := fmt.Sprintf("Multiple servers match filename '%s':\n", fileName)
|
||||
for name := range availableToolsPerClient {
|
||||
if strings.ToLower(name) == serverNameLower {
|
||||
errorMsg += fmt.Sprintf(" - %s\n", name)
|
||||
}
|
||||
}
|
||||
errorMsg += "\nPlease use a more specific filename. Use the exact display name from listToolFiles to avoid ambiguity."
|
||||
return createToolResponseMessage(toolCall, errorMsg), nil
|
||||
}
|
||||
|
||||
matchedClientName = clientName
|
||||
|
||||
if isToolLevel {
|
||||
// Tool-level: filter to specific tool
|
||||
var foundTool *schemas.ChatTool
|
||||
for i, tool := range tools {
|
||||
if tool.Function != nil {
|
||||
if matchesToolReference(toolName, clientName, tool.Function.Name) {
|
||||
foundTool = &tools[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if foundTool == nil {
|
||||
availableTools := make([]string, 0)
|
||||
for _, tool := range tools {
|
||||
if tool.Function != nil {
|
||||
availableTools = append(availableTools, getCanonicalToolName(clientName, tool.Function.Name))
|
||||
}
|
||||
}
|
||||
errorMsg := fmt.Sprintf("Tool '%s' not found in server '%s'. Available tools in this server are:\n", toolName, clientName)
|
||||
for _, t := range availableTools {
|
||||
errorMsg += fmt.Sprintf(" - servers/%s/%s.pyi\n", clientName, t)
|
||||
}
|
||||
return createToolResponseMessage(toolCall, errorMsg), nil
|
||||
}
|
||||
|
||||
matchedTools = []schemas.ChatTool{*foundTool}
|
||||
} else {
|
||||
// Server-level: use all tools
|
||||
matchedTools = tools
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if matchedClientName == "" {
|
||||
// Build helpful error message with available files
|
||||
bindingLevel := s.GetBindingLevel()
|
||||
var availableFiles []string
|
||||
|
||||
for name := range availableToolsPerClient {
|
||||
if bindingLevel == schemas.CodeModeBindingLevelServer {
|
||||
availableFiles = append(availableFiles, fmt.Sprintf("servers/%s.pyi", name))
|
||||
} else {
|
||||
client := s.clientManager.GetClientByName(name)
|
||||
if client != nil && client.ExecutionConfig.IsCodeModeClient {
|
||||
if tools, ok := availableToolsPerClient[name]; ok {
|
||||
for _, tool := range tools {
|
||||
if tool.Function != nil {
|
||||
availableFiles = append(availableFiles, fmt.Sprintf("servers/%s/%s.pyi", name, getCanonicalToolName(name, tool.Function.Name)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
errorMsg := fmt.Sprintf("No server found matching '%s'. Available virtual files are:\n", serverName)
|
||||
for _, f := range availableFiles {
|
||||
errorMsg += fmt.Sprintf(" - %s\n", f)
|
||||
}
|
||||
return createToolResponseMessage(toolCall, errorMsg), nil
|
||||
}
|
||||
|
||||
// Generate compact Python signatures
|
||||
fileContent := generateCompactSignatures(matchedClientName, matchedTools, isToolLevel)
|
||||
lines := strings.Split(fileContent, "\n")
|
||||
totalLines := len(lines)
|
||||
|
||||
// Prepend total lines info so LLM knows the file size upfront
|
||||
fileContent = fmt.Sprintf("# Total lines: %d (this is the complete file, no need to paginate)\n%s", totalLines+1, fileContent)
|
||||
// Recalculate lines after prepending
|
||||
lines = strings.Split(fileContent, "\n")
|
||||
totalLines = len(lines)
|
||||
|
||||
// Handle line slicing if provided
|
||||
var startLine, endLine *int
|
||||
if sl, ok := arguments["startLine"].(float64); ok {
|
||||
slInt := int(sl)
|
||||
startLine = &slInt
|
||||
}
|
||||
if el, ok := arguments["endLine"].(float64); ok {
|
||||
elInt := int(el)
|
||||
endLine = &elInt
|
||||
}
|
||||
|
||||
if startLine != nil || endLine != nil {
|
||||
start := 1
|
||||
if startLine != nil {
|
||||
start = *startLine
|
||||
}
|
||||
end := totalLines
|
||||
if endLine != nil {
|
||||
end = *endLine
|
||||
}
|
||||
|
||||
// Clamp values to valid range instead of erroring
|
||||
// This handles cases where LLM requests more lines than exist
|
||||
if start < 1 {
|
||||
start = 1
|
||||
}
|
||||
if start > totalLines {
|
||||
start = totalLines
|
||||
}
|
||||
if end < 1 {
|
||||
end = 1
|
||||
}
|
||||
if end > totalLines {
|
||||
end = totalLines
|
||||
}
|
||||
if start > end {
|
||||
// If start > end after clamping, just return the start line
|
||||
end = start
|
||||
}
|
||||
|
||||
// Slice lines (convert to 0-based indexing)
|
||||
selectedLines := lines[start-1 : end]
|
||||
fileContent = strings.Join(selectedLines, "\n")
|
||||
}
|
||||
|
||||
return createToolResponseMessage(toolCall, fileContent), nil
|
||||
}
|
||||
|
||||
// parseVFSFilePath parses a VFS file path and extracts the server name and optional tool name.
|
||||
func parseVFSFilePath(fileName string) (serverName, toolName string, isToolLevel bool) {
|
||||
// Remove .pyi extension
|
||||
basePath := strings.TrimSuffix(fileName, ".pyi")
|
||||
|
||||
// Remove "servers/" prefix if present
|
||||
basePath = strings.TrimPrefix(basePath, "servers/")
|
||||
|
||||
// Defensive validation: reject paths with path traversal attempts
|
||||
if strings.Contains(basePath, "..") {
|
||||
// Return empty to indicate invalid path
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
// Check for path separator
|
||||
parts := strings.Split(basePath, "/")
|
||||
if len(parts) == 2 {
|
||||
// Tool-level: "serverName/toolName"
|
||||
// Validate that tool name doesn't contain additional path separators or traversal
|
||||
if parts[1] == "" || strings.Contains(parts[1], "/") || strings.Contains(parts[1], "..") {
|
||||
// Invalid tool name, treat as server-level
|
||||
return parts[0], "", false
|
||||
}
|
||||
return parts[0], parts[1], true
|
||||
}
|
||||
// Server-level: "serverName"
|
||||
// Validate server name doesn't contain path separators or traversal
|
||||
if strings.Contains(basePath, "/") || strings.Contains(basePath, "..") {
|
||||
// Invalid path
|
||||
return "", "", false
|
||||
}
|
||||
return basePath, "", false
|
||||
}
|
||||
|
||||
// generateCompactSignatures generates compact Python function signatures for tools.
|
||||
func generateCompactSignatures(clientName string, tools []schemas.ChatTool, isToolLevel bool) string {
|
||||
var sb strings.Builder
|
||||
|
||||
// Minimal header
|
||||
if isToolLevel && len(tools) == 1 && tools[0].Function != nil {
|
||||
toolName := getCanonicalToolName(clientName, tools[0].Function.Name)
|
||||
sb.WriteString(fmt.Sprintf("# %s.%s tool\n", clientName, toolName))
|
||||
} else {
|
||||
sb.WriteString(fmt.Sprintf("# %s server tools\n", clientName))
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("# Usage: %s.tool_name(param=value)\n", clientName))
|
||||
sb.WriteString("# The def names below are the exact callable names to use in executeToolCode.\n")
|
||||
sb.WriteString("# Read this file before executeToolCode to confirm parameters and return shape.\n")
|
||||
sb.WriteString(fmt.Sprintf("# For detailed docs: use getToolDocs(server=\"%s\", tool=\"tool_name\")\n", clientName))
|
||||
sb.WriteString("# Note: Descriptions may be truncated. Use getToolDocs for full details.\n\n")
|
||||
|
||||
for _, tool := range tools {
|
||||
if tool.Function == nil || tool.Function.Name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
toolName := getCanonicalToolName(clientName, tool.Function.Name)
|
||||
|
||||
// Format inline parameters in Python style
|
||||
params := formatPythonParams(tool.Function.Parameters)
|
||||
|
||||
// Get description (truncate if too long)
|
||||
desc := ""
|
||||
if tool.Function.Description != nil && *tool.Function.Description != "" {
|
||||
desc = *tool.Function.Description
|
||||
// Truncate long descriptions to first sentence or 80 chars
|
||||
if idx := strings.Index(desc, ". "); idx > 0 && idx < 80 {
|
||||
desc = desc[:idx+1]
|
||||
} else if len(desc) > 80 {
|
||||
desc = desc[:77] + "..."
|
||||
}
|
||||
}
|
||||
|
||||
// Write Python signature: def tool_name(param: type, param: type = None) -> dict: # description
|
||||
if desc != "" {
|
||||
sb.WriteString(fmt.Sprintf("def %s(%s) -> dict: # %s\n", toolName, params, desc))
|
||||
} else {
|
||||
sb.WriteString(fmt.Sprintf("def %s(%s) -> dict\n", toolName, params))
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// formatPythonParams formats tool parameters as Python function parameters.
|
||||
func formatPythonParams(params *schemas.ToolFunctionParameters) string {
|
||||
if params == nil || params.Properties == nil || params.Properties.Len() == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
props := params.Properties
|
||||
required := make(map[string]bool)
|
||||
if params.Required != nil {
|
||||
for _, req := range params.Required {
|
||||
required[req] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Sort properties: required first, then optional, alphabetically within each group
|
||||
requiredNames := make([]string, 0)
|
||||
optionalNames := make([]string, 0)
|
||||
props.Range(func(name string, _ interface{}) bool {
|
||||
if required[name] {
|
||||
requiredNames = append(requiredNames, name)
|
||||
} else {
|
||||
optionalNames = append(optionalNames, name)
|
||||
}
|
||||
return true
|
||||
})
|
||||
// Simple alphabetical sort for each group
|
||||
for i := 0; i < len(requiredNames)-1; i++ {
|
||||
for j := i + 1; j < len(requiredNames); j++ {
|
||||
if requiredNames[i] > requiredNames[j] {
|
||||
requiredNames[i], requiredNames[j] = requiredNames[j], requiredNames[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
for i := 0; i < len(optionalNames)-1; i++ {
|
||||
for j := i + 1; j < len(optionalNames); j++ {
|
||||
if optionalNames[i] > optionalNames[j] {
|
||||
optionalNames[i], optionalNames[j] = optionalNames[j], optionalNames[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
parts := make([]string, 0, props.Len())
|
||||
|
||||
// Add required params first
|
||||
for _, propName := range requiredNames {
|
||||
prop, _ := props.Get(propName)
|
||||
propMap, ok := prop.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
pyType := jsonSchemaToPython(propMap)
|
||||
parts = append(parts, fmt.Sprintf("%s: %s", propName, pyType))
|
||||
}
|
||||
|
||||
// Add optional params with default None
|
||||
for _, propName := range optionalNames {
|
||||
prop, _ := props.Get(propName)
|
||||
propMap, ok := prop.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
pyType := jsonSchemaToPython(propMap)
|
||||
parts = append(parts, fmt.Sprintf("%s: %s = None", propName, pyType))
|
||||
}
|
||||
|
||||
return strings.Join(parts, ", ")
|
||||
}
|
||||
|
||||
// jsonSchemaToPython converts a JSON Schema type definition to a Python type string.
|
||||
func jsonSchemaToPython(prop map[string]interface{}) string {
|
||||
// Check for enum first - takes precedence over type to show allowed values
|
||||
if enum, ok := prop["enum"].([]interface{}); ok && len(enum) > 0 {
|
||||
enumStrs := make([]string, 0, len(enum))
|
||||
for _, e := range enum {
|
||||
enumStrs = append(enumStrs, fmt.Sprintf("%q", e))
|
||||
}
|
||||
return "Literal[" + strings.Join(enumStrs, ", ") + "]"
|
||||
}
|
||||
|
||||
// Check for const (single fixed value)
|
||||
if constVal, ok := prop["const"]; ok {
|
||||
return fmt.Sprintf("Literal[%q]", constVal)
|
||||
}
|
||||
|
||||
// Fall back to type-based conversion
|
||||
if typeVal, ok := prop["type"].(string); ok {
|
||||
switch typeVal {
|
||||
case "string":
|
||||
return "str"
|
||||
case "number":
|
||||
return "float"
|
||||
case "integer":
|
||||
return "int"
|
||||
case "boolean":
|
||||
return "bool"
|
||||
case "array":
|
||||
itemsType := "Any"
|
||||
if items, ok := prop["items"].(map[string]interface{}); ok {
|
||||
itemsType = jsonSchemaToPython(items)
|
||||
}
|
||||
return fmt.Sprintf("list[%s]", itemsType)
|
||||
case "object":
|
||||
return "dict"
|
||||
case "null":
|
||||
return "None"
|
||||
}
|
||||
}
|
||||
|
||||
return "Any"
|
||||
}
|
||||
175
core/mcp/codemode/starlark/starlark.go
Normal file
175
core/mcp/codemode/starlark/starlark.go
Normal file
@@ -0,0 +1,175 @@
|
||||
//go:build !tinygo && !wasm
|
||||
|
||||
// Package starlark provides a Starlark-based implementation of the CodeMode interface.
|
||||
// Starlark is a Python-like language designed for configuration and embedded scripting.
|
||||
// See https://github.com/google/starlark-go for more information.
|
||||
package starlark
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/mcp"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// StarlarkCodeMode implements the CodeMode interface using a Starlark interpreter.
|
||||
// It provides a sandboxed Python-like execution environment with access to MCP tools.
|
||||
type StarlarkCodeMode struct {
|
||||
// Configuration (atomic for thread-safe updates)
|
||||
bindingLevel atomic.Value // schemas.CodeModeBindingLevel
|
||||
toolExecutionTimeout atomic.Value // time.Duration
|
||||
|
||||
// Dependencies
|
||||
clientManager mcp.ClientManager
|
||||
pluginPipelineProvider func() mcp.PluginPipeline
|
||||
releasePluginPipeline func(pipeline mcp.PluginPipeline)
|
||||
fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string
|
||||
oauth2Provider schemas.OAuth2Provider
|
||||
|
||||
// Logger for this instance
|
||||
logger schemas.Logger
|
||||
|
||||
// Mutex for protecting logs during concurrent execution
|
||||
logMu sync.Mutex
|
||||
}
|
||||
|
||||
// NewStarlarkCodeMode creates a new Starlark-based CodeMode implementation.
|
||||
//
|
||||
// Parameters:
|
||||
// - config: Configuration for the code mode (binding level, timeouts). Can be nil for defaults.
|
||||
// - logger: Logger instance for this code mode. Can be nil.
|
||||
//
|
||||
// Returns:
|
||||
// - *StarlarkCodeMode: A new Starlark code mode instance
|
||||
//
|
||||
// Note: Dependencies must be set via SetDependencies before the CodeMode can execute tools.
|
||||
// This allows the CodeMode to be created before the MCPManager, avoiding circular dependencies.
|
||||
func NewStarlarkCodeMode(config *mcp.CodeModeConfig, logger schemas.Logger) *StarlarkCodeMode {
|
||||
if config == nil {
|
||||
config = mcp.DefaultCodeModeConfig()
|
||||
}
|
||||
|
||||
if config.BindingLevel == "" {
|
||||
config.BindingLevel = schemas.CodeModeBindingLevelServer
|
||||
}
|
||||
|
||||
if config.ToolExecutionTimeout <= 0 {
|
||||
config.ToolExecutionTimeout = schemas.DefaultToolExecutionTimeout
|
||||
}
|
||||
|
||||
if logger == nil {
|
||||
logger = defaultLogger
|
||||
}
|
||||
|
||||
s := &StarlarkCodeMode{
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// Initialize atomic values
|
||||
s.bindingLevel.Store(config.BindingLevel)
|
||||
s.toolExecutionTimeout.Store(config.ToolExecutionTimeout)
|
||||
|
||||
s.logger.Info("%s Starlark code mode initialized with binding level: %s, timeout: %v",
|
||||
mcp.CodeModeLogPrefix, config.BindingLevel, config.ToolExecutionTimeout)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// SetDependencies sets the dependencies required for code execution.
|
||||
// This must be called after the MCPManager is created, as the dependencies
|
||||
// include the ClientManager (which is the MCPManager itself).
|
||||
func (s *StarlarkCodeMode) SetDependencies(deps *mcp.CodeModeDependencies) {
|
||||
if deps != nil {
|
||||
s.clientManager = deps.ClientManager
|
||||
s.pluginPipelineProvider = deps.PluginPipelineProvider
|
||||
s.releasePluginPipeline = deps.ReleasePluginPipeline
|
||||
s.fetchNewRequestIDFunc = deps.FetchNewRequestIDFunc
|
||||
s.oauth2Provider = deps.OAuth2Provider
|
||||
}
|
||||
}
|
||||
|
||||
// GetTools returns the code mode meta-tools for Starlark execution.
|
||||
// These tools allow LLMs to discover, read, and execute code against MCP servers.
|
||||
func (s *StarlarkCodeMode) GetTools() []schemas.ChatTool {
|
||||
return []schemas.ChatTool{
|
||||
s.createListToolFilesTool(),
|
||||
s.createReadToolFileTool(),
|
||||
s.createGetToolDocsTool(),
|
||||
s.createExecuteToolCodeTool(),
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteTool handles a code mode tool call.
|
||||
// It dispatches to the appropriate handler based on the tool name.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for tool execution
|
||||
// - toolCall: The tool call to execute
|
||||
//
|
||||
// Returns:
|
||||
// - *schemas.ChatMessage: The tool response message
|
||||
// - error: Any error that occurred during execution
|
||||
func (s *StarlarkCodeMode) ExecuteTool(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) {
|
||||
if toolCall.Function.Name == nil {
|
||||
return nil, fmt.Errorf("tool call missing function name")
|
||||
}
|
||||
|
||||
toolName := *toolCall.Function.Name
|
||||
|
||||
switch toolName {
|
||||
case mcp.ToolTypeListToolFiles:
|
||||
return s.handleListToolFiles(ctx, toolCall)
|
||||
case mcp.ToolTypeReadToolFile:
|
||||
return s.handleReadToolFile(ctx, toolCall)
|
||||
case mcp.ToolTypeGetToolDocs:
|
||||
return s.handleGetToolDocs(ctx, toolCall)
|
||||
case mcp.ToolTypeExecuteToolCode:
|
||||
return s.handleExecuteToolCode(ctx, toolCall)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown code mode tool: %s", toolName)
|
||||
}
|
||||
}
|
||||
|
||||
// IsCodeModeTool returns true if the given tool name is a code mode tool.
|
||||
func (s *StarlarkCodeMode) IsCodeModeTool(toolName string) bool {
|
||||
return mcp.IsCodeModeTool(toolName)
|
||||
}
|
||||
|
||||
// GetBindingLevel returns the current code mode binding level.
|
||||
func (s *StarlarkCodeMode) GetBindingLevel() schemas.CodeModeBindingLevel {
|
||||
val := s.bindingLevel.Load()
|
||||
if val == nil {
|
||||
return schemas.CodeModeBindingLevelServer
|
||||
}
|
||||
return val.(schemas.CodeModeBindingLevel)
|
||||
}
|
||||
|
||||
// UpdateConfig updates the code mode configuration atomically.
|
||||
func (s *StarlarkCodeMode) UpdateConfig(config *mcp.CodeModeConfig) {
|
||||
if config == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if config.BindingLevel != "" {
|
||||
s.bindingLevel.Store(config.BindingLevel)
|
||||
}
|
||||
|
||||
if config.ToolExecutionTimeout > 0 {
|
||||
s.toolExecutionTimeout.Store(config.ToolExecutionTimeout)
|
||||
}
|
||||
|
||||
s.logger.Info("%s Starlark code mode configuration updated: binding level=%s, timeout=%v",
|
||||
mcp.CodeModeLogPrefix, config.BindingLevel, config.ToolExecutionTimeout)
|
||||
}
|
||||
|
||||
// getToolExecutionTimeout returns the current tool execution timeout.
|
||||
func (s *StarlarkCodeMode) getToolExecutionTimeout() time.Duration {
|
||||
val := s.toolExecutionTimeout.Load()
|
||||
if val == nil {
|
||||
return schemas.DefaultToolExecutionTimeout
|
||||
}
|
||||
return val.(time.Duration)
|
||||
}
|
||||
999
core/mcp/codemode/starlark/starlark_test.go
Normal file
999
core/mcp/codemode/starlark/starlark_test.go
Normal file
@@ -0,0 +1,999 @@
|
||||
//go:build !tinygo && !wasm
|
||||
|
||||
package starlark
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
codemcp "github.com/maximhq/bifrost/core/mcp"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"go.starlark.net/starlark"
|
||||
"go.starlark.net/syntax"
|
||||
)
|
||||
|
||||
type testClientManager struct {
|
||||
clients map[string]*schemas.MCPClientState
|
||||
tools map[string][]schemas.ChatTool
|
||||
}
|
||||
|
||||
func (m *testClientManager) GetClientForTool(toolName string) *schemas.MCPClientState {
|
||||
for clientName, tools := range m.tools {
|
||||
for _, tool := range tools {
|
||||
if tool.Function != nil && tool.Function.Name == toolName {
|
||||
return m.clients[clientName]
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testClientManager) GetClientByName(clientName string) *schemas.MCPClientState {
|
||||
return m.clients[clientName]
|
||||
}
|
||||
|
||||
func (m *testClientManager) GetToolPerClient(ctx context.Context) map[string][]schemas.ChatTool {
|
||||
return m.tools
|
||||
}
|
||||
|
||||
func TestStarlarkToGo(t *testing.T) {
|
||||
t.Run("Convert None", func(t *testing.T) {
|
||||
result := starlarkToGo(starlark.None)
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Convert Bool", func(t *testing.T) {
|
||||
result := starlarkToGo(starlark.Bool(true))
|
||||
if result != true {
|
||||
t.Errorf("Expected true, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Convert Int", func(t *testing.T) {
|
||||
result := starlarkToGo(starlark.MakeInt(42))
|
||||
if result != int64(42) {
|
||||
t.Errorf("Expected 42, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Convert Float", func(t *testing.T) {
|
||||
result := starlarkToGo(starlark.Float(3.14))
|
||||
if result != 3.14 {
|
||||
t.Errorf("Expected 3.14, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Convert String", func(t *testing.T) {
|
||||
result := starlarkToGo(starlark.String("hello"))
|
||||
if result != "hello" {
|
||||
t.Errorf("Expected 'hello', got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Convert List", func(t *testing.T) {
|
||||
list := starlark.NewList([]starlark.Value{
|
||||
starlark.MakeInt(1),
|
||||
starlark.MakeInt(2),
|
||||
starlark.MakeInt(3),
|
||||
})
|
||||
result := starlarkToGo(list)
|
||||
arr, ok := result.([]interface{})
|
||||
if !ok {
|
||||
t.Errorf("Expected []interface{}, got %T", result)
|
||||
}
|
||||
if len(arr) != 3 {
|
||||
t.Errorf("Expected length 3, got %d", len(arr))
|
||||
}
|
||||
if arr[0] != int64(1) {
|
||||
t.Errorf("Expected first element 1, got %v", arr[0])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Convert Dict", func(t *testing.T) {
|
||||
dict := starlark.NewDict(2)
|
||||
dict.SetKey(starlark.String("key1"), starlark.String("value1"))
|
||||
dict.SetKey(starlark.String("key2"), starlark.MakeInt(42))
|
||||
|
||||
result := starlarkToGo(dict)
|
||||
m, ok := result.(map[string]interface{})
|
||||
if !ok {
|
||||
t.Errorf("Expected map[string]interface{}, got %T", result)
|
||||
}
|
||||
if m["key1"] != "value1" {
|
||||
t.Errorf("Expected key1='value1', got %v", m["key1"])
|
||||
}
|
||||
if m["key2"] != int64(42) {
|
||||
t.Errorf("Expected key2=42, got %v", m["key2"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGoToStarlark(t *testing.T) {
|
||||
t.Run("Convert nil", func(t *testing.T) {
|
||||
result := goToStarlark(nil)
|
||||
if result != starlark.None {
|
||||
t.Errorf("Expected None, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Convert bool", func(t *testing.T) {
|
||||
result := goToStarlark(true)
|
||||
if result != starlark.Bool(true) {
|
||||
t.Errorf("Expected True, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Convert int", func(t *testing.T) {
|
||||
result := goToStarlark(42)
|
||||
expected := starlark.MakeInt(42)
|
||||
if result.String() != expected.String() {
|
||||
t.Errorf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Convert float64", func(t *testing.T) {
|
||||
result := goToStarlark(3.14)
|
||||
if result != starlark.Float(3.14) {
|
||||
t.Errorf("Expected 3.14, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Convert string", func(t *testing.T) {
|
||||
result := goToStarlark("hello")
|
||||
if result != starlark.String("hello") {
|
||||
t.Errorf("Expected 'hello', got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Convert slice", func(t *testing.T) {
|
||||
result := goToStarlark([]interface{}{1, "two", 3.0})
|
||||
list, ok := result.(*starlark.List)
|
||||
if !ok {
|
||||
t.Errorf("Expected *starlark.List, got %T", result)
|
||||
}
|
||||
if list.Len() != 3 {
|
||||
t.Errorf("Expected length 3, got %d", list.Len())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Convert map", func(t *testing.T) {
|
||||
result := goToStarlark(map[string]interface{}{
|
||||
"key1": "value1",
|
||||
"key2": 42,
|
||||
})
|
||||
dict, ok := result.(*starlark.Dict)
|
||||
if !ok {
|
||||
t.Errorf("Expected *starlark.Dict, got %T", result)
|
||||
}
|
||||
val, found, _ := dict.Get(starlark.String("key1"))
|
||||
if !found {
|
||||
t.Errorf("Expected key1 to exist")
|
||||
}
|
||||
if val != starlark.String("value1") {
|
||||
t.Errorf("Expected value1, got %v", val)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetCanonicalToolName(t *testing.T) {
|
||||
if got := getCanonicalToolName("github", "github-SEARCH_REPOS"); got != "search_repos" {
|
||||
t.Fatalf("expected canonical tool name search_repos, got %q", got)
|
||||
}
|
||||
|
||||
if got := getCanonicalToolName("math", "math-123Add!"); got != "_123add" {
|
||||
t.Fatalf("expected canonical tool name _123add, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchesToolReferenceSupportsCanonicalAndLegacyNames(t *testing.T) {
|
||||
clientName := "github"
|
||||
originalToolName := "github-SEARCH_REPOS"
|
||||
|
||||
testCases := []string{
|
||||
"search_repos",
|
||||
"SEARCH_REPOS",
|
||||
}
|
||||
|
||||
for _, toolRef := range testCases {
|
||||
if !matchesToolReference(toolRef, clientName, originalToolName) {
|
||||
t.Fatalf("expected %q to match %q", toolRef, originalToolName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleListToolFilesUsesCanonicalToolIdentifiers(t *testing.T) {
|
||||
mode := NewStarlarkCodeMode(&codemcp.CodeModeConfig{
|
||||
BindingLevel: schemas.CodeModeBindingLevelTool,
|
||||
ToolExecutionTimeout: time.Second,
|
||||
}, nil)
|
||||
|
||||
clientName := "github"
|
||||
mode.clientManager = &testClientManager{
|
||||
clients: map[string]*schemas.MCPClientState{
|
||||
clientName: {
|
||||
Name: clientName,
|
||||
ExecutionConfig: &schemas.MCPClientConfig{
|
||||
Name: clientName,
|
||||
IsCodeModeClient: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
tools: map[string][]schemas.ChatTool{
|
||||
clientName: {
|
||||
{
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: "github-SEARCH_REPOS",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msg, err := mode.handleListToolFiles(context.Background(), schemas.ChatAssistantMessageToolCall{
|
||||
ID: schemas.Ptr("tool-call-1"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("handleListToolFiles returned error: %v", err)
|
||||
}
|
||||
|
||||
if msg == nil || msg.Content == nil || msg.Content.ContentStr == nil {
|
||||
t.Fatal("expected tool response content")
|
||||
}
|
||||
|
||||
content := *msg.Content.ContentStr
|
||||
if !strings.Contains(content, "search_repos.pyi") {
|
||||
t.Fatalf("expected canonical tool file path in response, got:\n%s", content)
|
||||
}
|
||||
if strings.Contains(content, "SEARCH_REPOS.pyi") {
|
||||
t.Fatalf("did not expect raw uppercase tool file path in response, got:\n%s", content)
|
||||
}
|
||||
if !strings.Contains(content, "readToolFile before executeToolCode") {
|
||||
t.Fatalf("expected workflow guidance in response, got:\n%s", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeneratePythonErrorHints(t *testing.T) {
|
||||
serverKeys := []string{"calculator", "weather"}
|
||||
|
||||
t.Run("Undefined variable hint", func(t *testing.T) {
|
||||
hints := generatePythonErrorHints("name 'foo' is not defined", serverKeys)
|
||||
if len(hints) == 0 {
|
||||
t.Error("Expected hints, got none")
|
||||
}
|
||||
found := false
|
||||
for _, hint := range hints {
|
||||
if strings.Contains(hint, "Variable 'foo' is not defined.") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected exact undefined variable hint for foo, got: %v", hints)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Syntax error hint", func(t *testing.T) {
|
||||
hints := generatePythonErrorHints("syntax error at line 5", serverKeys)
|
||||
if len(hints) == 0 {
|
||||
t.Error("Expected hints, got none")
|
||||
}
|
||||
found := false
|
||||
for _, hint := range hints {
|
||||
if containsAny(hint, "syntax", "indentation", "colon") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected hint about syntax error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Attribute error hint", func(t *testing.T) {
|
||||
hints := generatePythonErrorHints("'dict' object has no attribute 'foo'", serverKeys)
|
||||
if len(hints) == 0 {
|
||||
t.Error("Expected hints, got none")
|
||||
}
|
||||
found := false
|
||||
for _, hint := range hints {
|
||||
if containsAny(hint, "attribute", "brackets", "key") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected hint about attribute access")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func containsAny(s string, substrs ...string) bool {
|
||||
for _, sub := range substrs {
|
||||
if containsIgnoreCase(s, sub) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func containsIgnoreCase(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && (containsIgnoreCase(s[1:], substr) || (len(s) >= len(substr) && equalFold(s[:len(substr)], substr))))
|
||||
}
|
||||
|
||||
func equalFold(a, b string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < len(a); i++ {
|
||||
ca, cb := a[i], b[i]
|
||||
if ca >= 'A' && ca <= 'Z' {
|
||||
ca += 'a' - 'A'
|
||||
}
|
||||
if cb >= 'A' && cb <= 'Z' {
|
||||
cb += 'a' - 'A'
|
||||
}
|
||||
if ca != cb {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestExtractResultFromResponsesMessage(t *testing.T) {
|
||||
t.Run("Extract error from ResponsesMessage", func(t *testing.T) {
|
||||
errorMsg := "Tool is not allowed by security policy: dangerous_tool"
|
||||
msg := &schemas.ResponsesMessage{
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
Error: &errorMsg,
|
||||
},
|
||||
}
|
||||
|
||||
result, err := extractResultFromResponsesMessage(msg)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error, got nil")
|
||||
}
|
||||
if err.Error() != errorMsg {
|
||||
t.Errorf("Expected error message '%s', got '%s'", errorMsg, err.Error())
|
||||
}
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil result when error is present, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Extract string output from ResponsesMessage", func(t *testing.T) {
|
||||
outputStr := "success result"
|
||||
msg := &schemas.ResponsesMessage{
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
Output: &schemas.ResponsesToolMessageOutputStruct{
|
||||
ResponsesToolCallOutputStr: &outputStr,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := extractResultFromResponsesMessage(msg)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if result != outputStr {
|
||||
t.Errorf("Expected result '%s', got '%v'", outputStr, result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Extract JSON output from ResponsesMessage", func(t *testing.T) {
|
||||
outputStr := `{"status": "success", "data": "test"}`
|
||||
msg := &schemas.ResponsesMessage{
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
Output: &schemas.ResponsesToolMessageOutputStruct{
|
||||
ResponsesToolCallOutputStr: &outputStr,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := extractResultFromResponsesMessage(msg)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
resultMap, ok := result.(map[string]interface{})
|
||||
if !ok {
|
||||
t.Errorf("Expected map, got %T", result)
|
||||
}
|
||||
|
||||
if resultMap["status"] != "success" {
|
||||
t.Errorf("Expected status 'success', got '%v'", resultMap["status"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Extract from ResponsesFunctionToolCallOutputBlocks", func(t *testing.T) {
|
||||
text1 := "First block"
|
||||
text2 := "Second block"
|
||||
msg := &schemas.ResponsesMessage{
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
Output: &schemas.ResponsesToolMessageOutputStruct{
|
||||
ResponsesFunctionToolCallOutputBlocks: []schemas.ResponsesMessageContentBlock{
|
||||
{Text: &text1},
|
||||
{Text: &text2},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := extractResultFromResponsesMessage(msg)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expectedResult := "First block\nSecond block"
|
||||
if result != expectedResult {
|
||||
t.Errorf("Expected result '%s', got '%v'", expectedResult, result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Extract JSON from ResponsesFunctionToolCallOutputBlocks", func(t *testing.T) {
|
||||
jsonText := `{"key": "value"}`
|
||||
msg := &schemas.ResponsesMessage{
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
Output: &schemas.ResponsesToolMessageOutputStruct{
|
||||
ResponsesFunctionToolCallOutputBlocks: []schemas.ResponsesMessageContentBlock{
|
||||
{Text: &jsonText},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := extractResultFromResponsesMessage(msg)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
resultMap, ok := result.(map[string]interface{})
|
||||
if !ok {
|
||||
t.Errorf("Expected map, got %T", result)
|
||||
}
|
||||
|
||||
if resultMap["key"] != "value" {
|
||||
t.Errorf("Expected key 'value', got '%v'", resultMap["key"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Handle nil message", func(t *testing.T) {
|
||||
result, err := extractResultFromResponsesMessage(nil)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil result for nil message, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Handle message without ResponsesToolMessage", func(t *testing.T) {
|
||||
msg := &schemas.ResponsesMessage{}
|
||||
|
||||
result, err := extractResultFromResponsesMessage(msg)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil result for message without tool message, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Handle empty error string (should not error)", func(t *testing.T) {
|
||||
emptyError := ""
|
||||
msg := &schemas.ResponsesMessage{
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
Error: &emptyError,
|
||||
},
|
||||
}
|
||||
|
||||
result, err := extractResultFromResponsesMessage(msg)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for empty error string, got: %v", err)
|
||||
}
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil result for empty error string, got %v", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractResultFromChatMessage(t *testing.T) {
|
||||
t.Run("Extract string from ChatMessage", func(t *testing.T) {
|
||||
content := "test result"
|
||||
msg := &schemas.ChatMessage{
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: &content,
|
||||
},
|
||||
}
|
||||
|
||||
result := extractResultFromChatMessage(msg)
|
||||
if result != content {
|
||||
t.Errorf("Expected result '%s', got '%v'", content, result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Extract JSON from ChatMessage", func(t *testing.T) {
|
||||
content := `{"status": "ok"}`
|
||||
msg := &schemas.ChatMessage{
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: &content,
|
||||
},
|
||||
}
|
||||
|
||||
result := extractResultFromChatMessage(msg)
|
||||
resultMap, ok := result.(map[string]interface{})
|
||||
if !ok {
|
||||
t.Errorf("Expected map, got %T", result)
|
||||
}
|
||||
|
||||
if resultMap["status"] != "ok" {
|
||||
t.Errorf("Expected status 'ok', got '%v'", resultMap["status"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Handle nil ChatMessage", func(t *testing.T) {
|
||||
result := extractResultFromChatMessage(nil)
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil result for nil message, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Handle ChatMessage without Content", func(t *testing.T) {
|
||||
msg := &schemas.ChatMessage{}
|
||||
result := extractResultFromChatMessage(msg)
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil result for message without content, got %v", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFormatResultForLog(t *testing.T) {
|
||||
t.Run("Format nil result", func(t *testing.T) {
|
||||
result := formatResultForLog(nil)
|
||||
if result != "null" {
|
||||
t.Errorf("Expected 'null', got '%s'", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Format string result", func(t *testing.T) {
|
||||
result := formatResultForLog("test string")
|
||||
if result != `"test string"` {
|
||||
t.Errorf("Expected '\"test string\"', got '%s'", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Format map result", func(t *testing.T) {
|
||||
input := map[string]interface{}{"key": "value"}
|
||||
result := formatResultForLog(input)
|
||||
|
||||
// Parse it back to verify it's valid JSON
|
||||
var parsed map[string]interface{}
|
||||
err := sonic.Unmarshal([]byte(result), &parsed)
|
||||
if err != nil {
|
||||
t.Errorf("Result is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
if parsed["key"] != "value" {
|
||||
t.Errorf("Expected key 'value', got '%v'", parsed["key"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Truncate long result", func(t *testing.T) {
|
||||
longString := ""
|
||||
for i := 0; i < 300; i++ {
|
||||
longString += "a"
|
||||
}
|
||||
|
||||
result := formatResultForLog(longString)
|
||||
if len(result) > 200 {
|
||||
// Should be truncated to around 200 chars (plus quotes and ellipsis)
|
||||
t.Logf("Result length: %d (truncated as expected)", len(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// starlarkOpts returns the FileOptions used by the code mode executor.
|
||||
// Kept in sync with executecode.go to test the same dialect configuration.
|
||||
func starlarkOpts() *syntax.FileOptions {
|
||||
return &syntax.FileOptions{
|
||||
TopLevelControl: true,
|
||||
While: true,
|
||||
Set: true,
|
||||
GlobalReassign: true,
|
||||
Recursion: true,
|
||||
}
|
||||
}
|
||||
|
||||
// execStarlark is a test helper that executes Starlark code with our dialect options
|
||||
// and returns the globals and any error.
|
||||
func execStarlark(code string) (starlark.StringDict, error) {
|
||||
thread := &starlark.Thread{Name: "test"}
|
||||
return starlark.ExecFileOptions(starlarkOpts(), thread, "test.star", code, nil)
|
||||
}
|
||||
|
||||
func TestStarlarkDialectOptions(t *testing.T) {
|
||||
t.Run("Top-level for loop", func(t *testing.T) {
|
||||
code := `
|
||||
items = []
|
||||
for i in range(3):
|
||||
items.append(i)
|
||||
result = items
|
||||
`
|
||||
globals, err := execStarlark(code)
|
||||
if err != nil {
|
||||
t.Fatalf("Top-level for loop should work, got error: %v", err)
|
||||
}
|
||||
resultVal := globals["result"]
|
||||
list, ok := resultVal.(*starlark.List)
|
||||
if !ok {
|
||||
t.Fatalf("Expected list, got %T", resultVal)
|
||||
}
|
||||
if list.Len() != 3 {
|
||||
t.Errorf("Expected 3 items, got %d", list.Len())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Top-level if statement", func(t *testing.T) {
|
||||
code := `
|
||||
x = 10
|
||||
if x > 5:
|
||||
result = "big"
|
||||
else:
|
||||
result = "small"
|
||||
`
|
||||
globals, err := execStarlark(code)
|
||||
if err != nil {
|
||||
t.Fatalf("Top-level if should work, got error: %v", err)
|
||||
}
|
||||
if globals["result"] != starlark.String("big") {
|
||||
t.Errorf("Expected 'big', got %v", globals["result"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Top-level while loop", func(t *testing.T) {
|
||||
code := `
|
||||
count = 0
|
||||
while count < 5:
|
||||
count += 1
|
||||
result = count
|
||||
`
|
||||
globals, err := execStarlark(code)
|
||||
if err != nil {
|
||||
t.Fatalf("Top-level while loop should work, got error: %v", err)
|
||||
}
|
||||
resultVal := globals["result"]
|
||||
if resultVal.String() != "5" {
|
||||
t.Errorf("Expected 5, got %v", resultVal)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("While loop inside function", func(t *testing.T) {
|
||||
code := `
|
||||
def countdown(n):
|
||||
items = []
|
||||
while n > 0:
|
||||
items.append(n)
|
||||
n -= 1
|
||||
return items
|
||||
result = countdown(3)
|
||||
`
|
||||
globals, err := execStarlark(code)
|
||||
if err != nil {
|
||||
t.Fatalf("While in function should work, got error: %v", err)
|
||||
}
|
||||
list := globals["result"].(*starlark.List)
|
||||
if list.Len() != 3 {
|
||||
t.Errorf("Expected 3 items, got %d", list.Len())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("set() builtin", func(t *testing.T) {
|
||||
code := `
|
||||
s = set([1, 2, 3, 2, 1])
|
||||
result = len(s)
|
||||
`
|
||||
globals, err := execStarlark(code)
|
||||
if err != nil {
|
||||
t.Fatalf("set() should work, got error: %v", err)
|
||||
}
|
||||
if globals["result"].String() != "3" {
|
||||
t.Errorf("Expected 3 unique items, got %v", globals["result"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Global variable reassignment", func(t *testing.T) {
|
||||
code := `
|
||||
x = 1
|
||||
x = x + 1
|
||||
x = x * 3
|
||||
result = x
|
||||
`
|
||||
globals, err := execStarlark(code)
|
||||
if err != nil {
|
||||
t.Fatalf("Global reassignment should work, got error: %v", err)
|
||||
}
|
||||
if globals["result"].String() != "6" {
|
||||
t.Errorf("Expected 6, got %v", globals["result"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Recursive function", func(t *testing.T) {
|
||||
code := `
|
||||
def factorial(n):
|
||||
if n <= 1:
|
||||
return 1
|
||||
return n * factorial(n - 1)
|
||||
result = factorial(5)
|
||||
`
|
||||
globals, err := execStarlark(code)
|
||||
if err != nil {
|
||||
t.Fatalf("Recursion should work, got error: %v", err)
|
||||
}
|
||||
if globals["result"].String() != "120" {
|
||||
t.Errorf("Expected 120, got %v", globals["result"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestStarlarkStringEscapePreservation(t *testing.T) {
|
||||
t.Run("Backslash-n in string literal preserved", func(t *testing.T) {
|
||||
// Simulate what happens after JSON deserialization:
|
||||
// Model writes: {"code": "msg = \"hello\\nworld\""}
|
||||
// sonic.Unmarshal produces: msg = "hello\nworld" (where \n is two chars: \ + n)
|
||||
// Starlark should interpret \n as newline escape inside the string
|
||||
code := "msg = \"hello\\nworld\"\nresult = msg"
|
||||
|
||||
globals, err := execStarlark(code)
|
||||
if err != nil {
|
||||
t.Fatalf("String with \\n escape should work, got error: %v", err)
|
||||
}
|
||||
resultStr := string(globals["result"].(starlark.String))
|
||||
if resultStr != "hello\nworld" {
|
||||
t.Errorf("Expected 'hello<newline>world', got %q", resultStr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Multiple escape sequences in strings", func(t *testing.T) {
|
||||
code := "msg = \"col1\\tcol2\\nrow1\\trow2\"\nresult = msg"
|
||||
|
||||
globals, err := execStarlark(code)
|
||||
if err != nil {
|
||||
t.Fatalf("String with multiple escapes should work, got error: %v", err)
|
||||
}
|
||||
resultStr := string(globals["result"].(starlark.String))
|
||||
if resultStr != "col1\tcol2\nrow1\trow2" {
|
||||
t.Errorf("Expected tab/newline escapes, got %q", resultStr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Newline join pattern", func(t *testing.T) {
|
||||
// This is the exact pattern that failed 7 times in benchmarks
|
||||
code := `
|
||||
def main():
|
||||
lines = ["line1", "line2", "line3"]
|
||||
content = "\n".join(lines)
|
||||
return content
|
||||
result = main()
|
||||
`
|
||||
globals, err := execStarlark(code)
|
||||
if err != nil {
|
||||
t.Fatalf("Newline join pattern should work, got error: %v", err)
|
||||
}
|
||||
resultStr := string(globals["result"].(starlark.String))
|
||||
if resultStr != "line1\nline2\nline3" {
|
||||
t.Errorf("Expected joined lines, got %q", resultStr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("chr() for newline", func(t *testing.T) {
|
||||
code := `
|
||||
nl = chr(10)
|
||||
result = "hello" + nl + "world"
|
||||
`
|
||||
globals, err := execStarlark(code)
|
||||
if err != nil {
|
||||
t.Fatalf("chr(10) should work, got error: %v", err)
|
||||
}
|
||||
resultStr := string(globals["result"].(starlark.String))
|
||||
if resultStr != "hello\nworld" {
|
||||
t.Errorf("Expected 'hello<newline>world', got %q", resultStr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Triple-quoted strings", func(t *testing.T) {
|
||||
code := "result = \"\"\"line1\nline2\nline3\"\"\""
|
||||
|
||||
globals, err := execStarlark(code)
|
||||
if err != nil {
|
||||
t.Fatalf("Triple-quoted string should work, got error: %v", err)
|
||||
}
|
||||
resultStr := string(globals["result"].(starlark.String))
|
||||
if resultStr != "line1\nline2\nline3" {
|
||||
t.Errorf("Expected multiline string, got %q", resultStr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Raw string preserves backslash", func(t *testing.T) {
|
||||
code := "result = r\"hello\\nworld\""
|
||||
|
||||
globals, err := execStarlark(code)
|
||||
if err != nil {
|
||||
t.Fatalf("Raw string should work, got error: %v", err)
|
||||
}
|
||||
resultStr := string(globals["result"].(starlark.String))
|
||||
// Raw string: \n stays as two characters \ and n
|
||||
if resultStr != "hello\\nworld" {
|
||||
t.Errorf("Expected literal backslash-n, got %q", resultStr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("JSON deserialization then Starlark execution", func(t *testing.T) {
|
||||
// End-to-end: simulate the exact flow from model JSON → sonic.Unmarshal → Starlark
|
||||
jsonArgs := `{"code": "lines = [\"a\", \"b\", \"c\"]\nresult = \"\\n\".join(lines)"}`
|
||||
|
||||
var arguments map[string]interface{}
|
||||
err := sonic.Unmarshal([]byte(jsonArgs), &arguments)
|
||||
if err != nil {
|
||||
t.Fatalf("JSON unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
code := arguments["code"].(string)
|
||||
|
||||
globals, starlarkErr := execStarlark(code)
|
||||
if starlarkErr != nil {
|
||||
t.Fatalf("Starlark execution failed: %v", starlarkErr)
|
||||
}
|
||||
resultStr := string(globals["result"].(starlark.String))
|
||||
if resultStr != "a\nb\nc" {
|
||||
t.Errorf("Expected 'a<newline>b<newline>c', got %q", resultStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestStarlarkUnsupportedFeatures(t *testing.T) {
|
||||
t.Run("try/except rejected", func(t *testing.T) {
|
||||
code := `
|
||||
def main():
|
||||
try:
|
||||
x = 1
|
||||
except:
|
||||
x = 0
|
||||
result = main()
|
||||
`
|
||||
_, err := execStarlark(code)
|
||||
if err == nil {
|
||||
t.Fatal("try/except should be rejected by Starlark")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "got try") {
|
||||
t.Errorf("Expected 'got try' in error, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("raise rejected", func(t *testing.T) {
|
||||
code := `raise ValueError("test")`
|
||||
|
||||
_, err := execStarlark(code)
|
||||
if err == nil {
|
||||
t.Fatal("raise should be rejected by Starlark")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("class rejected", func(t *testing.T) {
|
||||
code := `
|
||||
class Foo:
|
||||
pass
|
||||
`
|
||||
_, err := execStarlark(code)
|
||||
if err == nil {
|
||||
t.Fatal("class should be rejected by Starlark")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("import rejected", func(t *testing.T) {
|
||||
code := `import json`
|
||||
|
||||
_, err := execStarlark(code)
|
||||
if err == nil {
|
||||
t.Fatal("import should be rejected by Starlark")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGeneratePythonErrorHintsNewCases(t *testing.T) {
|
||||
serverKeys := []string{"Github", "SqLite"}
|
||||
|
||||
t.Run("try/except hint", func(t *testing.T) {
|
||||
hints := generatePythonErrorHints("code.star:3:9: got try, want primary expression", serverKeys)
|
||||
if len(hints) == 0 {
|
||||
t.Fatal("Expected hints for try/except error")
|
||||
}
|
||||
found := false
|
||||
for _, hint := range hints {
|
||||
if containsAny(hint, "try/except", "exception handling") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected hint about try/except not being supported, got: %v", hints)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("except hint", func(t *testing.T) {
|
||||
hints := generatePythonErrorHints("code.star:5:9: got except, want primary expression", serverKeys)
|
||||
if len(hints) == 0 {
|
||||
t.Fatal("Expected hints for except error")
|
||||
}
|
||||
found := false
|
||||
for _, hint := range hints {
|
||||
if containsAny(hint, "try/except", "exception handling") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected hint about exception handling, got: %v", hints)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("finally hint", func(t *testing.T) {
|
||||
hints := generatePythonErrorHints("code.star:7:9: got finally, want primary expression", serverKeys)
|
||||
if len(hints) == 0 {
|
||||
t.Fatal("Expected hints for finally error")
|
||||
}
|
||||
found := false
|
||||
for _, hint := range hints {
|
||||
if containsAny(hint, "try/except", "exception handling") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected hint about exception handling, got: %v", hints)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("raise hint", func(t *testing.T) {
|
||||
hints := generatePythonErrorHints("code.star:2:1: got raise, want primary expression", serverKeys)
|
||||
if len(hints) == 0 {
|
||||
t.Fatal("Expected hints for raise error")
|
||||
}
|
||||
found := false
|
||||
for _, hint := range hints {
|
||||
if containsAny(hint, "try/except", "exception handling") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected hint about exception handling, got: %v", hints)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Undefined variable includes scope hint", func(t *testing.T) {
|
||||
hints := generatePythonErrorHints("code.star:3:17: undefined: commits_n8n", serverKeys)
|
||||
if len(hints) == 0 {
|
||||
t.Fatal("Expected hints for undefined variable")
|
||||
}
|
||||
foundVar := false
|
||||
foundScope := false
|
||||
for _, hint := range hints {
|
||||
if strings.Contains(hint, "Variable 'commits_n8n' is not defined.") {
|
||||
foundVar = true
|
||||
}
|
||||
if containsAny(hint, "fresh scope", "persist") {
|
||||
foundScope = true
|
||||
}
|
||||
}
|
||||
if !foundVar {
|
||||
t.Errorf("Expected exact undefined variable hint for commits_n8n, got: %v", hints)
|
||||
}
|
||||
if !foundScope {
|
||||
t.Errorf("Expected scope persistence hint, got: %v", hints)
|
||||
}
|
||||
})
|
||||
}
|
||||
443
core/mcp/codemode/starlark/utils.go
Normal file
443
core/mcp/codemode/starlark/utils.go
Normal file
@@ -0,0 +1,443 @@
|
||||
//go:build !tinygo && !wasm
|
||||
|
||||
package starlark
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"go.starlark.net/starlark"
|
||||
"go.starlark.net/starlarkstruct"
|
||||
)
|
||||
|
||||
// starlarkToGo converts a Starlark value to a Go value
|
||||
func starlarkToGo(v starlark.Value) interface{} {
|
||||
switch val := v.(type) {
|
||||
case starlark.NoneType:
|
||||
return nil
|
||||
case starlark.Bool:
|
||||
return bool(val)
|
||||
case starlark.Int:
|
||||
if i, ok := val.Int64(); ok {
|
||||
return i
|
||||
}
|
||||
if i, ok := val.Uint64(); ok {
|
||||
return i
|
||||
}
|
||||
return val.String()
|
||||
case starlark.Float:
|
||||
return float64(val)
|
||||
case starlark.String:
|
||||
return string(val)
|
||||
case *starlark.List:
|
||||
result := make([]interface{}, val.Len())
|
||||
for i := 0; i < val.Len(); i++ {
|
||||
result[i] = starlarkToGo(val.Index(i))
|
||||
}
|
||||
return result
|
||||
case starlark.Tuple:
|
||||
result := make([]interface{}, len(val))
|
||||
for i, item := range val {
|
||||
result[i] = starlarkToGo(item)
|
||||
}
|
||||
return result
|
||||
case *starlark.Dict:
|
||||
result := make(map[string]interface{})
|
||||
for _, item := range val.Items() {
|
||||
if keyStr, ok := item[0].(starlark.String); ok {
|
||||
result[string(keyStr)] = starlarkToGo(item[1])
|
||||
} else {
|
||||
// Use string representation for non-string keys
|
||||
result[item[0].String()] = starlarkToGo(item[1])
|
||||
}
|
||||
}
|
||||
return result
|
||||
case *starlarkstruct.Struct:
|
||||
result := make(map[string]interface{})
|
||||
for _, name := range val.AttrNames() {
|
||||
if attrVal, err := val.Attr(name); err == nil {
|
||||
result[name] = starlarkToGo(attrVal)
|
||||
}
|
||||
}
|
||||
return result
|
||||
default:
|
||||
return val.String()
|
||||
}
|
||||
}
|
||||
|
||||
// goToStarlark converts a Go value to a Starlark value
|
||||
func goToStarlark(v interface{}) starlark.Value {
|
||||
if v == nil {
|
||||
return starlark.None
|
||||
}
|
||||
|
||||
switch val := v.(type) {
|
||||
case bool:
|
||||
return starlark.Bool(val)
|
||||
case int:
|
||||
return starlark.MakeInt(val)
|
||||
case int64:
|
||||
return starlark.MakeInt64(val)
|
||||
case uint64:
|
||||
return starlark.MakeUint64(val)
|
||||
case float64:
|
||||
return starlark.Float(val)
|
||||
case string:
|
||||
return starlark.String(val)
|
||||
case []interface{}:
|
||||
items := make([]starlark.Value, len(val))
|
||||
for i, item := range val {
|
||||
items[i] = goToStarlark(item)
|
||||
}
|
||||
return starlark.NewList(items)
|
||||
case map[string]interface{}:
|
||||
dict := starlark.NewDict(len(val))
|
||||
for k, v := range val {
|
||||
dict.SetKey(starlark.String(k), goToStarlark(v))
|
||||
}
|
||||
return dict
|
||||
default:
|
||||
// Try to marshal to JSON and parse as a generic structure
|
||||
if jsonBytes, err := schemas.MarshalSorted(val); err == nil {
|
||||
var generic interface{}
|
||||
if schemas.Unmarshal(jsonBytes, &generic) == nil {
|
||||
return goToStarlark(generic)
|
||||
}
|
||||
}
|
||||
return starlark.String(fmt.Sprintf("%v", val))
|
||||
}
|
||||
}
|
||||
|
||||
// extractResultFromChatMessage extracts the result from a chat message and parses it as JSON if possible.
|
||||
func extractResultFromChatMessage(msg *schemas.ChatMessage) interface{} {
|
||||
if msg == nil || msg.Content == nil || msg.Content.ContentStr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
rawResult := *msg.Content.ContentStr
|
||||
|
||||
var finalResult interface{}
|
||||
if err := sonic.Unmarshal([]byte(rawResult), &finalResult); err != nil {
|
||||
return rawResult
|
||||
}
|
||||
|
||||
return finalResult
|
||||
}
|
||||
|
||||
// extractResultFromResponsesMessage extracts the result or error from a ResponsesMessage.
|
||||
func extractResultFromResponsesMessage(msg *schemas.ResponsesMessage) (interface{}, error) {
|
||||
if msg == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if msg.ResponsesToolMessage != nil {
|
||||
if msg.ResponsesToolMessage.Error != nil && *msg.ResponsesToolMessage.Error != "" {
|
||||
return nil, fmt.Errorf("%s", *msg.ResponsesToolMessage.Error)
|
||||
}
|
||||
|
||||
if msg.ResponsesToolMessage.Output != nil {
|
||||
if msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != nil {
|
||||
rawResult := *msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr
|
||||
|
||||
var finalResult interface{}
|
||||
if err := sonic.Unmarshal([]byte(rawResult), &finalResult); err != nil {
|
||||
return rawResult, nil
|
||||
}
|
||||
return finalResult, nil
|
||||
}
|
||||
|
||||
if len(msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks) > 0 {
|
||||
var textParts []string
|
||||
for _, block := range msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks {
|
||||
if block.Text != nil {
|
||||
textParts = append(textParts, *block.Text)
|
||||
}
|
||||
}
|
||||
if len(textParts) > 0 {
|
||||
result := strings.Join(textParts, "\n")
|
||||
var finalResult interface{}
|
||||
if err := sonic.Unmarshal([]byte(result), &finalResult); err != nil {
|
||||
return result, nil
|
||||
}
|
||||
return finalResult, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// formatResultForLog formats a result value for logging purposes.
|
||||
func formatResultForLog(result interface{}) string {
|
||||
var resultStr string
|
||||
if result == nil {
|
||||
resultStr = "null"
|
||||
} else if resultBytes, err := schemas.MarshalSorted(result); err == nil {
|
||||
resultStr = string(resultBytes)
|
||||
} else {
|
||||
resultStr = fmt.Sprintf("%v", result)
|
||||
}
|
||||
return resultStr
|
||||
}
|
||||
|
||||
// generatePythonErrorHints generates helpful hints for Python/Starlark errors.
|
||||
func generatePythonErrorHints(errorMessage string, serverKeys []string) []string {
|
||||
hints := []string{}
|
||||
|
||||
if strings.Contains(errorMessage, "got try") || strings.Contains(errorMessage, "got except") ||
|
||||
strings.Contains(errorMessage, "got finally") || strings.Contains(errorMessage, "got raise") {
|
||||
hints = append(hints, "Starlark does NOT support try/except/finally/raise — there is no exception handling.")
|
||||
hints = append(hints, "Instead, check return values for errors:")
|
||||
hints = append(hints, " result = server.tool(param=\"value\")")
|
||||
hints = append(hints, " if result == None or (type(result) == \"dict\" and \"error\" in result):")
|
||||
hints = append(hints, " print(\"Error:\", result)")
|
||||
} else if strings.Contains(errorMessage, "undefined") || strings.Contains(errorMessage, "not defined") {
|
||||
var undefinedVar string
|
||||
if match := regexp.MustCompile(`name ['"]([^'"]+)['"] is not defined`).FindStringSubmatch(errorMessage); len(match) > 1 {
|
||||
undefinedVar = match[1]
|
||||
} else if match := regexp.MustCompile(`undefined:\s*([A-Za-z_][A-Za-z0-9_]*)`).FindStringSubmatch(errorMessage); len(match) > 1 {
|
||||
undefinedVar = match[1]
|
||||
} else if match := regexp.MustCompile(`([A-Za-z_][A-Za-z0-9_]*)[^A-Za-z0-9_]+(?:undefined|not defined)`).FindStringSubmatch(errorMessage); len(match) > 1 {
|
||||
undefinedVar = match[1]
|
||||
}
|
||||
if undefinedVar != "" {
|
||||
hints = append(hints, fmt.Sprintf("Variable '%s' is not defined.", undefinedVar))
|
||||
hints = append(hints, "Note: Each executeToolCode call runs in a fresh scope — no variables persist between calls.")
|
||||
if len(serverKeys) > 0 {
|
||||
hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", ")))
|
||||
hints = append(hints, "Access tools using: server_name.tool_name(param=\"value\")")
|
||||
}
|
||||
}
|
||||
} else if strings.Contains(errorMessage, "not within a function") {
|
||||
hints = append(hints, "Starlark requires for/if/while statements to be inside functions at the top level.")
|
||||
hints = append(hints, "Wrap your code in a function, then call it:")
|
||||
hints = append(hints, " def fetch_all():")
|
||||
hints = append(hints, " results = []")
|
||||
hints = append(hints, " for id in ids:")
|
||||
hints = append(hints, " results.append(server.get(id=id))")
|
||||
hints = append(hints, " return results")
|
||||
hints = append(hints, " result = fetch_all()")
|
||||
} else if strings.Contains(errorMessage, "syntax error") {
|
||||
hints = append(hints, "Python syntax error detected.")
|
||||
hints = append(hints, "Check for proper indentation (use spaces, not tabs).")
|
||||
hints = append(hints, "Ensure colons after if/for/def statements.")
|
||||
hints = append(hints, "Check for matching parentheses and brackets.")
|
||||
} else if strings.Contains(errorMessage, "has no") && strings.Contains(errorMessage, "attribute") {
|
||||
hints = append(hints, "You're trying to access an attribute that doesn't exist.")
|
||||
hints = append(hints, "Use dict access syntax: result[\"key\"] instead of result.key")
|
||||
hints = append(hints, "Use print(result) to see the actual structure.")
|
||||
if len(serverKeys) > 0 {
|
||||
hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", ")))
|
||||
}
|
||||
} else if strings.Contains(errorMessage, "not callable") {
|
||||
hints = append(hints, "You're trying to call something that is not a function.")
|
||||
hints = append(hints, "Ensure you're using the correct tool name.")
|
||||
if len(serverKeys) > 0 {
|
||||
hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", ")))
|
||||
}
|
||||
hints = append(hints, "Use readToolFile to see available tools for a server.")
|
||||
} else if strings.Contains(errorMessage, "key") && strings.Contains(errorMessage, "not found") {
|
||||
hints = append(hints, "Dictionary key not found.")
|
||||
hints = append(hints, "Use print() to inspect the dict structure before accessing keys.")
|
||||
hints = append(hints, "Use .get(\"key\", default) for safe access.")
|
||||
} else {
|
||||
hints = append(hints, "Check the error message above for details.")
|
||||
if len(serverKeys) > 0 {
|
||||
hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", ")))
|
||||
}
|
||||
hints = append(hints, "Use: result = server_name.tool_name(param=\"value\")")
|
||||
hints = append(hints, "Access dict values with brackets: result[\"key\"]")
|
||||
}
|
||||
|
||||
return hints
|
||||
}
|
||||
|
||||
// extractTextFromMCPResponse extracts text content from an MCP tool response.
|
||||
func extractTextFromMCPResponse(toolResponse *mcp.CallToolResult, toolName string) string {
|
||||
if toolResponse == nil {
|
||||
return fmt.Sprintf("MCP tool '%s' executed successfully", toolName)
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
for _, contentBlock := range toolResponse.Content {
|
||||
// Handle typed content
|
||||
switch content := contentBlock.(type) {
|
||||
case mcp.TextContent:
|
||||
result.WriteString(content.Text)
|
||||
case mcp.ImageContent:
|
||||
result.WriteString(fmt.Sprintf("[Image Response: %s, MIME: %s]\n", content.Data, content.MIMEType))
|
||||
case mcp.AudioContent:
|
||||
result.WriteString(fmt.Sprintf("[Audio Response: %s, MIME: %s]\n", content.Data, content.MIMEType))
|
||||
case mcp.EmbeddedResource:
|
||||
result.WriteString(fmt.Sprintf("[Embedded Resource Response: %s]\n", content.Type))
|
||||
default:
|
||||
// Fallback: try to extract from map structure
|
||||
if jsonBytes, err := schemas.MarshalSorted(contentBlock); err == nil {
|
||||
var contentMap map[string]interface{}
|
||||
if json.Unmarshal(jsonBytes, &contentMap) == nil {
|
||||
if text, ok := contentMap["text"].(string); ok {
|
||||
result.WriteString(fmt.Sprintf("[Text Response: %s]\n", text))
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Final fallback: serialize as JSON
|
||||
result.WriteString(string(jsonBytes))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if result.Len() > 0 {
|
||||
return strings.TrimSpace(result.String())
|
||||
}
|
||||
return fmt.Sprintf("MCP tool '%s' executed successfully", toolName)
|
||||
}
|
||||
|
||||
// createToolResponseMessage creates a tool response message with the execution result.
|
||||
func createToolResponseMessage(toolCall schemas.ChatAssistantMessageToolCall, responseText string) *schemas.ChatMessage {
|
||||
return &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleTool,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: &responseText,
|
||||
},
|
||||
ChatToolMessage: &schemas.ChatToolMessage{
|
||||
ToolCallID: toolCall.ID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// parseToolName normalizes a raw tool name into a Starlark-compatible identifier.
|
||||
func parseToolName(toolName string) string {
|
||||
if toolName == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
runes := []rune(toolName)
|
||||
|
||||
// Process first character - must be letter, underscore, or dollar sign
|
||||
if len(runes) > 0 {
|
||||
first := runes[0]
|
||||
if unicode.IsLetter(first) || first == '_' || first == '$' {
|
||||
result.WriteRune(unicode.ToLower(first))
|
||||
} else {
|
||||
// If first char is invalid, prefix with underscore
|
||||
result.WriteRune('_')
|
||||
if unicode.IsDigit(first) {
|
||||
result.WriteRune(first)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process remaining characters
|
||||
for i := 1; i < len(runes); i++ {
|
||||
r := runes[i]
|
||||
if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '$' {
|
||||
result.WriteRune(unicode.ToLower(r))
|
||||
} else if unicode.IsSpace(r) || r == '-' {
|
||||
// Replace spaces and hyphens with single underscore
|
||||
// Avoid consecutive underscores
|
||||
if result.Len() > 0 && result.String()[result.Len()-1] != '_' {
|
||||
result.WriteRune('_')
|
||||
}
|
||||
}
|
||||
// Skip other invalid characters
|
||||
}
|
||||
|
||||
parsed := result.String()
|
||||
|
||||
// Remove trailing underscores
|
||||
parsed = strings.TrimRight(parsed, "_")
|
||||
|
||||
// Ensure we have at least one character
|
||||
if parsed == "" {
|
||||
return "tool"
|
||||
}
|
||||
|
||||
return parsed
|
||||
}
|
||||
|
||||
// getCanonicalToolName returns the exact callable tool identifier exposed in Starlark.
|
||||
func getCanonicalToolName(clientName, originalToolName string) string {
|
||||
return parseToolName(stripClientPrefix(originalToolName, clientName))
|
||||
}
|
||||
|
||||
// getCompatibilityToolAlias returns the case-preserving alias derived from the raw tool name.
|
||||
// This is used as a compatibility alias when the raw name is still a valid Starlark identifier.
|
||||
func getCompatibilityToolAlias(clientName, originalToolName string) string {
|
||||
return strings.ReplaceAll(stripClientPrefix(originalToolName, clientName), "-", "_")
|
||||
}
|
||||
|
||||
// matchesToolReference reports whether the requested tool name matches any supported identifier form.
|
||||
// We accept the canonical callable name plus legacy display forms for backward compatibility.
|
||||
func matchesToolReference(requestedToolName, clientName, originalToolName string) bool {
|
||||
requested := strings.ToLower(requestedToolName)
|
||||
if requested == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
candidates := []string{
|
||||
getCanonicalToolName(clientName, originalToolName),
|
||||
getCompatibilityToolAlias(clientName, originalToolName),
|
||||
stripClientPrefix(originalToolName, clientName),
|
||||
}
|
||||
|
||||
for _, candidate := range candidates {
|
||||
if candidate != "" && requested == strings.ToLower(candidate) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isValidStarlarkIdentifier reports whether name can be used directly in Starlark code.
|
||||
func isValidStarlarkIdentifier(name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
runes := []rune(name)
|
||||
first := runes[0]
|
||||
if !unicode.IsLetter(first) && first != '_' && first != '$' {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, r := range runes[1:] {
|
||||
if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' && r != '$' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// validateNormalizedToolName validates a normalized tool name to prevent path traversal.
|
||||
func validateNormalizedToolName(normalizedName string) error {
|
||||
if normalizedName == "" {
|
||||
return fmt.Errorf("tool name cannot be empty after normalization")
|
||||
}
|
||||
if strings.Contains(normalizedName, "/") {
|
||||
return fmt.Errorf("tool name cannot contain '/' (path separator) after normalization: %s", normalizedName)
|
||||
}
|
||||
if strings.Contains(normalizedName, "..") {
|
||||
return fmt.Errorf("tool name cannot contain '..' (path traversal) after normalization: %s", normalizedName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// stripClientPrefix removes the client name prefix from a tool name.
|
||||
func stripClientPrefix(prefixedToolName, clientName string) string {
|
||||
prefix := clientName + "-"
|
||||
if strings.HasPrefix(prefixedToolName, prefix) {
|
||||
return strings.TrimPrefix(prefixedToolName, prefix)
|
||||
}
|
||||
// If prefix doesn't match, return as-is (shouldn't happen, but be safe)
|
||||
return prefixedToolName
|
||||
}
|
||||
Reference in New Issue
Block a user