first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

View File

@@ -0,0 +1,644 @@
//go:build !tinygo && !wasm
package starlark
import (
"context"
"fmt"
"strings"
"time"
"github.com/bytedance/sonic"
"github.com/mark3labs/mcp-go/mcp"
codemcp "github.com/maximhq/bifrost/core/mcp"
"github.com/maximhq/bifrost/core/mcp/utils"
"github.com/maximhq/bifrost/core/schemas"
"go.starlark.net/starlark"
"go.starlark.net/starlarkstruct"
"go.starlark.net/syntax"
)
// ExecutionResult represents the result of code execution
type ExecutionResult struct {
Result interface{} `json:"result"`
Logs []string `json:"logs"`
Errors *ExecutionError `json:"errors,omitempty"`
Environment ExecutionEnvironment `json:"environment"`
}
// ExecutionErrorType represents the type of execution error
type ExecutionErrorType string
const (
ExecutionErrorTypeCompile ExecutionErrorType = "compile"
ExecutionErrorTypeSyntax ExecutionErrorType = "syntax"
ExecutionErrorTypeRuntime ExecutionErrorType = "runtime"
)
// ExecutionError represents an error during code execution
type ExecutionError struct {
Kind ExecutionErrorType `json:"kind"` // "compile", "syntax", or "runtime"
Message string `json:"message"`
Hints []string `json:"hints"`
}
// ExecutionEnvironment contains information about the execution environment
type ExecutionEnvironment struct {
ServerKeys []string `json:"serverKeys"`
}
// createExecuteToolCodeTool creates the executeToolCode tool definition for code mode.
// This tool allows executing Python (Starlark) code in a sandboxed interpreter with access to MCP server tools.
func (s *StarlarkCodeMode) createExecuteToolCodeTool() schemas.ChatTool {
executeToolCodeProps := schemas.NewOrderedMapFromPairs(
schemas.KV("code", map[string]interface{}{
"type": "string",
"description": "Python (Starlark) code to execute. Tool calls are synchronous: result = server.tool(param=\"value\"). " +
"Use print() for logging. Assign to 'result' variable to return a value. " +
"Retry after fixing syntax or logic errors, especially for read-only flows. Before rerunning code that already made tool calls, inspect prior outputs and avoid replaying stateful operations. " +
"Example: items = server.list_items()\nfor item in items:\n print(item[\"name\"])\nresult = items",
}),
)
return schemas.ChatTool{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: codemcp.ToolTypeExecuteToolCode,
Description: schemas.Ptr(
"Executes Python code in a sandboxed Starlark interpreter with MCP server tool access. " +
"Servers are exposed as global objects: result = serverName.toolName(param=\"value\"). " +
"This is the final step of the four-tool code mode workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " +
"If you have not already read a tool's .pyi stub in this conversation, do that before writing code. " +
"Do NOT guess callable tool names from natural language or stale assumptions; use the exact identifier returned by listToolFiles/readToolFile. " +
"STARLARK DIFFERENCES FROM PYTHON — READ BEFORE WRITING CODE: " +
"1. NO try/except/finally/raise — error handling is not supported, and tool failures cannot be caught inside Starlark. " +
"2. NO classes — use dicts and functions. " +
"3. NO imports, direct network access, or direct filesystem access — use MCP tools instead. " +
"4. NO is operator — use == for comparison. " +
"5. NO f-strings — use % formatting: \"Hello %s, count=%d\" % (name, n). " +
"6. Each executeToolCode call runs in a FRESH ISOLATED SCOPE — no variables, functions, or state persist between calls. Re-fetch data or store it via MCP tools (e.g., SQLite, FileSystem) if needed across calls. " +
"SYNTAX NOTES: " +
"• Synchronous calls — NO async/await: result = server.tool(arg=\"value\") " +
"• Use keyword arguments: server.tool(param=\"value\") NOT server.tool({\"param\": \"value\"}) " +
"• Access dict values with brackets: result[\"key\"] NOT result.key " +
"• Use print() for logging/debugging " +
"• List comprehensions: [x for x in items if x[\"active\"]] " +
"• String escapes work normally: \"line1\\nline2\" produces a newline " +
"• Triple-quoted strings for multiline: \"\"\"multi\\nline\"\"\" " +
"• chr(10) for newline character, chr(9) for tab " +
"• To return a value, assign to 'result': result = computed_value " +
"• MCP tool calls are timeout-limited; avoid long or infinite loops " +
"AVAILABLE BUILTINS: print, len, range, enumerate, zip, sorted, reversed, min, max, " +
"int, float, str, bool, list, dict, tuple, set, hasattr, getattr, type, chr, ord, any, all, hash, repr. " +
"RETRY POLICY: Retry after fixing syntax or logic errors, especially for read-only flows. Before rerunning code that already made tool calls, inspect prior outputs and avoid replaying stateful operations.",
),
Parameters: &schemas.ToolFunctionParameters{
Type: "object",
Properties: executeToolCodeProps,
Required: []string{"code"},
},
},
}
}
// handleExecuteToolCode handles the executeToolCode tool call.
func (s *StarlarkCodeMode) handleExecuteToolCode(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) {
toolName := "unknown"
if toolCall.Function.Name != nil {
toolName = *toolCall.Function.Name
}
s.logger.Debug("%s Handling executeToolCode tool call: %s", codemcp.CodeModeLogPrefix, toolName)
// Parse tool arguments
var arguments map[string]interface{}
if err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil {
s.logger.Debug("%s Failed to parse tool arguments: %v", codemcp.CodeModeLogPrefix, err)
return nil, fmt.Errorf("failed to parse tool arguments: %v", err)
}
code, ok := arguments["code"].(string)
if !ok || code == "" {
s.logger.Debug("%s Code parameter missing or empty", codemcp.CodeModeLogPrefix)
return nil, fmt.Errorf("code parameter is required and must be a non-empty string")
}
s.logger.Debug("%s Starting code execution", codemcp.CodeModeLogPrefix)
result := s.executeCode(ctx, code)
s.logger.Debug("%s Code execution completed. Success: %v, Has errors: %v, Log count: %d", codemcp.CodeModeLogPrefix, result.Errors == nil, result.Errors != nil, len(result.Logs))
// Format response text
var responseText string
var executionSuccess bool = true
if result.Errors != nil {
s.logger.Debug("%s Formatting error response. Error kind: %s, Message length: %d, Hints count: %d", codemcp.CodeModeLogPrefix, result.Errors.Kind, len(result.Errors.Message), len(result.Errors.Hints))
logsText := ""
if len(result.Logs) > 0 {
logsText = fmt.Sprintf("\n\nPrint Output:\n%s\n", strings.Join(result.Logs, "\n"))
}
responseText = fmt.Sprintf(
"Execution %s error:\n\n%s\n\nHints:\n%s%s\n\nEnvironment:\n Available server keys: %s",
result.Errors.Kind,
result.Errors.Message,
strings.Join(result.Errors.Hints, "\n"),
logsText,
strings.Join(result.Environment.ServerKeys, ", "),
)
s.logger.Debug("%s Error response formatted. Response length: %d chars", codemcp.CodeModeLogPrefix, len(responseText))
} else {
hasLogs := len(result.Logs) > 0
hasResult := result.Result != nil
s.logger.Debug("%s Formatting success response. Has logs: %v, Has result: %v", codemcp.CodeModeLogPrefix, hasLogs, hasResult)
if !hasLogs && !hasResult {
executionSuccess = false
s.logger.Debug("%s Execution completed with no data (no logs, no result), marking as failure", codemcp.CodeModeLogPrefix)
hints := []string{
"Add print() statements throughout your code to debug and see what's happening at each step",
"Assign the final value to 'result' variable if you want to return it: result = computed_value",
"Check that your tool calls are actually executing and returning data",
}
responseText = fmt.Sprintf(
"Execution completed but produced no data:\n\n"+
"The code executed without errors but returned no output (no print output and no result variable).\n\n"+
"Hints:\n%s\n\n"+
"Environment:\n Available server keys: %s",
strings.Join(hints, "\n"),
strings.Join(result.Environment.ServerKeys, ", "),
)
s.logger.Debug("%s No-data failure response formatted. Response length: %d chars", codemcp.CodeModeLogPrefix, len(responseText))
} else {
if hasLogs {
responseText = fmt.Sprintf("Print output:\n%s\n\nExecution completed successfully.",
strings.Join(result.Logs, "\n"))
} else {
responseText = "Execution completed successfully."
}
if hasResult {
resultJSON, err := schemas.MarshalSortedIndent(result.Result, "", " ")
if err == nil {
responseText += fmt.Sprintf("\nReturn value: %s", string(resultJSON))
s.logger.Debug("%s Added return value to response (JSON length: %d chars)", codemcp.CodeModeLogPrefix, len(resultJSON))
} else {
s.logger.Debug("%s Failed to marshal result to JSON: %v", codemcp.CodeModeLogPrefix, err)
}
}
responseText += fmt.Sprintf("\n\nEnvironment:\n Available server keys: %s",
strings.Join(result.Environment.ServerKeys, ", "))
responseText += "\nNote: This is a Starlark (Python subset) environment. Use MCP tools for external interactions."
s.logger.Debug("%s Success response formatted. Response length: %d chars, Server keys: %v", codemcp.CodeModeLogPrefix, len(responseText), result.Environment.ServerKeys)
}
}
s.logger.Debug("%s Returning tool response message. Execution success: %v", codemcp.CodeModeLogPrefix, executionSuccess)
return createToolResponseMessage(toolCall, responseText), nil
}
// executeCode executes Python (Starlark) code in a sandboxed interpreter with MCP tool bindings.
func (s *StarlarkCodeMode) executeCode(ctx *schemas.BifrostContext, code string) ExecutionResult {
logs := []string{}
s.logger.Debug("%s Starting Starlark code execution", codemcp.CodeModeLogPrefix)
// Step 1: Handle empty code
trimmedCode := strings.TrimSpace(code)
if trimmedCode == "" {
return ExecutionResult{
Result: nil,
Logs: logs,
Errors: nil,
Environment: ExecutionEnvironment{
ServerKeys: []string{},
},
}
}
// Step 2: Build tool bindings for all connected servers
availableToolsPerClient := s.clientManager.GetToolPerClient(ctx)
serverKeys := make([]string, 0, len(availableToolsPerClient))
predeclared := starlark.StringDict{}
// Thread-safe log appender
appendLog := func(msg string) {
s.logMu.Lock()
defer s.logMu.Unlock()
logs = append(logs, msg)
}
s.logger.Debug("%s GetToolPerClient returned %d clients", codemcp.CodeModeLogPrefix, len(availableToolsPerClient))
for clientName, tools := range availableToolsPerClient {
client := s.clientManager.GetClientByName(clientName)
if client == nil {
s.logger.Warn("%s Client %s not found, skipping", codemcp.CodeModeLogPrefix, clientName)
continue
}
s.logger.Debug("%s [%s] Client found. IsCodeModeClient: %v, ToolCount: %d", codemcp.CodeModeLogPrefix, clientName, client.ExecutionConfig.IsCodeModeClient, len(tools))
if !client.ExecutionConfig.IsCodeModeClient || len(tools) == 0 {
s.logger.Debug("%s [%s] Skipped: IsCodeModeClient=%v, HasTools=%v", codemcp.CodeModeLogPrefix, clientName, client.ExecutionConfig.IsCodeModeClient, len(tools) > 0)
continue
}
serverKeys = append(serverKeys, clientName)
// Build struct with tool methods
structMembers := starlark.StringDict{}
for _, tool := range tools {
if tool.Function == nil || tool.Function.Name == "" {
continue
}
originalToolName := tool.Function.Name
parsedToolName := getCanonicalToolName(clientName, originalToolName)
compatibilityAlias := getCompatibilityToolAlias(clientName, originalToolName)
s.logger.Debug("%s [%s] Binding tool: %s -> %s", codemcp.CodeModeLogPrefix, clientName, originalToolName, parsedToolName)
// Capture variables for closure
capturedToolName := originalToolName
capturedClientName := clientName
// Create a Starlark builtin function for this tool
toolFunc := starlark.NewBuiltin(parsedToolName, func(thread *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
// Convert kwargs to Go map
goArgs := make(map[string]interface{})
for _, kwarg := range kwargs {
if len(kwarg) == 2 {
key := string(kwarg[0].(starlark.String))
value := starlarkToGo(kwarg[1])
goArgs[key] = value
}
}
// Also handle positional args if there's exactly one dict argument
if len(args) == 1 && len(kwargs) == 0 {
if dict, ok := args[0].(*starlark.Dict); ok {
for _, item := range dict.Items() {
if keyStr, ok := item[0].(starlark.String); ok {
goArgs[string(keyStr)] = starlarkToGo(item[1])
}
}
}
}
// Call the MCP tool
result, err := s.callMCPTool(ctx, capturedClientName, capturedToolName, goArgs, appendLog)
if err != nil {
return starlark.None, fmt.Errorf("tool call failed: %v", err)
}
// Convert result back to Starlark
return goToStarlark(result), nil
})
structMembers[parsedToolName] = toolFunc
if compatibilityAlias != parsedToolName && isValidStarlarkIdentifier(compatibilityAlias) {
if _, exists := structMembers[compatibilityAlias]; !exists {
structMembers[compatibilityAlias] = toolFunc
s.logger.Debug("%s [%s] Added compatibility alias: %s -> %s", codemcp.CodeModeLogPrefix, clientName, compatibilityAlias, parsedToolName)
}
}
}
// Create a struct for this server
serverStruct := starlarkstruct.FromStringDict(starlark.String(clientName), structMembers)
predeclared[clientName] = serverStruct
s.logger.Debug("%s [%s] Added server struct with %d tools", codemcp.CodeModeLogPrefix, clientName, len(structMembers))
}
if len(serverKeys) > 0 {
s.logger.Debug("%s Bound %d servers with tools: %v", codemcp.CodeModeLogPrefix, len(serverKeys), serverKeys)
} else {
s.logger.Debug("%s No servers available for code mode execution", codemcp.CodeModeLogPrefix)
}
// Step 3: Create Starlark thread with print function and timeout
toolExecutionTimeout := s.getToolExecutionTimeout()
timeoutCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout)
defer cancel()
thread := &starlark.Thread{
Name: "codemode",
Print: func(_ *starlark.Thread, msg string) {
appendLog(msg)
},
}
// Set up cancellation check — watch the context and cancel the Starlark
// thread so that infinite loops and other long-running scripts are interrupted
// when the execution timeout fires.
thread.SetLocal("context", timeoutCtx)
go func() {
<-timeoutCtx.Done()
thread.Cancel(timeoutCtx.Err().Error())
}()
// Step 4: Configure Starlark dialect options for a Python-like experience
starlarkOpts := &syntax.FileOptions{
TopLevelControl: true, // allow if/for/while at top level (not just inside functions)
While: true, // enable while loops
Set: true, // enable set() builtin
GlobalReassign: true, // allow reassignment to top-level names
Recursion: true, // allow recursive functions
}
// Step 5: Execute the code
globals, err := starlark.ExecFileOptions(starlarkOpts, thread, "code.star", trimmedCode, predeclared)
if err != nil {
errorMessage := err.Error()
hints := generatePythonErrorHints(errorMessage, serverKeys)
s.logger.Debug("%s Execution failed: %s", codemcp.CodeModeLogPrefix, errorMessage)
errorKind := ExecutionErrorTypeRuntime
if strings.Contains(errorMessage, "syntax error") {
errorKind = ExecutionErrorTypeSyntax
}
return ExecutionResult{
Result: nil,
Logs: logs,
Errors: &ExecutionError{
Kind: errorKind,
Message: errorMessage,
Hints: hints,
},
Environment: ExecutionEnvironment{
ServerKeys: serverKeys,
},
}
}
// Step 6: Extract result from globals
var result interface{}
if resultVal, ok := globals["result"]; ok && resultVal != starlark.None {
result = starlarkToGo(resultVal)
}
s.logger.Debug("%s Execution completed successfully", codemcp.CodeModeLogPrefix)
return ExecutionResult{
Result: result,
Logs: logs,
Errors: nil,
Environment: ExecutionEnvironment{
ServerKeys: serverKeys,
},
}
}
// callMCPTool calls an MCP tool and returns the result.
func (s *StarlarkCodeMode) callMCPTool(ctx *schemas.BifrostContext, clientName, toolName string, args map[string]interface{}, appendLog func(string)) (interface{}, error) {
// Get available tools per client
availableToolsPerClient := s.clientManager.GetToolPerClient(ctx)
// Find the client by name
tools, exists := availableToolsPerClient[clientName]
if !exists || len(tools) == 0 {
return nil, fmt.Errorf("client not found for server name: %s", clientName)
}
// Get client using a tool from this client
var client *schemas.MCPClientState
for _, tool := range tools {
if tool.Function != nil && tool.Function.Name != "" {
client = s.clientManager.GetClientForTool(tool.Function.Name)
if client != nil {
break
}
}
}
if client == nil {
return nil, fmt.Errorf("client not found for server name: %s", clientName)
}
// Strip the client name prefix from tool name before calling MCP server
originalToolName := stripClientPrefix(toolName, clientName)
originalRequestID, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string)
if !ok {
originalRequestID = ""
}
// Generate new request ID for this nested tool call
var newRequestID string
if s.fetchNewRequestIDFunc != nil {
newRequestID = s.fetchNewRequestIDFunc(ctx)
} else {
newRequestID = fmt.Sprintf("exec_%d_%s", time.Now().UnixNano(), toolName)
}
// Create new child context
deadline, hasDeadline := ctx.Deadline()
if !hasDeadline {
deadline = schemas.NoDeadline
}
nestedCtx := schemas.NewBifrostContext(ctx, deadline)
nestedCtx.SetValue(schemas.BifrostContextKeyRequestID, newRequestID)
if originalRequestID != "" {
nestedCtx.SetValue(schemas.BifrostContextKeyParentMCPRequestID, originalRequestID)
}
// Marshal arguments to JSON for the tool call
argsJSON, err := schemas.MarshalSorted(args)
if err != nil {
return nil, fmt.Errorf("failed to marshal tool arguments: %v", err)
}
// Build tool call for MCP request
toolCallReq := schemas.ChatAssistantMessageToolCall{
ID: schemas.Ptr(newRequestID),
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: schemas.Ptr(toolName),
Arguments: string(argsJSON),
},
}
// Create BifrostMCPRequest
mcpRequest := &schemas.BifrostMCPRequest{
RequestType: schemas.MCPRequestTypeChatToolCall,
ChatAssistantMessageToolCall: &toolCallReq,
}
// Check if plugin pipeline is available
if s.pluginPipelineProvider == nil {
// Should never happen, but just in case
s.logger.Warn("%s Plugin pipeline provider is nil", codemcp.CodeModeLogPrefix)
return nil, fmt.Errorf("plugin pipeline provider is nil")
}
// Get plugin pipeline and run hooks
pipeline := s.pluginPipelineProvider()
if pipeline == nil {
// Should never happen, but just in case
s.logger.Warn("%s Plugin pipeline is nil", codemcp.CodeModeLogPrefix)
return nil, fmt.Errorf("plugin pipeline is nil")
}
defer s.releasePluginPipeline(pipeline)
// Run PreMCPHooks
preReq, shortCircuit, preCount := pipeline.RunMCPPreHooks(nestedCtx, mcpRequest)
// Handle short-circuit cases
if shortCircuit != nil {
if shortCircuit.Response != nil {
finalResp, _ := pipeline.RunMCPPostHooks(nestedCtx, shortCircuit.Response, nil, preCount)
if finalResp != nil {
if finalResp.ChatMessage != nil {
return extractResultFromChatMessage(finalResp.ChatMessage), nil
}
if finalResp.ResponsesMessage != nil {
result, err := extractResultFromResponsesMessage(finalResp.ResponsesMessage)
if err != nil {
return nil, err
}
if result != nil {
return result, nil
}
}
}
return nil, fmt.Errorf("plugin short-circuit returned invalid response")
}
if shortCircuit.Error != nil {
pipeline.RunMCPPostHooks(nestedCtx, nil, shortCircuit.Error, preCount)
if shortCircuit.Error.Error != nil {
return nil, fmt.Errorf("%s", shortCircuit.Error.Error.Message)
}
return nil, fmt.Errorf("plugin short-circuit error")
}
}
// If pre-hooks modified the request, extract updated args
if preReq != nil && preReq.ChatAssistantMessageToolCall != nil {
toolCallReq = *preReq.ChatAssistantMessageToolCall
if toolCallReq.Function.Arguments != "" {
if err := sonic.Unmarshal([]byte(toolCallReq.Function.Arguments), &args); err != nil {
s.logger.Warn("%s Failed to parse modified tool arguments, using original: %v", codemcp.CodeModeLogPrefix, err)
}
}
}
// Execute tool
startTime := time.Now()
toolNameToCall := originalToolName
callRequest := mcp.CallToolRequest{
Request: mcp.Request{
Method: string(mcp.MethodToolsCall),
},
Params: mcp.CallToolParams{
Name: toolNameToCall,
Arguments: args,
},
Header: utils.GetHeadersForToolExecution(nestedCtx, client),
}
toolExecutionTimeout := s.getToolExecutionTimeout()
toolCtx, cancel := context.WithTimeout(nestedCtx, toolExecutionTimeout)
defer cancel()
var toolResponse *mcp.CallToolResult
var callErr error
if client.ExecutionConfig.AuthType == schemas.MCPAuthTypePerUserOauth {
accessToken, err := utils.ResolvePerUserOAuthToken(nestedCtx, client, s.oauth2Provider)
if err != nil {
return nil, err
}
if client.Conn == nil {
// Per-user OAuth with no persistent connection — use a temporary connection.
// Assign to outer toolResponse/callErr so the shared logging + post-hooks path runs.
toolResponse, callErr = codemcp.ExecuteToolWithUserToken(toolCtx, client.ExecutionConfig, toolNameToCall, args, accessToken, s.logger)
if callErr != nil && toolCtx.Err() == context.DeadlineExceeded {
callErr = fmt.Errorf("MCP tool call timed out after %v: %s", toolExecutionTimeout, toolName)
}
} else {
callRequest.Header = utils.BuildPerUserOAuthHeaders(callRequest.Header, accessToken)
toolResponse, callErr = client.Conn.CallTool(toolCtx, callRequest)
}
} else {
toolResponse, callErr = client.Conn.CallTool(toolCtx, callRequest)
}
latency := time.Since(startTime).Milliseconds()
var mcpResp *schemas.BifrostMCPResponse
var bifrostErr *schemas.BifrostError
if callErr != nil {
s.logger.Debug("%s Tool call failed: %s.%s - %v", codemcp.CodeModeLogPrefix, clientName, toolName, callErr)
appendLog(fmt.Sprintf("[TOOL] %s.%s error: %v", clientName, toolName, callErr))
bifrostErr = &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: fmt.Sprintf("tool call failed for %s.%s: %v", clientName, toolName, callErr),
},
}
} else {
rawResult := extractTextFromMCPResponse(toolResponse, toolName)
if after, ok := strings.CutPrefix(rawResult, "Error: "); ok {
errorMsg := after
s.logger.Debug("%s Tool returned error result: %s.%s - %s", codemcp.CodeModeLogPrefix, clientName, toolName, errorMsg)
appendLog(fmt.Sprintf("[TOOL] %s.%s error result: %s", clientName, toolName, errorMsg))
bifrostErr = &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: errorMsg,
},
}
} else {
mcpResp = &schemas.BifrostMCPResponse{
ChatMessage: createToolResponseMessage(toolCallReq, rawResult),
ExtraFields: schemas.BifrostMCPResponseExtraFields{
ClientName: clientName,
ToolName: originalToolName,
Latency: latency,
},
}
resultStr := formatResultForLog(rawResult)
logToolName := stripClientPrefix(toolName, clientName)
logToolName = strings.ReplaceAll(logToolName, "-", "_")
appendLog(fmt.Sprintf("[TOOL] %s.%s raw response: %s", clientName, logToolName, resultStr))
}
}
// Run post-hooks
finalResp, finalErr := pipeline.RunMCPPostHooks(nestedCtx, mcpResp, bifrostErr, preCount)
if finalErr != nil {
if finalErr.Error != nil {
return nil, fmt.Errorf("%s", finalErr.Error.Message)
}
return nil, fmt.Errorf("tool execution failed")
}
if finalResp == nil {
return nil, fmt.Errorf("plugin post-hooks returned invalid response")
}
if finalResp.ChatMessage != nil {
return extractResultFromChatMessage(finalResp.ChatMessage), nil
}
if finalResp.ResponsesMessage != nil {
result, err := extractResultFromResponsesMessage(finalResp.ResponsesMessage)
if err != nil {
return nil, err
}
if result != nil {
return result, nil
}
}
return nil, fmt.Errorf("plugin post-hooks returned invalid response")
}

