444 lines
15 KiB
Go
444 lines
15 KiB
Go
//go:build !tinygo && !wasm
|
|
|
|
package starlark
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
|
|
codemcp "github.com/maximhq/bifrost/core/mcp"
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
)
|
|
|
|
// createReadToolFileTool creates the readToolFile tool definition for code mode.
|
|
// This tool allows reading virtual .pyi stub files for specific MCP servers/tools,
|
|
// generating Python type stubs from the server's tool schemas.
|
|
func (s *StarlarkCodeMode) createReadToolFileTool() schemas.ChatTool {
|
|
bindingLevel := s.GetBindingLevel()
|
|
|
|
var fileNameDescription, toolDescription string
|
|
|
|
if bindingLevel == schemas.CodeModeBindingLevelServer {
|
|
fileNameDescription = "The virtual filename from listToolFiles in format: servers/<serverName>.pyi (e.g., 'servers/calculator.pyi')"
|
|
toolDescription = "Reads a virtual .pyi stub file for a specific MCP server, returning compact Python function signatures " +
|
|
"for all tools available on that server. The fileName should be in format servers/<serverName>.pyi as listed by listToolFiles. " +
|
|
"The function performs case-insensitive matching and removes the .pyi extension. " +
|
|
"This is the authoritative source for the exact callable tool names and parameters to use in executeToolCode. " +
|
|
"Each tool can be accessed in code via: serverName.tool_name(param=value). " +
|
|
"If the compact signature is not enough to understand a tool, use getToolDocs for detailed documentation. " +
|
|
"Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " +
|
|
"IMPORTANT: If the response header shows 'Total lines: X (this is the complete file)', " +
|
|
"do NOT call this tool again with startLine/endLine - you already have the complete file."
|
|
} else {
|
|
fileNameDescription = "The virtual filename from listToolFiles in format: servers/<serverName>/<toolName>.pyi (e.g., 'servers/calculator/add.pyi')"
|
|
toolDescription = "Reads a virtual .pyi stub file for a specific tool, returning its compact Python function signature. " +
|
|
"The fileName should be in format servers/<serverName>/<toolName>.pyi as listed by listToolFiles. " +
|
|
"The function performs case-insensitive matching and removes the .pyi extension. " +
|
|
"This is the authoritative source for the exact callable tool name and arguments to use in executeToolCode. " +
|
|
"The tool can be accessed in code via: serverName.tool_name(param=value) using the def name shown in the file. " +
|
|
"If the compact signature is not enough to understand the tool, use getToolDocs for detailed documentation. " +
|
|
"Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " +
|
|
"IMPORTANT: If the response header shows 'Total lines: X (this is the complete file)', " +
|
|
"do NOT call this tool again with startLine/endLine - you already have the complete file."
|
|
}
|
|
|
|
readToolFileProps := schemas.NewOrderedMapFromPairs(
|
|
schemas.KV("fileName", map[string]interface{}{
|
|
"type": "string",
|
|
"description": fileNameDescription,
|
|
}),
|
|
schemas.KV("startLine", map[string]interface{}{
|
|
"type": "number",
|
|
"description": "Optional 1-based starting line number for partial file read. Usually not needed - omit to read the entire file. Files are typically small (under 50 lines).",
|
|
}),
|
|
schemas.KV("endLine", map[string]interface{}{
|
|
"type": "number",
|
|
"description": "Optional 1-based ending line number for partial file read. Usually not needed - omit to read the entire file. Will be clamped to actual file size if too large.",
|
|
}),
|
|
)
|
|
return schemas.ChatTool{
|
|
Type: schemas.ChatToolTypeFunction,
|
|
Function: &schemas.ChatToolFunction{
|
|
Name: codemcp.ToolTypeReadToolFile,
|
|
Description: schemas.Ptr(toolDescription),
|
|
Parameters: &schemas.ToolFunctionParameters{
|
|
Type: "object",
|
|
Properties: readToolFileProps,
|
|
Required: []string{"fileName"},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
// handleReadToolFile handles the readToolFile tool call.
|
|
func (s *StarlarkCodeMode) handleReadToolFile(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) {
|
|
// Parse tool arguments
|
|
var arguments map[string]interface{}
|
|
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil {
|
|
return nil, fmt.Errorf("failed to parse tool arguments: %v", err)
|
|
}
|
|
|
|
fileName, ok := arguments["fileName"].(string)
|
|
if !ok || fileName == "" {
|
|
return nil, fmt.Errorf("fileName parameter is required and must be a string")
|
|
}
|
|
|
|
// Parse the file path to extract server name and optional tool name
|
|
serverName, toolName, isToolLevel := parseVFSFilePath(fileName)
|
|
|
|
// Get available tools per client
|
|
availableToolsPerClient := s.clientManager.GetToolPerClient(ctx)
|
|
|
|
// Find matching client
|
|
var matchedClientName string
|
|
var matchedTools []schemas.ChatTool
|
|
matchCount := 0
|
|
|
|
for clientName, tools := range availableToolsPerClient {
|
|
client := s.clientManager.GetClientByName(clientName)
|
|
if client == nil {
|
|
s.logger.Warn("%s Client %s not found, skipping", codemcp.CodeModeLogPrefix, clientName)
|
|
continue
|
|
}
|
|
if !client.ExecutionConfig.IsCodeModeClient || len(tools) == 0 {
|
|
continue
|
|
}
|
|
|
|
clientNameLower := strings.ToLower(clientName)
|
|
serverNameLower := strings.ToLower(serverName)
|
|
|
|
if clientNameLower == serverNameLower {
|
|
matchCount++
|
|
if matchCount > 1 {
|
|
// Multiple matches found
|
|
errorMsg := fmt.Sprintf("Multiple servers match filename '%s':\n", fileName)
|
|
for name := range availableToolsPerClient {
|
|
if strings.ToLower(name) == serverNameLower {
|
|
errorMsg += fmt.Sprintf(" - %s\n", name)
|
|
}
|
|
}
|
|
errorMsg += "\nPlease use a more specific filename. Use the exact display name from listToolFiles to avoid ambiguity."
|
|
return createToolResponseMessage(toolCall, errorMsg), nil
|
|
}
|
|
|
|
matchedClientName = clientName
|
|
|
|
if isToolLevel {
|
|
// Tool-level: filter to specific tool
|
|
var foundTool *schemas.ChatTool
|
|
for i, tool := range tools {
|
|
if tool.Function != nil {
|
|
if matchesToolReference(toolName, clientName, tool.Function.Name) {
|
|
foundTool = &tools[i]
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if foundTool == nil {
|
|
availableTools := make([]string, 0)
|
|
for _, tool := range tools {
|
|
if tool.Function != nil {
|
|
availableTools = append(availableTools, getCanonicalToolName(clientName, tool.Function.Name))
|
|
}
|
|
}
|
|
errorMsg := fmt.Sprintf("Tool '%s' not found in server '%s'. Available tools in this server are:\n", toolName, clientName)
|
|
for _, t := range availableTools {
|
|
errorMsg += fmt.Sprintf(" - servers/%s/%s.pyi\n", clientName, t)
|
|
}
|
|
return createToolResponseMessage(toolCall, errorMsg), nil
|
|
}
|
|
|
|
matchedTools = []schemas.ChatTool{*foundTool}
|
|
} else {
|
|
// Server-level: use all tools
|
|
matchedTools = tools
|
|
}
|
|
}
|
|
}
|
|
|
|
if matchedClientName == "" {
|
|
// Build helpful error message with available files
|
|
bindingLevel := s.GetBindingLevel()
|
|
var availableFiles []string
|
|
|
|
for name := range availableToolsPerClient {
|
|
if bindingLevel == schemas.CodeModeBindingLevelServer {
|
|
availableFiles = append(availableFiles, fmt.Sprintf("servers/%s.pyi", name))
|
|
} else {
|
|
client := s.clientManager.GetClientByName(name)
|
|
if client != nil && client.ExecutionConfig.IsCodeModeClient {
|
|
if tools, ok := availableToolsPerClient[name]; ok {
|
|
for _, tool := range tools {
|
|
if tool.Function != nil {
|
|
availableFiles = append(availableFiles, fmt.Sprintf("servers/%s/%s.pyi", name, getCanonicalToolName(name, tool.Function.Name)))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
errorMsg := fmt.Sprintf("No server found matching '%s'. Available virtual files are:\n", serverName)
|
|
for _, f := range availableFiles {
|
|
errorMsg += fmt.Sprintf(" - %s\n", f)
|
|
}
|
|
return createToolResponseMessage(toolCall, errorMsg), nil
|
|
}
|
|
|
|
// Generate compact Python signatures
|
|
fileContent := generateCompactSignatures(matchedClientName, matchedTools, isToolLevel)
|
|
lines := strings.Split(fileContent, "\n")
|
|
totalLines := len(lines)
|
|
|
|
// Prepend total lines info so LLM knows the file size upfront
|
|
fileContent = fmt.Sprintf("# Total lines: %d (this is the complete file, no need to paginate)\n%s", totalLines+1, fileContent)
|
|
// Recalculate lines after prepending
|
|
lines = strings.Split(fileContent, "\n")
|
|
totalLines = len(lines)
|
|
|
|
// Handle line slicing if provided
|
|
var startLine, endLine *int
|
|
if sl, ok := arguments["startLine"].(float64); ok {
|
|
slInt := int(sl)
|
|
startLine = &slInt
|
|
}
|
|
if el, ok := arguments["endLine"].(float64); ok {
|
|
elInt := int(el)
|
|
endLine = &elInt
|
|
}
|
|
|
|
if startLine != nil || endLine != nil {
|
|
start := 1
|
|
if startLine != nil {
|
|
start = *startLine
|
|
}
|
|
end := totalLines
|
|
if endLine != nil {
|
|
end = *endLine
|
|
}
|
|
|
|
// Clamp values to valid range instead of erroring
|
|
// This handles cases where LLM requests more lines than exist
|
|
if start < 1 {
|
|
start = 1
|
|
}
|
|
if start > totalLines {
|
|
start = totalLines
|
|
}
|
|
if end < 1 {
|
|
end = 1
|
|
}
|
|
if end > totalLines {
|
|
end = totalLines
|
|
}
|
|
if start > end {
|
|
// If start > end after clamping, just return the start line
|
|
end = start
|
|
}
|
|
|
|
// Slice lines (convert to 0-based indexing)
|
|
selectedLines := lines[start-1 : end]
|
|
fileContent = strings.Join(selectedLines, "\n")
|
|
}
|
|
|
|
return createToolResponseMessage(toolCall, fileContent), nil
|
|
}
|
|
|
|
// parseVFSFilePath parses a VFS file path and extracts the server name and optional tool name.
|
|
func parseVFSFilePath(fileName string) (serverName, toolName string, isToolLevel bool) {
|
|
// Remove .pyi extension
|
|
basePath := strings.TrimSuffix(fileName, ".pyi")
|
|
|
|
// Remove "servers/" prefix if present
|
|
basePath = strings.TrimPrefix(basePath, "servers/")
|
|
|
|
// Defensive validation: reject paths with path traversal attempts
|
|
if strings.Contains(basePath, "..") {
|
|
// Return empty to indicate invalid path
|
|
return "", "", false
|
|
}
|
|
|
|
// Check for path separator
|
|
parts := strings.Split(basePath, "/")
|
|
if len(parts) == 2 {
|
|
// Tool-level: "serverName/toolName"
|
|
// Validate that tool name doesn't contain additional path separators or traversal
|
|
if parts[1] == "" || strings.Contains(parts[1], "/") || strings.Contains(parts[1], "..") {
|
|
// Invalid tool name, treat as server-level
|
|
return parts[0], "", false
|
|
}
|
|
return parts[0], parts[1], true
|
|
}
|
|
// Server-level: "serverName"
|
|
// Validate server name doesn't contain path separators or traversal
|
|
if strings.Contains(basePath, "/") || strings.Contains(basePath, "..") {
|
|
// Invalid path
|
|
return "", "", false
|
|
}
|
|
return basePath, "", false
|
|
}
|
|
|
|
// generateCompactSignatures generates compact Python function signatures for tools.
|
|
func generateCompactSignatures(clientName string, tools []schemas.ChatTool, isToolLevel bool) string {
|
|
var sb strings.Builder
|
|
|
|
// Minimal header
|
|
if isToolLevel && len(tools) == 1 && tools[0].Function != nil {
|
|
toolName := getCanonicalToolName(clientName, tools[0].Function.Name)
|
|
sb.WriteString(fmt.Sprintf("# %s.%s tool\n", clientName, toolName))
|
|
} else {
|
|
sb.WriteString(fmt.Sprintf("# %s server tools\n", clientName))
|
|
}
|
|
sb.WriteString(fmt.Sprintf("# Usage: %s.tool_name(param=value)\n", clientName))
|
|
sb.WriteString("# The def names below are the exact callable names to use in executeToolCode.\n")
|
|
sb.WriteString("# Read this file before executeToolCode to confirm parameters and return shape.\n")
|
|
sb.WriteString(fmt.Sprintf("# For detailed docs: use getToolDocs(server=\"%s\", tool=\"tool_name\")\n", clientName))
|
|
sb.WriteString("# Note: Descriptions may be truncated. Use getToolDocs for full details.\n\n")
|
|
|
|
for _, tool := range tools {
|
|
if tool.Function == nil || tool.Function.Name == "" {
|
|
continue
|
|
}
|
|
|
|
toolName := getCanonicalToolName(clientName, tool.Function.Name)
|
|
|
|
// Format inline parameters in Python style
|
|
params := formatPythonParams(tool.Function.Parameters)
|
|
|
|
// Get description (truncate if too long)
|
|
desc := ""
|
|
if tool.Function.Description != nil && *tool.Function.Description != "" {
|
|
desc = *tool.Function.Description
|
|
// Truncate long descriptions to first sentence or 80 chars
|
|
if idx := strings.Index(desc, ". "); idx > 0 && idx < 80 {
|
|
desc = desc[:idx+1]
|
|
} else if len(desc) > 80 {
|
|
desc = desc[:77] + "..."
|
|
}
|
|
}
|
|
|
|
// Write Python signature: def tool_name(param: type, param: type = None) -> dict: # description
|
|
if desc != "" {
|
|
sb.WriteString(fmt.Sprintf("def %s(%s) -> dict: # %s\n", toolName, params, desc))
|
|
} else {
|
|
sb.WriteString(fmt.Sprintf("def %s(%s) -> dict\n", toolName, params))
|
|
}
|
|
}
|
|
|
|
return sb.String()
|
|
}
|
|
|
|
// formatPythonParams formats tool parameters as Python function parameters.
|
|
func formatPythonParams(params *schemas.ToolFunctionParameters) string {
|
|
if params == nil || params.Properties == nil || params.Properties.Len() == 0 {
|
|
return ""
|
|
}
|
|
|
|
props := params.Properties
|
|
required := make(map[string]bool)
|
|
if params.Required != nil {
|
|
for _, req := range params.Required {
|
|
required[req] = true
|
|
}
|
|
}
|
|
|
|
// Sort properties: required first, then optional, alphabetically within each group
|
|
requiredNames := make([]string, 0)
|
|
optionalNames := make([]string, 0)
|
|
props.Range(func(name string, _ interface{}) bool {
|
|
if required[name] {
|
|
requiredNames = append(requiredNames, name)
|
|
} else {
|
|
optionalNames = append(optionalNames, name)
|
|
}
|
|
return true
|
|
})
|
|
// Simple alphabetical sort for each group
|
|
for i := 0; i < len(requiredNames)-1; i++ {
|
|
for j := i + 1; j < len(requiredNames); j++ {
|
|
if requiredNames[i] > requiredNames[j] {
|
|
requiredNames[i], requiredNames[j] = requiredNames[j], requiredNames[i]
|
|
}
|
|
}
|
|
}
|
|
for i := 0; i < len(optionalNames)-1; i++ {
|
|
for j := i + 1; j < len(optionalNames); j++ {
|
|
if optionalNames[i] > optionalNames[j] {
|
|
optionalNames[i], optionalNames[j] = optionalNames[j], optionalNames[i]
|
|
}
|
|
}
|
|
}
|
|
|
|
parts := make([]string, 0, props.Len())
|
|
|
|
// Add required params first
|
|
for _, propName := range requiredNames {
|
|
prop, _ := props.Get(propName)
|
|
propMap, ok := prop.(map[string]interface{})
|
|
if !ok {
|
|
continue
|
|
}
|
|
pyType := jsonSchemaToPython(propMap)
|
|
parts = append(parts, fmt.Sprintf("%s: %s", propName, pyType))
|
|
}
|
|
|
|
// Add optional params with default None
|
|
for _, propName := range optionalNames {
|
|
prop, _ := props.Get(propName)
|
|
propMap, ok := prop.(map[string]interface{})
|
|
if !ok {
|
|
continue
|
|
}
|
|
pyType := jsonSchemaToPython(propMap)
|
|
parts = append(parts, fmt.Sprintf("%s: %s = None", propName, pyType))
|
|
}
|
|
|
|
return strings.Join(parts, ", ")
|
|
}
|
|
|
|
// jsonSchemaToPython converts a JSON Schema type definition to a Python type string.
|
|
func jsonSchemaToPython(prop map[string]interface{}) string {
|
|
// Check for enum first - takes precedence over type to show allowed values
|
|
if enum, ok := prop["enum"].([]interface{}); ok && len(enum) > 0 {
|
|
enumStrs := make([]string, 0, len(enum))
|
|
for _, e := range enum {
|
|
enumStrs = append(enumStrs, fmt.Sprintf("%q", e))
|
|
}
|
|
return "Literal[" + strings.Join(enumStrs, ", ") + "]"
|
|
}
|
|
|
|
// Check for const (single fixed value)
|
|
if constVal, ok := prop["const"]; ok {
|
|
return fmt.Sprintf("Literal[%q]", constVal)
|
|
}
|
|
|
|
// Fall back to type-based conversion
|
|
if typeVal, ok := prop["type"].(string); ok {
|
|
switch typeVal {
|
|
case "string":
|
|
return "str"
|
|
case "number":
|
|
return "float"
|
|
case "integer":
|
|
return "int"
|
|
case "boolean":
|
|
return "bool"
|
|
case "array":
|
|
itemsType := "Any"
|
|
if items, ok := prop["items"].(map[string]interface{}); ok {
|
|
itemsType = jsonSchemaToPython(items)
|
|
}
|
|
return fmt.Sprintf("list[%s]", itemsType)
|
|
case "object":
|
|
return "dict"
|
|
case "null":
|
|
return "None"
|
|
}
|
|
}
|
|
|
|
return "Any"
|
|
}
|