View File

@@ -0,0 +1,301 @@
//go:build !tinygo && !wasm
package starlark
import (
"context"
"encoding/json"
"fmt"
"strings"
codemcp "github.com/maximhq/bifrost/core/mcp"
"github.com/maximhq/bifrost/core/schemas"
)
// createGetToolDocsTool creates the getToolDocs tool definition for code mode.
// This tool provides detailed documentation for a specific tool when the compact
// signatures from readToolFile are not sufficient to understand how to use it.
func (s *StarlarkCodeMode) createGetToolDocsTool() schemas.ChatTool {
getToolDocsProps := schemas.NewOrderedMapFromPairs(
schemas.KV("server", map[string]interface{}{
"type": "string",
"description": "The server name (e.g., 'calculator'). Use listToolFiles to see available servers.",
}),
schemas.KV("tool", map[string]interface{}{
"type": "string",
"description": "The tool name (e.g., 'add'). Use readToolFile to see available tools for a server.",
}),
)
return schemas.ChatTool{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: codemcp.ToolTypeGetToolDocs,
Description: schemas.Ptr(
"Get detailed documentation for a specific tool including full parameter descriptions, " +
"types, and usage examples. Use this when the compact signature from readToolFile " +
"is not sufficient to understand how to use a tool. " +
"Requires both server name and tool name as parameters.",
),
Parameters: &schemas.ToolFunctionParameters{
Type: "object",
Properties: getToolDocsProps,
Required: []string{"server", "tool"},
},
},
}
}
// handleGetToolDocs handles the getToolDocs tool call.
func (s *StarlarkCodeMode) handleGetToolDocs(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) {
// Parse tool arguments
var arguments map[string]interface{}
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil {
return nil, fmt.Errorf("failed to parse tool arguments: %v", err)
}
serverName, ok := arguments["server"].(string)
if !ok || serverName == "" {
return nil, fmt.Errorf("server parameter is required and must be a string")
}
toolName, ok := arguments["tool"].(string)
if !ok || toolName == "" {
return nil, fmt.Errorf("tool parameter is required and must be a string")
}
// Get available tools per client
availableToolsPerClient := s.clientManager.GetToolPerClient(ctx)
// Find matching client
var matchedClientName string
var matchedTool *schemas.ChatTool
serverNameLower := strings.ToLower(serverName)
for clientName, tools := range availableToolsPerClient {
client := s.clientManager.GetClientByName(clientName)
if client == nil {
s.logger.Warn("%s Client %s not found, skipping", codemcp.CodeModeLogPrefix, clientName)
continue
}
if !client.ExecutionConfig.IsCodeModeClient || len(tools) == 0 {
continue
}
clientNameLower := strings.ToLower(clientName)
if clientNameLower == serverNameLower {
matchedClientName = clientName
// Find the specific tool
for i, tool := range tools {
if tool.Function != nil {
if matchesToolReference(toolName, clientName, tool.Function.Name) {
matchedTool = &tools[i]
break
}
}
}
break
}
}
// Handle server not found
if matchedClientName == "" {
var availableServers []string
for name := range availableToolsPerClient {
client := s.clientManager.GetClientByName(name)
if client != nil && client.ExecutionConfig.IsCodeModeClient {
availableServers = append(availableServers, name)
}
}
errorMsg := fmt.Sprintf("Server '%s' not found. Available servers are:\n", serverName)
for _, sn := range availableServers {
errorMsg += fmt.Sprintf(" - %s\n", sn)
}
return createToolResponseMessage(toolCall, errorMsg), nil
}
// Handle tool not found
if matchedTool == nil {
tools := availableToolsPerClient[matchedClientName]
var availableTools []string
for _, tool := range tools {
if tool.Function != nil {
availableTools = append(availableTools, getCanonicalToolName(matchedClientName, tool.Function.Name))
}
}
errorMsg := fmt.Sprintf("Tool '%s' not found in server '%s'. Available tools are:\n", toolName, matchedClientName)
for _, t := range availableTools {
errorMsg += fmt.Sprintf(" - %s\n", t)
}
return createToolResponseMessage(toolCall, errorMsg), nil
}
// Generate detailed documentation using generateTypeDefinitions
docContent := generateTypeDefinitions(matchedClientName, []schemas.ChatTool{*matchedTool}, true)
return createToolResponseMessage(toolCall, docContent), nil
}
// generateTypeDefinitions generates Python documentation with docstrings from ChatTool schemas.
func generateTypeDefinitions(clientName string, tools []schemas.ChatTool, isToolLevel bool) string {
var sb strings.Builder
// Write comprehensive header
sb.WriteString("# ============================================================================\n")
if isToolLevel && len(tools) == 1 && tools[0].Function != nil {
sb.WriteString(fmt.Sprintf("# Documentation for %s.%s tool\n", clientName, getCanonicalToolName(clientName, tools[0].Function.Name)))
} else {
sb.WriteString(fmt.Sprintf("# Documentation for %s MCP server\n", clientName))
}
sb.WriteString("# ============================================================================\n")
sb.WriteString("#\n")
if isToolLevel && len(tools) == 1 {
sb.WriteString("# This file contains Python documentation for a specific tool on this MCP server.\n")
} else {
sb.WriteString("# This file contains Python documentation for all tools available on this MCP server.\n")
}
sb.WriteString("#\n")
sb.WriteString("# USAGE INSTRUCTIONS:\n")
sb.WriteString(fmt.Sprintf("# Call tools using: result = %s.tool_name(param=value)\n", clientName))
sb.WriteString("# No async/await needed - calls are synchronous.\n")
sb.WriteString("#\n")
sb.WriteString("# STARLARK DIFFERENCE FROM PYTHON:\n")
sb.WriteString("# for/if/while at top level MUST be inside a function.\n")
sb.WriteString("# Wrap loops: def main(): for x in items: ... then result = main()\n")
sb.WriteString("#\n")
sb.WriteString("# CRITICAL - HANDLING RESPONSES:\n")
sb.WriteString("# Tool responses are dicts. To avoid runtime errors:\n")
sb.WriteString("# 1. Use print(result) to inspect the response structure first\n")
sb.WriteString("# 2. Access dict values with brackets: result[\"key\"] NOT result.key\n")
sb.WriteString("# 3. Use .get() for safe access: result.get(\"key\", default)\n")
sb.WriteString("#\n")
sb.WriteString("# Common error: \"key not found\" or \"has no attribute\"\n")
sb.WriteString("# Fix: Use print() to see actual structure, then use result[\"key\"] or .get()\n")
sb.WriteString("# ============================================================================\n\n")
// Generate function definitions for each tool
for _, tool := range tools {
if tool.Function == nil || tool.Function.Name == "" {
continue
}
originalToolName := tool.Function.Name
toolName := getCanonicalToolName(clientName, originalToolName)
description := ""
if tool.Function.Description != nil {
description = *tool.Function.Description
}
// Generate function signature
params := formatPythonParams(tool.Function.Parameters)
sb.WriteString(fmt.Sprintf("def %s(%s) -> dict:\n", toolName, params))
// Generate docstring
sb.WriteString(" \"\"\"\n")
if description != "" {
sb.WriteString(fmt.Sprintf(" %s\n", description))
sb.WriteString("\n")
}
// Args section
if tool.Function.Parameters != nil && tool.Function.Parameters.Properties != nil {
props := tool.Function.Parameters.Properties
required := make(map[string]bool)
if tool.Function.Parameters.Required != nil {
for _, req := range tool.Function.Parameters.Required {
required[req] = true
}
}
if props.Len() > 0 {
sb.WriteString(" Args:\n")
// Sort properties for consistent output
propNames := make([]string, 0, props.Len())
props.Range(func(name string, _ interface{}) bool {
propNames = append(propNames, name)
return true
})
for i := 0; i < len(propNames)-1; i++ {
for j := i + 1; j < len(propNames); j++ {
if propNames[i] > propNames[j] {
propNames[i], propNames[j] = propNames[j], propNames[i]
}
}
}
for _, propName := range propNames {
prop, _ := props.Get(propName)
propMap, ok := prop.(map[string]interface{})
if !ok {
continue
}
pyType := jsonSchemaToPython(propMap)
propDesc := ""
if desc, ok := propMap["description"].(string); ok && desc != "" {
propDesc = desc
} else {
propDesc = fmt.Sprintf("%s parameter", propName)
}
requiredNote := ""
if required[propName] {
requiredNote = " (required)"
} else {
requiredNote = " (optional)"
}
sb.WriteString(fmt.Sprintf(" %s (%s): %s%s\n", propName, pyType, propDesc, requiredNote))
}
sb.WriteString("\n")
}
}
// Returns section
sb.WriteString(" Returns:\n")
sb.WriteString(" dict: Response from the tool. Structure varies by tool.\n")
sb.WriteString(" Use print(result) to inspect the actual structure.\n")
sb.WriteString("\n")
// Example section
sb.WriteString(" Example:\n")
sb.WriteString(fmt.Sprintf(" result = %s.%s(%s)\n", clientName, toolName, getExampleParams(tool.Function.Parameters)))
sb.WriteString(" print(result) # Always inspect response first!\n")
sb.WriteString(" value = result.get(\"key\", default) # Safe access\n")
sb.WriteString(" \"\"\"\n")
sb.WriteString(" ...\n\n")
}
return sb.String()
}
// getExampleParams generates example parameter usage for a function.
func getExampleParams(params *schemas.ToolFunctionParameters) string {
if params == nil || params.Properties == nil || params.Properties.Len() == 0 {
return ""
}
required := make(map[string]bool)
if params.Required != nil {
for _, req := range params.Required {
required[req] = true
}
}
keys := params.Properties.Keys()
// Get first required param as example
for _, name := range keys {
if required[name] {
return fmt.Sprintf("%s=\"...\"", name)
}
}
// If no required, get first param
if len(keys) > 0 {
return fmt.Sprintf("%s=\"...\"", keys[0])
}
return ""
}

View File

@@ -0,0 +1,23 @@
//go:build !tinygo && !wasm
package starlark
import "github.com/maximhq/bifrost/core/schemas"
// noopLogger is a no-op implementation of schemas.Logger used as a fallback
// when no logger is provided.
type noopLogger struct{}
func (noopLogger) Debug(string, ...any) {}
func (noopLogger) Info(string, ...any) {}
func (noopLogger) Warn(string, ...any) {}
func (noopLogger) Error(string, ...any) {}
func (noopLogger) Fatal(string, ...any) {}
func (noopLogger) SetLevel(schemas.LogLevel) {}
func (noopLogger) SetOutputType(schemas.LoggerOutputType) {}
func (noopLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
return schemas.NoopLogEvent
}
// defaultLogger is used when nil is passed to NewStarlarkCodeMode.
var defaultLogger schemas.Logger = noopLogger{}

View File

@@ -0,0 +1,231 @@
//go:build !tinygo && !wasm
package starlark
import (
"context"
"fmt"
"strings"
codemcp "github.com/maximhq/bifrost/core/mcp"
"github.com/maximhq/bifrost/core/schemas"
)
// createListToolFilesTool creates the listToolFiles tool definition for code mode.
// This tool allows listing all available virtual .pyi stub files for connected MCP servers.
// The description is dynamically generated based on the configured CodeModeBindingLevel.
func (s *StarlarkCodeMode) createListToolFilesTool() schemas.ChatTool {
bindingLevel := s.GetBindingLevel()
var description string
if bindingLevel == schemas.CodeModeBindingLevelServer {
description = "Returns a tree structure listing all virtual .pyi stub files available for connected MCP servers. " +
"Each server has a corresponding file (e.g., servers/<serverName>.pyi) that contains compact Python signatures for all tools in that server. " +
"Safe workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " +
"Use readToolFile before executeToolCode to read a specific server file and confirm exact callable tool names and parameters. " +
"Use getToolDocs if you need detailed documentation for a specific tool. " +
"In code, access tools via: server_name.tool_name(param=value). " +
"The server names used in code correspond to the human-readable names shown in this listing. " +
"This tool is generic and works with any set of servers connected at runtime. " +
"Always check this tool whenever you are unsure about what tools you have available or if you want to verify available servers and their tools."
} else {
description = "Returns a tree structure listing all virtual .pyi stub files available for connected MCP servers, organized by individual tool. " +
"Each tool has a corresponding file (e.g., servers/<serverName>/<toolName>.pyi) that contains compact Python signatures for that specific tool. " +
"The <toolName> shown in each filename is the exact canonical identifier exposed in executeToolCode. " +
"Safe workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " +
"Use readToolFile before executeToolCode to confirm the exact signature and parameters for the tool you want to call. " +
"Use getToolDocs if you need detailed documentation for a specific tool. " +
"In code, access tools via: server_name.tool_name(param=value). " +
"The server names used in code correspond to the human-readable names shown in this listing. " +
"This tool is generic and works with any set of servers connected at runtime. " +
"Always check this tool whenever you are unsure about what tools you have available or if you want to verify available servers and their tools."
}
return schemas.ChatTool{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: codemcp.ToolTypeListToolFiles,
Description: schemas.Ptr(description),
Parameters: &schemas.ToolFunctionParameters{
Type: "object",
Properties: schemas.NewOrderedMap(),
Required: []string{},
},
},
}
}
// handleListToolFiles handles the listToolFiles tool call.
// It builds a tree structure listing all virtual .pyi files available for code mode clients.
func (s *StarlarkCodeMode) handleListToolFiles(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) {
availableToolsPerClient := s.clientManager.GetToolPerClient(ctx)
if len(availableToolsPerClient) == 0 {
responseText := "No servers are currently connected. There are no virtual .pyi files available. " +
"Please ensure servers are connected before using this tool."
return createToolResponseMessage(toolCall, responseText), nil
}
// Get the code mode binding level
bindingLevel := s.GetBindingLevel()
// Build file list based on binding level
var files []string
codeModeServerCount := 0
for clientName, tools := range availableToolsPerClient {
client := s.clientManager.GetClientByName(clientName)
if client == nil {
s.logger.Warn("%s Client %s not found, skipping", codemcp.CodeModeLogPrefix, clientName)
continue
}
if !client.ExecutionConfig.IsCodeModeClient {
continue
}
codeModeServerCount++
if bindingLevel == schemas.CodeModeBindingLevelServer {
// Server-level: one file per server
files = append(files, fmt.Sprintf("servers/%s.pyi", clientName))
} else {
// Tool-level: one file per tool
for _, tool := range tools {
if tool.Function != nil && tool.Function.Name != "" {
toolName := getCanonicalToolName(clientName, tool.Function.Name)
if err := validateNormalizedToolName(toolName); err != nil {
s.logger.Warn("%s Skipping tool '%s' from client '%s': %v", codemcp.CodeModeLogPrefix, tool.Function.Name, clientName, err)
continue
}
toolFileName := fmt.Sprintf("servers/%s/%s.pyi", clientName, toolName)
files = append(files, toolFileName)
}
}
}
}
if codeModeServerCount == 0 {
responseText := "Servers are connected but none are configured for code mode. " +
"There are no virtual .pyi files available."
return createToolResponseMessage(toolCall, responseText), nil
}
// Build tree structure from file list
responseText := buildListToolFilesResponse(files, bindingLevel)
return createToolResponseMessage(toolCall, responseText), nil
}
func buildListToolFilesResponse(files []string, bindingLevel schemas.CodeModeBindingLevel) string {
tree := buildVFSTree(files)
if tree == "" {
return ""
}
header := []string{
"# Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode",
}
if bindingLevel == schemas.CodeModeBindingLevelServer {
header = append(header, "# Read the server .pyi file before executeToolCode to confirm exact tool names and parameters.")
} else {
header = append(header,
"# Filenames below use the exact canonical tool identifiers available in executeToolCode.",
"# Still call readToolFile before executeToolCode to confirm parameters and return shape.",
)
}
return strings.Join(append(header, "", tree), "\n")
}
// VFS tree node structure for building hierarchical file structure
type treeNode struct {
isDirectory bool
children map[string]*treeNode
}
// buildVFSTree creates a hierarchical tree structure from a flat list of file paths.
func buildVFSTree(files []string) string {
if len(files) == 0 {
return ""
}
root := &treeNode{
isDirectory: true,
children: make(map[string]*treeNode),
}
// Parse all files and build tree structure
for _, file := range files {
parts := strings.Split(file, "/")
current := root
// Create all intermediate directories and final file
for i, part := range parts {
if _, exists := current.children[part]; !exists {
current.children[part] = &treeNode{
isDirectory: i < len(parts)-1, // Last part is file, not directory
children: make(map[string]*treeNode),
}
}
current = current.children[part]
}
}
// Render tree structure with proper indentation
var lines []string
renderTreeNode(root, "", &lines, true)
return strings.Join(lines, "\n")
}
// renderTreeNode recursively renders a tree node and its children with proper indentation.
func renderTreeNode(node *treeNode, indent string, lines *[]string, isRoot bool) {
// Get sorted keys for consistent output
var keys []string
for key := range node.children {
keys = append(keys, key)
}
// Simple bubble sort for small lists (good enough for this use case)
for i := 0; i < len(keys); i++ {
for j := i + 1; j < len(keys); j++ {
if keys[j] < keys[i] {
keys[i], keys[j] = keys[j], keys[i]
}
}
}
for _, key := range keys {
child := node.children[key]
// Format the line
var line string
if isRoot {
// Root level - no indentation
if child.isDirectory {
line = key + "/"
} else {
line = key
}
} else {
// Non-root levels - add indentation
if child.isDirectory {
line = indent + key + "/"
} else {
line = indent + key
}
}
*lines = append(*lines, line)
// Recurse into children
if child.isDirectory && len(child.children) > 0 {
var nextIndent string
if isRoot {
nextIndent = " "
} else {
nextIndent = indent + " "
}
renderTreeNode(child, nextIndent, lines, false)
}
}
}

View File

@@ -0,0 +1,443 @@
//go:build !tinygo && !wasm
package starlark
import (
"context"
"encoding/json"
"fmt"
"strings"
codemcp "github.com/maximhq/bifrost/core/mcp"
"github.com/maximhq/bifrost/core/schemas"
)
// createReadToolFileTool creates the readToolFile tool definition for code mode.
// This tool allows reading virtual .pyi stub files for specific MCP servers/tools,
// generating Python type stubs from the server's tool schemas.
func (s *StarlarkCodeMode) createReadToolFileTool() schemas.ChatTool {
bindingLevel := s.GetBindingLevel()
var fileNameDescription, toolDescription string
if bindingLevel == schemas.CodeModeBindingLevelServer {
fileNameDescription = "The virtual filename from listToolFiles in format: servers/<serverName>.pyi (e.g., 'servers/calculator.pyi')"
toolDescription = "Reads a virtual .pyi stub file for a specific MCP server, returning compact Python function signatures " +
"for all tools available on that server. The fileName should be in format servers/<serverName>.pyi as listed by listToolFiles. " +
"The function performs case-insensitive matching and removes the .pyi extension. " +
"This is the authoritative source for the exact callable tool names and parameters to use in executeToolCode. " +
"Each tool can be accessed in code via: serverName.tool_name(param=value). " +
"If the compact signature is not enough to understand a tool, use getToolDocs for detailed documentation. " +
"Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " +
"IMPORTANT: If the response header shows 'Total lines: X (this is the complete file)', " +
"do NOT call this tool again with startLine/endLine - you already have the complete file."
} else {
fileNameDescription = "The virtual filename from listToolFiles in format: servers/<serverName>/<toolName>.pyi (e.g., 'servers/calculator/add.pyi')"
toolDescription = "Reads a virtual .pyi stub file for a specific tool, returning its compact Python function signature. " +
"The fileName should be in format servers/<serverName>/<toolName>.pyi as listed by listToolFiles. " +
"The function performs case-insensitive matching and removes the .pyi extension. " +
"This is the authoritative source for the exact callable tool name and arguments to use in executeToolCode. " +
"The tool can be accessed in code via: serverName.tool_name(param=value) using the def name shown in the file. " +
"If the compact signature is not enough to understand the tool, use getToolDocs for detailed documentation. " +
"Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " +
"IMPORTANT: If the response header shows 'Total lines: X (this is the complete file)', " +
"do NOT call this tool again with startLine/endLine - you already have the complete file."
}
readToolFileProps := schemas.NewOrderedMapFromPairs(
schemas.KV("fileName", map[string]interface{}{
"type": "string",
"description": fileNameDescription,
}),
schemas.KV("startLine", map[string]interface{}{
"type": "number",
"description": "Optional 1-based starting line number for partial file read. Usually not needed - omit to read the entire file. Files are typically small (under 50 lines).",
}),
schemas.KV("endLine", map[string]interface{}{
"type": "number",
"description": "Optional 1-based ending line number for partial file read. Usually not needed - omit to read the entire file. Will be clamped to actual file size if too large.",
}),
)
return schemas.ChatTool{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: codemcp.ToolTypeReadToolFile,
Description: schemas.Ptr(toolDescription),
Parameters: &schemas.ToolFunctionParameters{
Type: "object",
Properties: readToolFileProps,
Required: []string{"fileName"},
},
},
}
}
// handleReadToolFile handles the readToolFile tool call.
func (s *StarlarkCodeMode) handleReadToolFile(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) {
// Parse tool arguments
var arguments map[string]interface{}
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil {
return nil, fmt.Errorf("failed to parse tool arguments: %v", err)
}
fileName, ok := arguments["fileName"].(string)
if !ok || fileName == "" {
return nil, fmt.Errorf("fileName parameter is required and must be a string")
}
// Parse the file path to extract server name and optional tool name
serverName, toolName, isToolLevel := parseVFSFilePath(fileName)
// Get available tools per client
availableToolsPerClient := s.clientManager.GetToolPerClient(ctx)
// Find matching client
var matchedClientName string
var matchedTools []schemas.ChatTool
matchCount := 0
for clientName, tools := range availableToolsPerClient {
client := s.clientManager.GetClientByName(clientName)
if client == nil {
s.logger.Warn("%s Client %s not found, skipping", codemcp.CodeModeLogPrefix, clientName)
continue
}
if !client.ExecutionConfig.IsCodeModeClient || len(tools) == 0 {
continue
}
clientNameLower := strings.ToLower(clientName)
serverNameLower := strings.ToLower(serverName)
if clientNameLower == serverNameLower {
matchCount++
if matchCount > 1 {
// Multiple matches found
errorMsg := fmt.Sprintf("Multiple servers match filename '%s':\n", fileName)
for name := range availableToolsPerClient {
if strings.ToLower(name) == serverNameLower {
errorMsg += fmt.Sprintf(" - %s\n", name)
}
}
errorMsg += "\nPlease use a more specific filename. Use the exact display name from listToolFiles to avoid ambiguity."
return createToolResponseMessage(toolCall, errorMsg), nil
}
matchedClientName = clientName
if isToolLevel {
// Tool-level: filter to specific tool
var foundTool *schemas.ChatTool
for i, tool := range tools {
if tool.Function != nil {
if matchesToolReference(toolName, clientName, tool.Function.Name) {
foundTool = &tools[i]
break
}
}
}
if foundTool == nil {
availableTools := make([]string, 0)
for _, tool := range tools {
if tool.Function != nil {
availableTools = append(availableTools, getCanonicalToolName(clientName, tool.Function.Name))
}
}
errorMsg := fmt.Sprintf("Tool '%s' not found in server '%s'. Available tools in this server are:\n", toolName, clientName)
for _, t := range availableTools {
errorMsg += fmt.Sprintf(" - servers/%s/%s.pyi\n", clientName, t)
}
return createToolResponseMessage(toolCall, errorMsg), nil
}
matchedTools = []schemas.ChatTool{*foundTool}
} else {
// Server-level: use all tools
matchedTools = tools
}
}
}
if matchedClientName == "" {
// Build helpful error message with available files
bindingLevel := s.GetBindingLevel()
var availableFiles []string
for name := range availableToolsPerClient {
if bindingLevel == schemas.CodeModeBindingLevelServer {
availableFiles = append(availableFiles, fmt.Sprintf("servers/%s.pyi", name))
} else {
client := s.clientManager.GetClientByName(name)
if client != nil && client.ExecutionConfig.IsCodeModeClient {
if tools, ok := availableToolsPerClient[name]; ok {
for _, tool := range tools {
if tool.Function != nil {
availableFiles = append(availableFiles, fmt.Sprintf("servers/%s/%s.pyi", name, getCanonicalToolName(name, tool.Function.Name)))
}
}
}
}
}
}
errorMsg := fmt.Sprintf("No server found matching '%s'. Available virtual files are:\n", serverName)
for _, f := range availableFiles {
errorMsg += fmt.Sprintf(" - %s\n", f)
}
return createToolResponseMessage(toolCall, errorMsg), nil
}
// Generate compact Python signatures
fileContent := generateCompactSignatures(matchedClientName, matchedTools, isToolLevel)
lines := strings.Split(fileContent, "\n")
totalLines := len(lines)
// Prepend total lines info so LLM knows the file size upfront
fileContent = fmt.Sprintf("# Total lines: %d (this is the complete file, no need to paginate)\n%s", totalLines+1, fileContent)
// Recalculate lines after prepending
lines = strings.Split(fileContent, "\n")
totalLines = len(lines)
// Handle line slicing if provided
var startLine, endLine *int
if sl, ok := arguments["startLine"].(float64); ok {
slInt := int(sl)
startLine = &slInt
}
if el, ok := arguments["endLine"].(float64); ok {
elInt := int(el)
endLine = &elInt
}
if startLine != nil || endLine != nil {
start := 1
if startLine != nil {
start = *startLine
}
end := totalLines
if endLine != nil {
end = *endLine
}
// Clamp values to valid range instead of erroring
// This handles cases where LLM requests more lines than exist
if start < 1 {
start = 1
}
if start > totalLines {
start = totalLines
}
if end < 1 {
end = 1
}
if end > totalLines {
end = totalLines
}
if start > end {
// If start > end after clamping, just return the start line
end = start
}
// Slice lines (convert to 0-based indexing)
selectedLines := lines[start-1 : end]
fileContent = strings.Join(selectedLines, "\n")
}
return createToolResponseMessage(toolCall, fileContent), nil
}
// parseVFSFilePath parses a VFS file path and extracts the server name and optional tool name.
func parseVFSFilePath(fileName string) (serverName, toolName string, isToolLevel bool) {
// Remove .pyi extension
basePath := strings.TrimSuffix(fileName, ".pyi")
// Remove "servers/" prefix if present
basePath = strings.TrimPrefix(basePath, "servers/")
// Defensive validation: reject paths with path traversal attempts
if strings.Contains(basePath, "..") {
// Return empty to indicate invalid path
return "", "", false
}
// Check for path separator
parts := strings.Split(basePath, "/")
if len(parts) == 2 {
// Tool-level: "serverName/toolName"
// Validate that tool name doesn't contain additional path separators or traversal
if parts[1] == "" || strings.Contains(parts[1], "/") || strings.Contains(parts[1], "..") {
// Invalid tool name, treat as server-level
return parts[0], "", false
}
return parts[0], parts[1], true
}
// Server-level: "serverName"
// Validate server name doesn't contain path separators or traversal
if strings.Contains(basePath, "/") || strings.Contains(basePath, "..") {
// Invalid path
return "", "", false
}
return basePath, "", false
}
// generateCompactSignatures generates compact Python function signatures for tools.
func generateCompactSignatures(clientName string, tools []schemas.ChatTool, isToolLevel bool) string {
var sb strings.Builder
// Minimal header
if isToolLevel && len(tools) == 1 && tools[0].Function != nil {
toolName := getCanonicalToolName(clientName, tools[0].Function.Name)
sb.WriteString(fmt.Sprintf("# %s.%s tool\n", clientName, toolName))
} else {
sb.WriteString(fmt.Sprintf("# %s server tools\n", clientName))
}
sb.WriteString(fmt.Sprintf("# Usage: %s.tool_name(param=value)\n", clientName))
sb.WriteString("# The def names below are the exact callable names to use in executeToolCode.\n")
sb.WriteString("# Read this file before executeToolCode to confirm parameters and return shape.\n")
sb.WriteString(fmt.Sprintf("# For detailed docs: use getToolDocs(server=\"%s\", tool=\"tool_name\")\n", clientName))
sb.WriteString("# Note: Descriptions may be truncated. Use getToolDocs for full details.\n\n")
for _, tool := range tools {
if tool.Function == nil || tool.Function.Name == "" {
continue
}
toolName := getCanonicalToolName(clientName, tool.Function.Name)
// Format inline parameters in Python style
params := formatPythonParams(tool.Function.Parameters)
// Get description (truncate if too long)
desc := ""
if tool.Function.Description != nil && *tool.Function.Description != "" {
desc = *tool.Function.Description
// Truncate long descriptions to first sentence or 80 chars
if idx := strings.Index(desc, ". "); idx > 0 && idx < 80 {
desc = desc[:idx+1]
} else if len(desc) > 80 {
desc = desc[:77] + "..."
}
}
// Write Python signature: def tool_name(param: type, param: type = None) -> dict: # description
if desc != "" {
sb.WriteString(fmt.Sprintf("def %s(%s) -> dict: # %s\n", toolName, params, desc))
} else {
sb.WriteString(fmt.Sprintf("def %s(%s) -> dict\n", toolName, params))
}
}
return sb.String()
}
// formatPythonParams formats tool parameters as Python function parameters.
func formatPythonParams(params *schemas.ToolFunctionParameters) string {
if params == nil || params.Properties == nil || params.Properties.Len() == 0 {
return ""
}
props := params.Properties
required := make(map[string]bool)
if params.Required != nil {
for _, req := range params.Required {
required[req] = true
}
}
// Sort properties: required first, then optional, alphabetically within each group
requiredNames := make([]string, 0)
optionalNames := make([]string, 0)
props.Range(func(name string, _ interface{}) bool {
if required[name] {
requiredNames = append(requiredNames, name)
} else {
optionalNames = append(optionalNames, name)
}
return true
})
// Simple alphabetical sort for each group
for i := 0; i < len(requiredNames)-1; i++ {
for j := i + 1; j < len(requiredNames); j++ {
if requiredNames[i] > requiredNames[j] {
requiredNames[i], requiredNames[j] = requiredNames[j], requiredNames[i]
}
}
}
for i := 0; i < len(optionalNames)-1; i++ {
for j := i + 1; j < len(optionalNames); j++ {
if optionalNames[i] > optionalNames[j] {
optionalNames[i], optionalNames[j] = optionalNames[j], optionalNames[i]
}
}
}
parts := make([]string, 0, props.Len())
// Add required params first
for _, propName := range requiredNames {
prop, _ := props.Get(propName)
propMap, ok := prop.(map[string]interface{})
if !ok {
continue
}
pyType := jsonSchemaToPython(propMap)
parts = append(parts, fmt.Sprintf("%s: %s", propName, pyType))
}
// Add optional params with default None
for _, propName := range optionalNames {
prop, _ := props.Get(propName)
propMap, ok := prop.(map[string]interface{})
if !ok {
continue
}
pyType := jsonSchemaToPython(propMap)
parts = append(parts, fmt.Sprintf("%s: %s = None", propName, pyType))
}
return strings.Join(parts, ", ")
}
// jsonSchemaToPython converts a JSON Schema type definition to a Python type string.
func jsonSchemaToPython(prop map[string]interface{}) string {
// Check for enum first - takes precedence over type to show allowed values
if enum, ok := prop["enum"].([]interface{}); ok && len(enum) > 0 {
enumStrs := make([]string, 0, len(enum))
for _, e := range enum {
enumStrs = append(enumStrs, fmt.Sprintf("%q", e))
}
return "Literal[" + strings.Join(enumStrs, ", ") + "]"
}
// Check for const (single fixed value)
if constVal, ok := prop["const"]; ok {
return fmt.Sprintf("Literal[%q]", constVal)
}
// Fall back to type-based conversion
if typeVal, ok := prop["type"].(string); ok {
switch typeVal {
case "string":
return "str"
case "number":
return "float"
case "integer":
return "int"
case "boolean":
return "bool"
case "array":
itemsType := "Any"
if items, ok := prop["items"].(map[string]interface{}); ok {
itemsType = jsonSchemaToPython(items)
}
return fmt.Sprintf("list[%s]", itemsType)
case "object":
return "dict"
case "null":
return "None"
}
}
return "Any"
}

View File

@@ -0,0 +1,175 @@
//go:build !tinygo && !wasm
// Package starlark provides a Starlark-based implementation of the CodeMode interface.
// Starlark is a Python-like language designed for configuration and embedded scripting.
// See https://github.com/google/starlark-go for more information.
package starlark
import (
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/maximhq/bifrost/core/mcp"
"github.com/maximhq/bifrost/core/schemas"
)
// StarlarkCodeMode implements the CodeMode interface using a Starlark interpreter.
// It provides a sandboxed Python-like execution environment with access to MCP tools.
type StarlarkCodeMode struct {
// Configuration (atomic for thread-safe updates)
bindingLevel atomic.Value // schemas.CodeModeBindingLevel
toolExecutionTimeout atomic.Value // time.Duration
// Dependencies
clientManager mcp.ClientManager
pluginPipelineProvider func() mcp.PluginPipeline
releasePluginPipeline func(pipeline mcp.PluginPipeline)
fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string
oauth2Provider schemas.OAuth2Provider
// Logger for this instance
logger schemas.Logger
// Mutex for protecting logs during concurrent execution
logMu sync.Mutex
}
// NewStarlarkCodeMode creates a new Starlark-based CodeMode implementation.
//
// Parameters:
// - config: Configuration for the code mode (binding level, timeouts). Can be nil for defaults.
// - logger: Logger instance for this code mode. Can be nil.
//
// Returns:
// - *StarlarkCodeMode: A new Starlark code mode instance
//
// Note: Dependencies must be set via SetDependencies before the CodeMode can execute tools.
// This allows the CodeMode to be created before the MCPManager, avoiding circular dependencies.
func NewStarlarkCodeMode(config *mcp.CodeModeConfig, logger schemas.Logger) *StarlarkCodeMode {
if config == nil {
config = mcp.DefaultCodeModeConfig()
}
if config.BindingLevel == "" {
config.BindingLevel = schemas.CodeModeBindingLevelServer
}
if config.ToolExecutionTimeout <= 0 {
config.ToolExecutionTimeout = schemas.DefaultToolExecutionTimeout
}
if logger == nil {
logger = defaultLogger
}
s := &StarlarkCodeMode{
logger: logger,
}
// Initialize atomic values
s.bindingLevel.Store(config.BindingLevel)
s.toolExecutionTimeout.Store(config.ToolExecutionTimeout)
s.logger.Info("%s Starlark code mode initialized with binding level: %s, timeout: %v",
mcp.CodeModeLogPrefix, config.BindingLevel, config.ToolExecutionTimeout)
return s
}
// SetDependencies sets the dependencies required for code execution.
// This must be called after the MCPManager is created, as the dependencies
// include the ClientManager (which is the MCPManager itself).
func (s *StarlarkCodeMode) SetDependencies(deps *mcp.CodeModeDependencies) {
if deps != nil {
s.clientManager = deps.ClientManager
s.pluginPipelineProvider = deps.PluginPipelineProvider
s.releasePluginPipeline = deps.ReleasePluginPipeline
s.fetchNewRequestIDFunc = deps.FetchNewRequestIDFunc
s.oauth2Provider = deps.OAuth2Provider
}
}
// GetTools returns the code mode meta-tools for Starlark execution.
// These tools allow LLMs to discover, read, and execute code against MCP servers.
func (s *StarlarkCodeMode) GetTools() []schemas.ChatTool {
return []schemas.ChatTool{
s.createListToolFilesTool(),
s.createReadToolFileTool(),
s.createGetToolDocsTool(),
s.createExecuteToolCodeTool(),
}
}
// ExecuteTool handles a code mode tool call.
// It dispatches to the appropriate handler based on the tool name.
//
// Parameters:
// - ctx: Context for tool execution
// - toolCall: The tool call to execute
//
// Returns:
// - *schemas.ChatMessage: The tool response message
// - error: Any error that occurred during execution
func (s *StarlarkCodeMode) ExecuteTool(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) {
if toolCall.Function.Name == nil {
return nil, fmt.Errorf("tool call missing function name")
}
toolName := *toolCall.Function.Name
switch toolName {
case mcp.ToolTypeListToolFiles:
return s.handleListToolFiles(ctx, toolCall)
case mcp.ToolTypeReadToolFile:
return s.handleReadToolFile(ctx, toolCall)
case mcp.ToolTypeGetToolDocs:
return s.handleGetToolDocs(ctx, toolCall)
case mcp.ToolTypeExecuteToolCode:
return s.handleExecuteToolCode(ctx, toolCall)
default:
return nil, fmt.Errorf("unknown code mode tool: %s", toolName)
}
}
// IsCodeModeTool returns true if the given tool name is a code mode tool.
func (s *StarlarkCodeMode) IsCodeModeTool(toolName string) bool {
return mcp.IsCodeModeTool(toolName)
}
// GetBindingLevel returns the current code mode binding level.
func (s *StarlarkCodeMode) GetBindingLevel() schemas.CodeModeBindingLevel {
val := s.bindingLevel.Load()
if val == nil {
return schemas.CodeModeBindingLevelServer
}
return val.(schemas.CodeModeBindingLevel)
}
// UpdateConfig updates the code mode configuration atomically.
func (s *StarlarkCodeMode) UpdateConfig(config *mcp.CodeModeConfig) {
if config == nil {
return
}
if config.BindingLevel != "" {
s.bindingLevel.Store(config.BindingLevel)
}
if config.ToolExecutionTimeout > 0 {
s.toolExecutionTimeout.Store(config.ToolExecutionTimeout)
}
s.logger.Info("%s Starlark code mode configuration updated: binding level=%s, timeout=%v",
mcp.CodeModeLogPrefix, config.BindingLevel, config.ToolExecutionTimeout)
}
// getToolExecutionTimeout returns the current tool execution timeout.
func (s *StarlarkCodeMode) getToolExecutionTimeout() time.Duration {
val := s.toolExecutionTimeout.Load()
if val == nil {
return schemas.DefaultToolExecutionTimeout
}
return val.(time.Duration)
}

View File

@@ -0,0 +1,999 @@
//go:build !tinygo && !wasm
package starlark
import (
"context"
"strings"
"testing"
"time"
"github.com/bytedance/sonic"
codemcp "github.com/maximhq/bifrost/core/mcp"
"github.com/maximhq/bifrost/core/schemas"
"go.starlark.net/starlark"
"go.starlark.net/syntax"
)
type testClientManager struct {
clients map[string]*schemas.MCPClientState
tools map[string][]schemas.ChatTool
}
func (m *testClientManager) GetClientForTool(toolName string) *schemas.MCPClientState {
for clientName, tools := range m.tools {
for _, tool := range tools {
if tool.Function != nil && tool.Function.Name == toolName {
return m.clients[clientName]
}
}
}
return nil
}
func (m *testClientManager) GetClientByName(clientName string) *schemas.MCPClientState {
return m.clients[clientName]
}
func (m *testClientManager) GetToolPerClient(ctx context.Context) map[string][]schemas.ChatTool {
return m.tools
}
func TestStarlarkToGo(t *testing.T) {
t.Run("Convert None", func(t *testing.T) {
result := starlarkToGo(starlark.None)
if result != nil {
t.Errorf("Expected nil, got %v", result)
}
})
t.Run("Convert Bool", func(t *testing.T) {
result := starlarkToGo(starlark.Bool(true))
if result != true {
t.Errorf("Expected true, got %v", result)
}
})
t.Run("Convert Int", func(t *testing.T) {
result := starlarkToGo(starlark.MakeInt(42))
if result != int64(42) {
t.Errorf("Expected 42, got %v", result)
}
})
t.Run("Convert Float", func(t *testing.T) {
result := starlarkToGo(starlark.Float(3.14))
if result != 3.14 {
t.Errorf("Expected 3.14, got %v", result)
}
})
t.Run("Convert String", func(t *testing.T) {
result := starlarkToGo(starlark.String("hello"))
if result != "hello" {
t.Errorf("Expected 'hello', got %v", result)
}
})
t.Run("Convert List", func(t *testing.T) {
list := starlark.NewList([]starlark.Value{
starlark.MakeInt(1),
starlark.MakeInt(2),
starlark.MakeInt(3),
})
result := starlarkToGo(list)
arr, ok := result.([]interface{})
if !ok {
t.Errorf("Expected []interface{}, got %T", result)
}
if len(arr) != 3 {
t.Errorf("Expected length 3, got %d", len(arr))
}
if arr[0] != int64(1) {
t.Errorf("Expected first element 1, got %v", arr[0])
}
})
t.Run("Convert Dict", func(t *testing.T) {
dict := starlark.NewDict(2)
dict.SetKey(starlark.String("key1"), starlark.String("value1"))
dict.SetKey(starlark.String("key2"), starlark.MakeInt(42))
result := starlarkToGo(dict)
m, ok := result.(map[string]interface{})
if !ok {
t.Errorf("Expected map[string]interface{}, got %T", result)
}
if m["key1"] != "value1" {
t.Errorf("Expected key1='value1', got %v", m["key1"])
}
if m["key2"] != int64(42) {
t.Errorf("Expected key2=42, got %v", m["key2"])
}
})
}
func TestGoToStarlark(t *testing.T) {
t.Run("Convert nil", func(t *testing.T) {
result := goToStarlark(nil)
if result != starlark.None {
t.Errorf("Expected None, got %v", result)
}
})
t.Run("Convert bool", func(t *testing.T) {
result := goToStarlark(true)
if result != starlark.Bool(true) {
t.Errorf("Expected True, got %v", result)
}
})
t.Run("Convert int", func(t *testing.T) {
result := goToStarlark(42)
expected := starlark.MakeInt(42)
if result.String() != expected.String() {
t.Errorf("Expected %v, got %v", expected, result)
}
})
t.Run("Convert float64", func(t *testing.T) {
result := goToStarlark(3.14)
if result != starlark.Float(3.14) {
t.Errorf("Expected 3.14, got %v", result)
}
})
t.Run("Convert string", func(t *testing.T) {
result := goToStarlark("hello")
if result != starlark.String("hello") {
t.Errorf("Expected 'hello', got %v", result)
}
})
t.Run("Convert slice", func(t *testing.T) {
result := goToStarlark([]interface{}{1, "two", 3.0})
list, ok := result.(*starlark.List)
if !ok {
t.Errorf("Expected *starlark.List, got %T", result)
}
if list.Len() != 3 {
t.Errorf("Expected length 3, got %d", list.Len())
}
})
t.Run("Convert map", func(t *testing.T) {
result := goToStarlark(map[string]interface{}{
"key1": "value1",
"key2": 42,
})
dict, ok := result.(*starlark.Dict)
if !ok {
t.Errorf("Expected *starlark.Dict, got %T", result)
}
val, found, _ := dict.Get(starlark.String("key1"))
if !found {
t.Errorf("Expected key1 to exist")
}
if val != starlark.String("value1") {
t.Errorf("Expected value1, got %v", val)
}
})
}
func TestGetCanonicalToolName(t *testing.T) {
if got := getCanonicalToolName("github", "github-SEARCH_REPOS"); got != "search_repos" {
t.Fatalf("expected canonical tool name search_repos, got %q", got)
}
if got := getCanonicalToolName("math", "math-123Add!"); got != "_123add" {
t.Fatalf("expected canonical tool name _123add, got %q", got)
}
}
func TestMatchesToolReferenceSupportsCanonicalAndLegacyNames(t *testing.T) {
clientName := "github"
originalToolName := "github-SEARCH_REPOS"
testCases := []string{
"search_repos",
"SEARCH_REPOS",
}
for _, toolRef := range testCases {
if !matchesToolReference(toolRef, clientName, originalToolName) {
t.Fatalf("expected %q to match %q", toolRef, originalToolName)
}
}
}
func TestHandleListToolFilesUsesCanonicalToolIdentifiers(t *testing.T) {
mode := NewStarlarkCodeMode(&codemcp.CodeModeConfig{
BindingLevel: schemas.CodeModeBindingLevelTool,
ToolExecutionTimeout: time.Second,
}, nil)
clientName := "github"
mode.clientManager = &testClientManager{
clients: map[string]*schemas.MCPClientState{
clientName: {
Name: clientName,
ExecutionConfig: &schemas.MCPClientConfig{
Name: clientName,
IsCodeModeClient: true,
},
},
},
tools: map[string][]schemas.ChatTool{
clientName: {
{
Function: &schemas.ChatToolFunction{
Name: "github-SEARCH_REPOS",
},
},
},
},
}
msg, err := mode.handleListToolFiles(context.Background(), schemas.ChatAssistantMessageToolCall{
ID: schemas.Ptr("tool-call-1"),
})
if err != nil {
t.Fatalf("handleListToolFiles returned error: %v", err)
}
if msg == nil || msg.Content == nil || msg.Content.ContentStr == nil {
t.Fatal("expected tool response content")
}
content := *msg.Content.ContentStr
if !strings.Contains(content, "search_repos.pyi") {
t.Fatalf("expected canonical tool file path in response, got:\n%s", content)
}
if strings.Contains(content, "SEARCH_REPOS.pyi") {
t.Fatalf("did not expect raw uppercase tool file path in response, got:\n%s", content)
}
if !strings.Contains(content, "readToolFile before executeToolCode") {
t.Fatalf("expected workflow guidance in response, got:\n%s", content)
}
}
func TestGeneratePythonErrorHints(t *testing.T) {
serverKeys := []string{"calculator", "weather"}
t.Run("Undefined variable hint", func(t *testing.T) {
hints := generatePythonErrorHints("name 'foo' is not defined", serverKeys)
if len(hints) == 0 {
t.Error("Expected hints, got none")
}
found := false
for _, hint := range hints {
if strings.Contains(hint, "Variable 'foo' is not defined.") {
found = true
break
}
}
if !found {
t.Errorf("Expected exact undefined variable hint for foo, got: %v", hints)
}
})
t.Run("Syntax error hint", func(t *testing.T) {
hints := generatePythonErrorHints("syntax error at line 5", serverKeys)
if len(hints) == 0 {
t.Error("Expected hints, got none")
}
found := false
for _, hint := range hints {
if containsAny(hint, "syntax", "indentation", "colon") {
found = true
break
}
}
if !found {
t.Error("Expected hint about syntax error")
}
})
t.Run("Attribute error hint", func(t *testing.T) {
hints := generatePythonErrorHints("'dict' object has no attribute 'foo'", serverKeys)
if len(hints) == 0 {
t.Error("Expected hints, got none")
}
found := false
for _, hint := range hints {
if containsAny(hint, "attribute", "brackets", "key") {
found = true
break
}
}
if !found {
t.Error("Expected hint about attribute access")
}
})
}
func containsAny(s string, substrs ...string) bool {
for _, sub := range substrs {
if containsIgnoreCase(s, sub) {
return true
}
}
return false
}
func containsIgnoreCase(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && (containsIgnoreCase(s[1:], substr) || (len(s) >= len(substr) && equalFold(s[:len(substr)], substr))))
}
func equalFold(a, b string) bool {
if len(a) != len(b) {
return false
}
for i := 0; i < len(a); i++ {
ca, cb := a[i], b[i]
if ca >= 'A' && ca <= 'Z' {
ca += 'a' - 'A'
}
if cb >= 'A' && cb <= 'Z' {
cb += 'a' - 'A'
}
if ca != cb {
return false
}
}
return true
}
func TestExtractResultFromResponsesMessage(t *testing.T) {
t.Run("Extract error from ResponsesMessage", func(t *testing.T) {
errorMsg := "Tool is not allowed by security policy: dangerous_tool"
msg := &schemas.ResponsesMessage{
ResponsesToolMessage: &schemas.ResponsesToolMessage{
Error: &errorMsg,
},
}
result, err := extractResultFromResponsesMessage(msg)
if err == nil {
t.Errorf("Expected error, got nil")
}
if err.Error() != errorMsg {
t.Errorf("Expected error message '%s', got '%s'", errorMsg, err.Error())
}
if result != nil {
t.Errorf("Expected nil result when error is present, got %v", result)
}
})
t.Run("Extract string output from ResponsesMessage", func(t *testing.T) {
outputStr := "success result"
msg := &schemas.ResponsesMessage{
ResponsesToolMessage: &schemas.ResponsesToolMessage{
Output: &schemas.ResponsesToolMessageOutputStruct{
ResponsesToolCallOutputStr: &outputStr,
},
},
}
result, err := extractResultFromResponsesMessage(msg)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if result != outputStr {
t.Errorf("Expected result '%s', got '%v'", outputStr, result)
}
})
t.Run("Extract JSON output from ResponsesMessage", func(t *testing.T) {
outputStr := `{"status": "success", "data": "test"}`
msg := &schemas.ResponsesMessage{
ResponsesToolMessage: &schemas.ResponsesToolMessage{
Output: &schemas.ResponsesToolMessageOutputStruct{
ResponsesToolCallOutputStr: &outputStr,
},
},
}
result, err := extractResultFromResponsesMessage(msg)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
resultMap, ok := result.(map[string]interface{})
if !ok {
t.Errorf("Expected map, got %T", result)
}
if resultMap["status"] != "success" {
t.Errorf("Expected status 'success', got '%v'", resultMap["status"])
}
})
t.Run("Extract from ResponsesFunctionToolCallOutputBlocks", func(t *testing.T) {
text1 := "First block"
text2 := "Second block"
msg := &schemas.ResponsesMessage{
ResponsesToolMessage: &schemas.ResponsesToolMessage{
Output: &schemas.ResponsesToolMessageOutputStruct{
ResponsesFunctionToolCallOutputBlocks: []schemas.ResponsesMessageContentBlock{
{Text: &text1},
{Text: &text2},
},
},
},
}
result, err := extractResultFromResponsesMessage(msg)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
expectedResult := "First block\nSecond block"
if result != expectedResult {
t.Errorf("Expected result '%s', got '%v'", expectedResult, result)
}
})
t.Run("Extract JSON from ResponsesFunctionToolCallOutputBlocks", func(t *testing.T) {
jsonText := `{"key": "value"}`
msg := &schemas.ResponsesMessage{
ResponsesToolMessage: &schemas.ResponsesToolMessage{
Output: &schemas.ResponsesToolMessageOutputStruct{
ResponsesFunctionToolCallOutputBlocks: []schemas.ResponsesMessageContentBlock{
{Text: &jsonText},
},
},
},
}
result, err := extractResultFromResponsesMessage(msg)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
resultMap, ok := result.(map[string]interface{})
if !ok {
t.Errorf("Expected map, got %T", result)
}
if resultMap["key"] != "value" {
t.Errorf("Expected key 'value', got '%v'", resultMap["key"])
}
})
t.Run("Handle nil message", func(t *testing.T) {
result, err := extractResultFromResponsesMessage(nil)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if result != nil {
t.Errorf("Expected nil result for nil message, got %v", result)
}
})
t.Run("Handle message without ResponsesToolMessage", func(t *testing.T) {
msg := &schemas.ResponsesMessage{}
result, err := extractResultFromResponsesMessage(msg)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if result != nil {
t.Errorf("Expected nil result for message without tool message, got %v", result)
}
})
t.Run("Handle empty error string (should not error)", func(t *testing.T) {
emptyError := ""
msg := &schemas.ResponsesMessage{
ResponsesToolMessage: &schemas.ResponsesToolMessage{
Error: &emptyError,
},
}
result, err := extractResultFromResponsesMessage(msg)
if err != nil {
t.Errorf("Expected no error for empty error string, got: %v", err)
}
if result != nil {
t.Errorf("Expected nil result for empty error string, got %v", result)
}
})
}
func TestExtractResultFromChatMessage(t *testing.T) {
t.Run("Extract string from ChatMessage", func(t *testing.T) {
content := "test result"
msg := &schemas.ChatMessage{
Content: &schemas.ChatMessageContent{
ContentStr: &content,
},
}
result := extractResultFromChatMessage(msg)
if result != content {
t.Errorf("Expected result '%s', got '%v'", content, result)
}
})
t.Run("Extract JSON from ChatMessage", func(t *testing.T) {
content := `{"status": "ok"}`
msg := &schemas.ChatMessage{
Content: &schemas.ChatMessageContent{
ContentStr: &content,
},
}
result := extractResultFromChatMessage(msg)
resultMap, ok := result.(map[string]interface{})
if !ok {
t.Errorf("Expected map, got %T", result)
}
if resultMap["status"] != "ok" {
t.Errorf("Expected status 'ok', got '%v'", resultMap["status"])
}
})
t.Run("Handle nil ChatMessage", func(t *testing.T) {
result := extractResultFromChatMessage(nil)
if result != nil {
t.Errorf("Expected nil result for nil message, got %v", result)
}
})
t.Run("Handle ChatMessage without Content", func(t *testing.T) {
msg := &schemas.ChatMessage{}
result := extractResultFromChatMessage(msg)
if result != nil {
t.Errorf("Expected nil result for message without content, got %v", result)
}
})
}
func TestFormatResultForLog(t *testing.T) {
t.Run("Format nil result", func(t *testing.T) {
result := formatResultForLog(nil)
if result != "null" {
t.Errorf("Expected 'null', got '%s'", result)
}
})
t.Run("Format string result", func(t *testing.T) {
result := formatResultForLog("test string")
if result != `"test string"` {
t.Errorf("Expected '\"test string\"', got '%s'", result)
}
})
t.Run("Format map result", func(t *testing.T) {
input := map[string]interface{}{"key": "value"}
result := formatResultForLog(input)
// Parse it back to verify it's valid JSON
var parsed map[string]interface{}
err := sonic.Unmarshal([]byte(result), &parsed)
if err != nil {
t.Errorf("Result is not valid JSON: %v", err)
}
if parsed["key"] != "value" {
t.Errorf("Expected key 'value', got '%v'", parsed["key"])
}
})
t.Run("Truncate long result", func(t *testing.T) {
longString := ""
for i := 0; i < 300; i++ {
longString += "a"
}
result := formatResultForLog(longString)
if len(result) > 200 {
// Should be truncated to around 200 chars (plus quotes and ellipsis)
t.Logf("Result length: %d (truncated as expected)", len(result))
}
})
}
// starlarkOpts returns the FileOptions used by the code mode executor.
// Kept in sync with executecode.go to test the same dialect configuration.
func starlarkOpts() *syntax.FileOptions {
return &syntax.FileOptions{
TopLevelControl: true,
While: true,
Set: true,
GlobalReassign: true,
Recursion: true,
}
}
// execStarlark is a test helper that executes Starlark code with our dialect options
// and returns the globals and any error.
func execStarlark(code string) (starlark.StringDict, error) {
thread := &starlark.Thread{Name: "test"}
return starlark.ExecFileOptions(starlarkOpts(), thread, "test.star", code, nil)
}
func TestStarlarkDialectOptions(t *testing.T) {
t.Run("Top-level for loop", func(t *testing.T) {
code := `
items = []
for i in range(3):
items.append(i)
result = items
`
globals, err := execStarlark(code)
if err != nil {
t.Fatalf("Top-level for loop should work, got error: %v", err)
}
resultVal := globals["result"]
list, ok := resultVal.(*starlark.List)
if !ok {
t.Fatalf("Expected list, got %T", resultVal)
}
if list.Len() != 3 {
t.Errorf("Expected 3 items, got %d", list.Len())
}
})
t.Run("Top-level if statement", func(t *testing.T) {
code := `
x = 10
if x > 5:
result = "big"
else:
result = "small"
`
globals, err := execStarlark(code)
if err != nil {
t.Fatalf("Top-level if should work, got error: %v", err)
}
if globals["result"] != starlark.String("big") {
t.Errorf("Expected 'big', got %v", globals["result"])
}
})
t.Run("Top-level while loop", func(t *testing.T) {
code := `
count = 0
while count < 5:
count += 1
result = count
`
globals, err := execStarlark(code)
if err != nil {
t.Fatalf("Top-level while loop should work, got error: %v", err)
}
resultVal := globals["result"]
if resultVal.String() != "5" {
t.Errorf("Expected 5, got %v", resultVal)
}
})
t.Run("While loop inside function", func(t *testing.T) {
code := `
def countdown(n):
items = []
while n > 0:
items.append(n)
n -= 1
return items
result = countdown(3)
`
globals, err := execStarlark(code)
if err != nil {
t.Fatalf("While in function should work, got error: %v", err)
}
list := globals["result"].(*starlark.List)
if list.Len() != 3 {
t.Errorf("Expected 3 items, got %d", list.Len())
}
})
t.Run("set() builtin", func(t *testing.T) {
code := `
s = set([1, 2, 3, 2, 1])
result = len(s)
`
globals, err := execStarlark(code)
if err != nil {
t.Fatalf("set() should work, got error: %v", err)
}
if globals["result"].String() != "3" {
t.Errorf("Expected 3 unique items, got %v", globals["result"])
}
})
t.Run("Global variable reassignment", func(t *testing.T) {
code := `
x = 1
x = x + 1
x = x * 3
result = x
`
globals, err := execStarlark(code)
if err != nil {
t.Fatalf("Global reassignment should work, got error: %v", err)
}
if globals["result"].String() != "6" {
t.Errorf("Expected 6, got %v", globals["result"])
}
})
t.Run("Recursive function", func(t *testing.T) {
code := `
def factorial(n):
if n <= 1:
return 1
return n * factorial(n - 1)
result = factorial(5)
`
globals, err := execStarlark(code)
if err != nil {
t.Fatalf("Recursion should work, got error: %v", err)
}
if globals["result"].String() != "120" {
t.Errorf("Expected 120, got %v", globals["result"])
}
})
}
func TestStarlarkStringEscapePreservation(t *testing.T) {
t.Run("Backslash-n in string literal preserved", func(t *testing.T) {
// Simulate what happens after JSON deserialization:
// Model writes: {"code": "msg = \"hello\\nworld\""}
// sonic.Unmarshal produces: msg = "hello\nworld" (where \n is two chars: \ + n)
// Starlark should interpret \n as newline escape inside the string
code := "msg = \"hello\\nworld\"\nresult = msg"
globals, err := execStarlark(code)
if err != nil {
t.Fatalf("String with \\n escape should work, got error: %v", err)
}
resultStr := string(globals["result"].(starlark.String))
if resultStr != "hello\nworld" {
t.Errorf("Expected 'hello<newline>world', got %q", resultStr)
}
})
t.Run("Multiple escape sequences in strings", func(t *testing.T) {
code := "msg = \"col1\\tcol2\\nrow1\\trow2\"\nresult = msg"
globals, err := execStarlark(code)
if err != nil {
t.Fatalf("String with multiple escapes should work, got error: %v", err)
}
resultStr := string(globals["result"].(starlark.String))
if resultStr != "col1\tcol2\nrow1\trow2" {
t.Errorf("Expected tab/newline escapes, got %q", resultStr)
}
})
t.Run("Newline join pattern", func(t *testing.T) {
// This is the exact pattern that failed 7 times in benchmarks
code := `
def main():
lines = ["line1", "line2", "line3"]
content = "\n".join(lines)
return content
result = main()
`
globals, err := execStarlark(code)
if err != nil {
t.Fatalf("Newline join pattern should work, got error: %v", err)
}
resultStr := string(globals["result"].(starlark.String))
if resultStr != "line1\nline2\nline3" {
t.Errorf("Expected joined lines, got %q", resultStr)
}
})
t.Run("chr() for newline", func(t *testing.T) {
code := `
nl = chr(10)
result = "hello" + nl + "world"
`
globals, err := execStarlark(code)
if err != nil {
t.Fatalf("chr(10) should work, got error: %v", err)
}
resultStr := string(globals["result"].(starlark.String))
if resultStr != "hello\nworld" {
t.Errorf("Expected 'hello<newline>world', got %q", resultStr)
}
})
t.Run("Triple-quoted strings", func(t *testing.T) {
code := "result = \"\"\"line1\nline2\nline3\"\"\""
globals, err := execStarlark(code)
if err != nil {
t.Fatalf("Triple-quoted string should work, got error: %v", err)
}
resultStr := string(globals["result"].(starlark.String))
if resultStr != "line1\nline2\nline3" {
t.Errorf("Expected multiline string, got %q", resultStr)
}
})
t.Run("Raw string preserves backslash", func(t *testing.T) {
code := "result = r\"hello\\nworld\""
globals, err := execStarlark(code)
if err != nil {
t.Fatalf("Raw string should work, got error: %v", err)
}
resultStr := string(globals["result"].(starlark.String))
// Raw string: \n stays as two characters \ and n
if resultStr != "hello\\nworld" {
t.Errorf("Expected literal backslash-n, got %q", resultStr)
}
})
t.Run("JSON deserialization then Starlark execution", func(t *testing.T) {
// End-to-end: simulate the exact flow from model JSON → sonic.Unmarshal → Starlark
jsonArgs := `{"code": "lines = [\"a\", \"b\", \"c\"]\nresult = \"\\n\".join(lines)"}`
var arguments map[string]interface{}
err := sonic.Unmarshal([]byte(jsonArgs), &arguments)
if err != nil {
t.Fatalf("JSON unmarshal failed: %v", err)
}
code := arguments["code"].(string)
globals, starlarkErr := execStarlark(code)
if starlarkErr != nil {
t.Fatalf("Starlark execution failed: %v", starlarkErr)
}
resultStr := string(globals["result"].(starlark.String))
if resultStr != "a\nb\nc" {
t.Errorf("Expected 'a<newline>b<newline>c', got %q", resultStr)
}
})
}
func TestStarlarkUnsupportedFeatures(t *testing.T) {
t.Run("try/except rejected", func(t *testing.T) {
code := `
def main():
try:
x = 1
except:
x = 0
result = main()
`
_, err := execStarlark(code)
if err == nil {
t.Fatal("try/except should be rejected by Starlark")
}
if !strings.Contains(err.Error(), "got try") {
t.Errorf("Expected 'got try' in error, got: %v", err)
}
})
t.Run("raise rejected", func(t *testing.T) {
code := `raise ValueError("test")`
_, err := execStarlark(code)
if err == nil {
t.Fatal("raise should be rejected by Starlark")
}
})
t.Run("class rejected", func(t *testing.T) {
code := `
class Foo:
pass
`
_, err := execStarlark(code)
if err == nil {
t.Fatal("class should be rejected by Starlark")
}
})
t.Run("import rejected", func(t *testing.T) {
code := `import json`
_, err := execStarlark(code)
if err == nil {
t.Fatal("import should be rejected by Starlark")
}
})
}
func TestGeneratePythonErrorHintsNewCases(t *testing.T) {
serverKeys := []string{"Github", "SqLite"}
t.Run("try/except hint", func(t *testing.T) {
hints := generatePythonErrorHints("code.star:3:9: got try, want primary expression", serverKeys)
if len(hints) == 0 {
t.Fatal("Expected hints for try/except error")
}
found := false
for _, hint := range hints {
if containsAny(hint, "try/except", "exception handling") {
found = true
break
}
}
if !found {
t.Errorf("Expected hint about try/except not being supported, got: %v", hints)
}
})
t.Run("except hint", func(t *testing.T) {
hints := generatePythonErrorHints("code.star:5:9: got except, want primary expression", serverKeys)
if len(hints) == 0 {
t.Fatal("Expected hints for except error")
}
found := false
for _, hint := range hints {
if containsAny(hint, "try/except", "exception handling") {
found = true
break
}
}
if !found {
t.Errorf("Expected hint about exception handling, got: %v", hints)
}
})
t.Run("finally hint", func(t *testing.T) {
hints := generatePythonErrorHints("code.star:7:9: got finally, want primary expression", serverKeys)
if len(hints) == 0 {
t.Fatal("Expected hints for finally error")
}
found := false
for _, hint := range hints {
if containsAny(hint, "try/except", "exception handling") {
found = true
break
}
}
if !found {
t.Errorf("Expected hint about exception handling, got: %v", hints)
}
})
t.Run("raise hint", func(t *testing.T) {
hints := generatePythonErrorHints("code.star:2:1: got raise, want primary expression", serverKeys)
if len(hints) == 0 {
t.Fatal("Expected hints for raise error")
}
found := false
for _, hint := range hints {
if containsAny(hint, "try/except", "exception handling") {
found = true
break
}
}
if !found {
t.Errorf("Expected hint about exception handling, got: %v", hints)
}
})
t.Run("Undefined variable includes scope hint", func(t *testing.T) {
hints := generatePythonErrorHints("code.star:3:17: undefined: commits_n8n", serverKeys)
if len(hints) == 0 {
t.Fatal("Expected hints for undefined variable")
}
foundVar := false
foundScope := false
for _, hint := range hints {
if strings.Contains(hint, "Variable 'commits_n8n' is not defined.") {
foundVar = true
}
if containsAny(hint, "fresh scope", "persist") {
foundScope = true
}
}
if !foundVar {
t.Errorf("Expected exact undefined variable hint for commits_n8n, got: %v", hints)
}
if !foundScope {
t.Errorf("Expected scope persistence hint, got: %v", hints)
}
})
}

View File

@@ -0,0 +1,443 @@
//go:build !tinygo && !wasm
package starlark
import (
"encoding/json"
"fmt"
"regexp"
"strings"
"unicode"
"github.com/bytedance/sonic"
"github.com/mark3labs/mcp-go/mcp"
"github.com/maximhq/bifrost/core/schemas"
"go.starlark.net/starlark"
"go.starlark.net/starlarkstruct"
)
// starlarkToGo converts a Starlark value to a Go value
func starlarkToGo(v starlark.Value) interface{} {
switch val := v.(type) {
case starlark.NoneType:
return nil
case starlark.Bool:
return bool(val)
case starlark.Int:
if i, ok := val.Int64(); ok {
return i
}
if i, ok := val.Uint64(); ok {
return i
}
return val.String()
case starlark.Float:
return float64(val)
case starlark.String:
return string(val)
case *starlark.List:
result := make([]interface{}, val.Len())
for i := 0; i < val.Len(); i++ {
result[i] = starlarkToGo(val.Index(i))
}
return result
case starlark.Tuple:
result := make([]interface{}, len(val))
for i, item := range val {
result[i] = starlarkToGo(item)
}
return result
case *starlark.Dict:
result := make(map[string]interface{})
for _, item := range val.Items() {
if keyStr, ok := item[0].(starlark.String); ok {
result[string(keyStr)] = starlarkToGo(item[1])
} else {
// Use string representation for non-string keys
result[item[0].String()] = starlarkToGo(item[1])
}
}
return result
case *starlarkstruct.Struct:
result := make(map[string]interface{})
for _, name := range val.AttrNames() {
if attrVal, err := val.Attr(name); err == nil {
result[name] = starlarkToGo(attrVal)
}
}
return result
default:
return val.String()
}
}
// goToStarlark converts a Go value to a Starlark value
func goToStarlark(v interface{}) starlark.Value {
if v == nil {
return starlark.None
}
switch val := v.(type) {
case bool:
return starlark.Bool(val)
case int:
return starlark.MakeInt(val)
case int64:
return starlark.MakeInt64(val)
case uint64:
return starlark.MakeUint64(val)
case float64:
return starlark.Float(val)
case string:
return starlark.String(val)
case []interface{}:
items := make([]starlark.Value, len(val))
for i, item := range val {
items[i] = goToStarlark(item)
}
return starlark.NewList(items)
case map[string]interface{}:
dict := starlark.NewDict(len(val))
for k, v := range val {
dict.SetKey(starlark.String(k), goToStarlark(v))
}
return dict
default:
// Try to marshal to JSON and parse as a generic structure
if jsonBytes, err := schemas.MarshalSorted(val); err == nil {
var generic interface{}
if schemas.Unmarshal(jsonBytes, &generic) == nil {
return goToStarlark(generic)
}
}
return starlark.String(fmt.Sprintf("%v", val))
}
}
// extractResultFromChatMessage extracts the result from a chat message and parses it as JSON if possible.
func extractResultFromChatMessage(msg *schemas.ChatMessage) interface{} {
if msg == nil || msg.Content == nil || msg.Content.ContentStr == nil {
return nil
}
rawResult := *msg.Content.ContentStr
var finalResult interface{}
if err := sonic.Unmarshal([]byte(rawResult), &finalResult); err != nil {
return rawResult
}
return finalResult
}
// extractResultFromResponsesMessage extracts the result or error from a ResponsesMessage.
func extractResultFromResponsesMessage(msg *schemas.ResponsesMessage) (interface{}, error) {
if msg == nil {
return nil, nil
}
if msg.ResponsesToolMessage != nil {
if msg.ResponsesToolMessage.Error != nil && *msg.ResponsesToolMessage.Error != "" {
return nil, fmt.Errorf("%s", *msg.ResponsesToolMessage.Error)
}
if msg.ResponsesToolMessage.Output != nil {
if msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != nil {
rawResult := *msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr
var finalResult interface{}
if err := sonic.Unmarshal([]byte(rawResult), &finalResult); err != nil {
return rawResult, nil
}
return finalResult, nil
}
if len(msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks) > 0 {
var textParts []string
for _, block := range msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks {
if block.Text != nil {
textParts = append(textParts, *block.Text)
}
}
if len(textParts) > 0 {
result := strings.Join(textParts, "\n")
var finalResult interface{}
if err := sonic.Unmarshal([]byte(result), &finalResult); err != nil {
return result, nil
}
return finalResult, nil
}
}
}
}
return nil, nil
}
// formatResultForLog formats a result value for logging purposes.
func formatResultForLog(result interface{}) string {
var resultStr string
if result == nil {
resultStr = "null"
} else if resultBytes, err := schemas.MarshalSorted(result); err == nil {
resultStr = string(resultBytes)
} else {
resultStr = fmt.Sprintf("%v", result)
}
return resultStr
}
// generatePythonErrorHints generates helpful hints for Python/Starlark errors.
func generatePythonErrorHints(errorMessage string, serverKeys []string) []string {
hints := []string{}
if strings.Contains(errorMessage, "got try") || strings.Contains(errorMessage, "got except") ||
strings.Contains(errorMessage, "got finally") || strings.Contains(errorMessage, "got raise") {
hints = append(hints, "Starlark does NOT support try/except/finally/raise — there is no exception handling.")
hints = append(hints, "Instead, check return values for errors:")
hints = append(hints, " result = server.tool(param=\"value\")")
hints = append(hints, " if result == None or (type(result) == \"dict\" and \"error\" in result):")
hints = append(hints, " print(\"Error:\", result)")
} else if strings.Contains(errorMessage, "undefined") || strings.Contains(errorMessage, "not defined") {
var undefinedVar string
if match := regexp.MustCompile(`name ['"]([^'"]+)['"] is not defined`).FindStringSubmatch(errorMessage); len(match) > 1 {
undefinedVar = match[1]
} else if match := regexp.MustCompile(`undefined:\s*([A-Za-z_][A-Za-z0-9_]*)`).FindStringSubmatch(errorMessage); len(match) > 1 {
undefinedVar = match[1]
} else if match := regexp.MustCompile(`([A-Za-z_][A-Za-z0-9_]*)[^A-Za-z0-9_]+(?:undefined|not defined)`).FindStringSubmatch(errorMessage); len(match) > 1 {
undefinedVar = match[1]
}
if undefinedVar != "" {
hints = append(hints, fmt.Sprintf("Variable '%s' is not defined.", undefinedVar))
hints = append(hints, "Note: Each executeToolCode call runs in a fresh scope — no variables persist between calls.")
if len(serverKeys) > 0 {
hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", ")))
hints = append(hints, "Access tools using: server_name.tool_name(param=\"value\")")
}
}
} else if strings.Contains(errorMessage, "not within a function") {
hints = append(hints, "Starlark requires for/if/while statements to be inside functions at the top level.")
hints = append(hints, "Wrap your code in a function, then call it:")
hints = append(hints, " def fetch_all():")
hints = append(hints, " results = []")
hints = append(hints, " for id in ids:")
hints = append(hints, " results.append(server.get(id=id))")
hints = append(hints, " return results")
hints = append(hints, " result = fetch_all()")
} else if strings.Contains(errorMessage, "syntax error") {
hints = append(hints, "Python syntax error detected.")
hints = append(hints, "Check for proper indentation (use spaces, not tabs).")
hints = append(hints, "Ensure colons after if/for/def statements.")
hints = append(hints, "Check for matching parentheses and brackets.")
} else if strings.Contains(errorMessage, "has no") && strings.Contains(errorMessage, "attribute") {
hints = append(hints, "You're trying to access an attribute that doesn't exist.")
hints = append(hints, "Use dict access syntax: result[\"key\"] instead of result.key")
hints = append(hints, "Use print(result) to see the actual structure.")
if len(serverKeys) > 0 {
hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", ")))
}
} else if strings.Contains(errorMessage, "not callable") {
hints = append(hints, "You're trying to call something that is not a function.")
hints = append(hints, "Ensure you're using the correct tool name.")
if len(serverKeys) > 0 {
hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", ")))
}
hints = append(hints, "Use readToolFile to see available tools for a server.")
} else if strings.Contains(errorMessage, "key") && strings.Contains(errorMessage, "not found") {
hints = append(hints, "Dictionary key not found.")
hints = append(hints, "Use print() to inspect the dict structure before accessing keys.")
hints = append(hints, "Use .get(\"key\", default) for safe access.")
} else {
hints = append(hints, "Check the error message above for details.")
if len(serverKeys) > 0 {
hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", ")))
}
hints = append(hints, "Use: result = server_name.tool_name(param=\"value\")")
hints = append(hints, "Access dict values with brackets: result[\"key\"]")
}
return hints
}
// extractTextFromMCPResponse extracts text content from an MCP tool response.
func extractTextFromMCPResponse(toolResponse *mcp.CallToolResult, toolName string) string {
if toolResponse == nil {
return fmt.Sprintf("MCP tool '%s' executed successfully", toolName)
}
var result strings.Builder
for _, contentBlock := range toolResponse.Content {
// Handle typed content
switch content := contentBlock.(type) {
case mcp.TextContent:
result.WriteString(content.Text)
case mcp.ImageContent:
result.WriteString(fmt.Sprintf("[Image Response: %s, MIME: %s]\n", content.Data, content.MIMEType))
case mcp.AudioContent:
result.WriteString(fmt.Sprintf("[Audio Response: %s, MIME: %s]\n", content.Data, content.MIMEType))
case mcp.EmbeddedResource:
result.WriteString(fmt.Sprintf("[Embedded Resource Response: %s]\n", content.Type))
default:
// Fallback: try to extract from map structure
if jsonBytes, err := schemas.MarshalSorted(contentBlock); err == nil {
var contentMap map[string]interface{}
if json.Unmarshal(jsonBytes, &contentMap) == nil {
if text, ok := contentMap["text"].(string); ok {
result.WriteString(fmt.Sprintf("[Text Response: %s]\n", text))
continue
}
}
// Final fallback: serialize as JSON
result.WriteString(string(jsonBytes))
}
}
}
if result.Len() > 0 {
return strings.TrimSpace(result.String())
}
return fmt.Sprintf("MCP tool '%s' executed successfully", toolName)
}
// createToolResponseMessage creates a tool response message with the execution result.
func createToolResponseMessage(toolCall schemas.ChatAssistantMessageToolCall, responseText string) *schemas.ChatMessage {
return &schemas.ChatMessage{
Role: schemas.ChatMessageRoleTool,
Content: &schemas.ChatMessageContent{
ContentStr: &responseText,
},
ChatToolMessage: &schemas.ChatToolMessage{
ToolCallID: toolCall.ID,
},
}
}
// parseToolName normalizes a raw tool name into a Starlark-compatible identifier.
func parseToolName(toolName string) string {
if toolName == "" {
return ""
}
var result strings.Builder
runes := []rune(toolName)
// Process first character - must be letter, underscore, or dollar sign
if len(runes) > 0 {
first := runes[0]
if unicode.IsLetter(first) || first == '_' || first == '$' {
result.WriteRune(unicode.ToLower(first))
} else {
// If first char is invalid, prefix with underscore
result.WriteRune('_')
if unicode.IsDigit(first) {
result.WriteRune(first)
}
}
}
// Process remaining characters
for i := 1; i < len(runes); i++ {
r := runes[i]
if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '$' {
result.WriteRune(unicode.ToLower(r))
} else if unicode.IsSpace(r) || r == '-' {
// Replace spaces and hyphens with single underscore
// Avoid consecutive underscores
if result.Len() > 0 && result.String()[result.Len()-1] != '_' {
result.WriteRune('_')
}
}
// Skip other invalid characters
}
parsed := result.String()
// Remove trailing underscores
parsed = strings.TrimRight(parsed, "_")
// Ensure we have at least one character
if parsed == "" {
return "tool"
}
return parsed
}
// getCanonicalToolName returns the exact callable tool identifier exposed in Starlark.
func getCanonicalToolName(clientName, originalToolName string) string {
return parseToolName(stripClientPrefix(originalToolName, clientName))
}
// getCompatibilityToolAlias returns the case-preserving alias derived from the raw tool name.
// This is used as a compatibility alias when the raw name is still a valid Starlark identifier.
func getCompatibilityToolAlias(clientName, originalToolName string) string {
return strings.ReplaceAll(stripClientPrefix(originalToolName, clientName), "-", "_")
}
// matchesToolReference reports whether the requested tool name matches any supported identifier form.
// We accept the canonical callable name plus legacy display forms for backward compatibility.
func matchesToolReference(requestedToolName, clientName, originalToolName string) bool {
requested := strings.ToLower(requestedToolName)
if requested == "" {
return false
}
candidates := []string{
getCanonicalToolName(clientName, originalToolName),
getCompatibilityToolAlias(clientName, originalToolName),
stripClientPrefix(originalToolName, clientName),
}
for _, candidate := range candidates {
if candidate != "" && requested == strings.ToLower(candidate) {
return true
}
}
return false
}
// isValidStarlarkIdentifier reports whether name can be used directly in Starlark code.
func isValidStarlarkIdentifier(name string) bool {
if name == "" {
return false
}
runes := []rune(name)
first := runes[0]
if !unicode.IsLetter(first) && first != '_' && first != '$' {
return false
}
for _, r := range runes[1:] {
if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' && r != '$' {
return false
}
}
return true
}
// validateNormalizedToolName validates a normalized tool name to prevent path traversal.
func validateNormalizedToolName(normalizedName string) error {
if normalizedName == "" {
return fmt.Errorf("tool name cannot be empty after normalization")
}
if strings.Contains(normalizedName, "/") {
return fmt.Errorf("tool name cannot contain '/' (path separator) after normalization: %s", normalizedName)
}
if strings.Contains(normalizedName, "..") {
return fmt.Errorf("tool name cannot contain '..' (path traversal) after normalization: %s", normalizedName)
}
return nil
}
// stripClientPrefix removes the client name prefix from a tool name.
func stripClientPrefix(prefixedToolName, clientName string) string {
prefix := clientName + "-"
if strings.HasPrefix(prefixedToolName, prefix) {
return strings.TrimPrefix(prefixedToolName, prefix)
}
// If prefix doesn't match, return as-is (shouldn't happen, but be safe)
return prefixedToolName
}