444 lines
15 KiB
Go
444 lines
15 KiB
Go
//go:build !tinygo && !wasm
|
|
|
|
package starlark
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"regexp"
|
|
"strings"
|
|
"unicode"
|
|
|
|
"github.com/bytedance/sonic"
|
|
"github.com/mark3labs/mcp-go/mcp"
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
"go.starlark.net/starlark"
|
|
"go.starlark.net/starlarkstruct"
|
|
)
|
|
|
|
// starlarkToGo converts a Starlark value to a Go value
|
|
func starlarkToGo(v starlark.Value) interface{} {
|
|
switch val := v.(type) {
|
|
case starlark.NoneType:
|
|
return nil
|
|
case starlark.Bool:
|
|
return bool(val)
|
|
case starlark.Int:
|
|
if i, ok := val.Int64(); ok {
|
|
return i
|
|
}
|
|
if i, ok := val.Uint64(); ok {
|
|
return i
|
|
}
|
|
return val.String()
|
|
case starlark.Float:
|
|
return float64(val)
|
|
case starlark.String:
|
|
return string(val)
|
|
case *starlark.List:
|
|
result := make([]interface{}, val.Len())
|
|
for i := 0; i < val.Len(); i++ {
|
|
result[i] = starlarkToGo(val.Index(i))
|
|
}
|
|
return result
|
|
case starlark.Tuple:
|
|
result := make([]interface{}, len(val))
|
|
for i, item := range val {
|
|
result[i] = starlarkToGo(item)
|
|
}
|
|
return result
|
|
case *starlark.Dict:
|
|
result := make(map[string]interface{})
|
|
for _, item := range val.Items() {
|
|
if keyStr, ok := item[0].(starlark.String); ok {
|
|
result[string(keyStr)] = starlarkToGo(item[1])
|
|
} else {
|
|
// Use string representation for non-string keys
|
|
result[item[0].String()] = starlarkToGo(item[1])
|
|
}
|
|
}
|
|
return result
|
|
case *starlarkstruct.Struct:
|
|
result := make(map[string]interface{})
|
|
for _, name := range val.AttrNames() {
|
|
if attrVal, err := val.Attr(name); err == nil {
|
|
result[name] = starlarkToGo(attrVal)
|
|
}
|
|
}
|
|
return result
|
|
default:
|
|
return val.String()
|
|
}
|
|
}
|
|
|
|
// goToStarlark converts a Go value to a Starlark value
|
|
func goToStarlark(v interface{}) starlark.Value {
|
|
if v == nil {
|
|
return starlark.None
|
|
}
|
|
|
|
switch val := v.(type) {
|
|
case bool:
|
|
return starlark.Bool(val)
|
|
case int:
|
|
return starlark.MakeInt(val)
|
|
case int64:
|
|
return starlark.MakeInt64(val)
|
|
case uint64:
|
|
return starlark.MakeUint64(val)
|
|
case float64:
|
|
return starlark.Float(val)
|
|
case string:
|
|
return starlark.String(val)
|
|
case []interface{}:
|
|
items := make([]starlark.Value, len(val))
|
|
for i, item := range val {
|
|
items[i] = goToStarlark(item)
|
|
}
|
|
return starlark.NewList(items)
|
|
case map[string]interface{}:
|
|
dict := starlark.NewDict(len(val))
|
|
for k, v := range val {
|
|
dict.SetKey(starlark.String(k), goToStarlark(v))
|
|
}
|
|
return dict
|
|
default:
|
|
// Try to marshal to JSON and parse as a generic structure
|
|
if jsonBytes, err := schemas.MarshalSorted(val); err == nil {
|
|
var generic interface{}
|
|
if schemas.Unmarshal(jsonBytes, &generic) == nil {
|
|
return goToStarlark(generic)
|
|
}
|
|
}
|
|
return starlark.String(fmt.Sprintf("%v", val))
|
|
}
|
|
}
|
|
|
|
// extractResultFromChatMessage extracts the result from a chat message and parses it as JSON if possible.
|
|
func extractResultFromChatMessage(msg *schemas.ChatMessage) interface{} {
|
|
if msg == nil || msg.Content == nil || msg.Content.ContentStr == nil {
|
|
return nil
|
|
}
|
|
|
|
rawResult := *msg.Content.ContentStr
|
|
|
|
var finalResult interface{}
|
|
if err := sonic.Unmarshal([]byte(rawResult), &finalResult); err != nil {
|
|
return rawResult
|
|
}
|
|
|
|
return finalResult
|
|
}
|
|
|
|
// extractResultFromResponsesMessage extracts the result or error from a ResponsesMessage.
|
|
func extractResultFromResponsesMessage(msg *schemas.ResponsesMessage) (interface{}, error) {
|
|
if msg == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
if msg.ResponsesToolMessage != nil {
|
|
if msg.ResponsesToolMessage.Error != nil && *msg.ResponsesToolMessage.Error != "" {
|
|
return nil, fmt.Errorf("%s", *msg.ResponsesToolMessage.Error)
|
|
}
|
|
|
|
if msg.ResponsesToolMessage.Output != nil {
|
|
if msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != nil {
|
|
rawResult := *msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr
|
|
|
|
var finalResult interface{}
|
|
if err := sonic.Unmarshal([]byte(rawResult), &finalResult); err != nil {
|
|
return rawResult, nil
|
|
}
|
|
return finalResult, nil
|
|
}
|
|
|
|
if len(msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks) > 0 {
|
|
var textParts []string
|
|
for _, block := range msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks {
|
|
if block.Text != nil {
|
|
textParts = append(textParts, *block.Text)
|
|
}
|
|
}
|
|
if len(textParts) > 0 {
|
|
result := strings.Join(textParts, "\n")
|
|
var finalResult interface{}
|
|
if err := sonic.Unmarshal([]byte(result), &finalResult); err != nil {
|
|
return result, nil
|
|
}
|
|
return finalResult, nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil, nil
|
|
}
|
|
|
|
// formatResultForLog formats a result value for logging purposes.
|
|
func formatResultForLog(result interface{}) string {
|
|
var resultStr string
|
|
if result == nil {
|
|
resultStr = "null"
|
|
} else if resultBytes, err := schemas.MarshalSorted(result); err == nil {
|
|
resultStr = string(resultBytes)
|
|
} else {
|
|
resultStr = fmt.Sprintf("%v", result)
|
|
}
|
|
return resultStr
|
|
}
|
|
|
|
// generatePythonErrorHints generates helpful hints for Python/Starlark errors.
|
|
func generatePythonErrorHints(errorMessage string, serverKeys []string) []string {
|
|
hints := []string{}
|
|
|
|
if strings.Contains(errorMessage, "got try") || strings.Contains(errorMessage, "got except") ||
|
|
strings.Contains(errorMessage, "got finally") || strings.Contains(errorMessage, "got raise") {
|
|
hints = append(hints, "Starlark does NOT support try/except/finally/raise — there is no exception handling.")
|
|
hints = append(hints, "Instead, check return values for errors:")
|
|
hints = append(hints, " result = server.tool(param=\"value\")")
|
|
hints = append(hints, " if result == None or (type(result) == \"dict\" and \"error\" in result):")
|
|
hints = append(hints, " print(\"Error:\", result)")
|
|
} else if strings.Contains(errorMessage, "undefined") || strings.Contains(errorMessage, "not defined") {
|
|
var undefinedVar string
|
|
if match := regexp.MustCompile(`name ['"]([^'"]+)['"] is not defined`).FindStringSubmatch(errorMessage); len(match) > 1 {
|
|
undefinedVar = match[1]
|
|
} else if match := regexp.MustCompile(`undefined:\s*([A-Za-z_][A-Za-z0-9_]*)`).FindStringSubmatch(errorMessage); len(match) > 1 {
|
|
undefinedVar = match[1]
|
|
} else if match := regexp.MustCompile(`([A-Za-z_][A-Za-z0-9_]*)[^A-Za-z0-9_]+(?:undefined|not defined)`).FindStringSubmatch(errorMessage); len(match) > 1 {
|
|
undefinedVar = match[1]
|
|
}
|
|
if undefinedVar != "" {
|
|
hints = append(hints, fmt.Sprintf("Variable '%s' is not defined.", undefinedVar))
|
|
hints = append(hints, "Note: Each executeToolCode call runs in a fresh scope — no variables persist between calls.")
|
|
if len(serverKeys) > 0 {
|
|
hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", ")))
|
|
hints = append(hints, "Access tools using: server_name.tool_name(param=\"value\")")
|
|
}
|
|
}
|
|
} else if strings.Contains(errorMessage, "not within a function") {
|
|
hints = append(hints, "Starlark requires for/if/while statements to be inside functions at the top level.")
|
|
hints = append(hints, "Wrap your code in a function, then call it:")
|
|
hints = append(hints, " def fetch_all():")
|
|
hints = append(hints, " results = []")
|
|
hints = append(hints, " for id in ids:")
|
|
hints = append(hints, " results.append(server.get(id=id))")
|
|
hints = append(hints, " return results")
|
|
hints = append(hints, " result = fetch_all()")
|
|
} else if strings.Contains(errorMessage, "syntax error") {
|
|
hints = append(hints, "Python syntax error detected.")
|
|
hints = append(hints, "Check for proper indentation (use spaces, not tabs).")
|
|
hints = append(hints, "Ensure colons after if/for/def statements.")
|
|
hints = append(hints, "Check for matching parentheses and brackets.")
|
|
} else if strings.Contains(errorMessage, "has no") && strings.Contains(errorMessage, "attribute") {
|
|
hints = append(hints, "You're trying to access an attribute that doesn't exist.")
|
|
hints = append(hints, "Use dict access syntax: result[\"key\"] instead of result.key")
|
|
hints = append(hints, "Use print(result) to see the actual structure.")
|
|
if len(serverKeys) > 0 {
|
|
hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", ")))
|
|
}
|
|
} else if strings.Contains(errorMessage, "not callable") {
|
|
hints = append(hints, "You're trying to call something that is not a function.")
|
|
hints = append(hints, "Ensure you're using the correct tool name.")
|
|
if len(serverKeys) > 0 {
|
|
hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", ")))
|
|
}
|
|
hints = append(hints, "Use readToolFile to see available tools for a server.")
|
|
} else if strings.Contains(errorMessage, "key") && strings.Contains(errorMessage, "not found") {
|
|
hints = append(hints, "Dictionary key not found.")
|
|
hints = append(hints, "Use print() to inspect the dict structure before accessing keys.")
|
|
hints = append(hints, "Use .get(\"key\", default) for safe access.")
|
|
} else {
|
|
hints = append(hints, "Check the error message above for details.")
|
|
if len(serverKeys) > 0 {
|
|
hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", ")))
|
|
}
|
|
hints = append(hints, "Use: result = server_name.tool_name(param=\"value\")")
|
|
hints = append(hints, "Access dict values with brackets: result[\"key\"]")
|
|
}
|
|
|
|
return hints
|
|
}
|
|
|
|
// extractTextFromMCPResponse extracts text content from an MCP tool response.
|
|
func extractTextFromMCPResponse(toolResponse *mcp.CallToolResult, toolName string) string {
|
|
if toolResponse == nil {
|
|
return fmt.Sprintf("MCP tool '%s' executed successfully", toolName)
|
|
}
|
|
|
|
var result strings.Builder
|
|
for _, contentBlock := range toolResponse.Content {
|
|
// Handle typed content
|
|
switch content := contentBlock.(type) {
|
|
case mcp.TextContent:
|
|
result.WriteString(content.Text)
|
|
case mcp.ImageContent:
|
|
result.WriteString(fmt.Sprintf("[Image Response: %s, MIME: %s]\n", content.Data, content.MIMEType))
|
|
case mcp.AudioContent:
|
|
result.WriteString(fmt.Sprintf("[Audio Response: %s, MIME: %s]\n", content.Data, content.MIMEType))
|
|
case mcp.EmbeddedResource:
|
|
result.WriteString(fmt.Sprintf("[Embedded Resource Response: %s]\n", content.Type))
|
|
default:
|
|
// Fallback: try to extract from map structure
|
|
if jsonBytes, err := schemas.MarshalSorted(contentBlock); err == nil {
|
|
var contentMap map[string]interface{}
|
|
if json.Unmarshal(jsonBytes, &contentMap) == nil {
|
|
if text, ok := contentMap["text"].(string); ok {
|
|
result.WriteString(fmt.Sprintf("[Text Response: %s]\n", text))
|
|
continue
|
|
}
|
|
}
|
|
// Final fallback: serialize as JSON
|
|
result.WriteString(string(jsonBytes))
|
|
}
|
|
}
|
|
}
|
|
|
|
if result.Len() > 0 {
|
|
return strings.TrimSpace(result.String())
|
|
}
|
|
return fmt.Sprintf("MCP tool '%s' executed successfully", toolName)
|
|
}
|
|
|
|
// createToolResponseMessage creates a tool response message with the execution result.
|
|
func createToolResponseMessage(toolCall schemas.ChatAssistantMessageToolCall, responseText string) *schemas.ChatMessage {
|
|
return &schemas.ChatMessage{
|
|
Role: schemas.ChatMessageRoleTool,
|
|
Content: &schemas.ChatMessageContent{
|
|
ContentStr: &responseText,
|
|
},
|
|
ChatToolMessage: &schemas.ChatToolMessage{
|
|
ToolCallID: toolCall.ID,
|
|
},
|
|
}
|
|
}
|
|
|
|
// parseToolName normalizes a raw tool name into a Starlark-compatible identifier.
|
|
func parseToolName(toolName string) string {
|
|
if toolName == "" {
|
|
return ""
|
|
}
|
|
|
|
var result strings.Builder
|
|
runes := []rune(toolName)
|
|
|
|
// Process first character - must be letter, underscore, or dollar sign
|
|
if len(runes) > 0 {
|
|
first := runes[0]
|
|
if unicode.IsLetter(first) || first == '_' || first == '$' {
|
|
result.WriteRune(unicode.ToLower(first))
|
|
} else {
|
|
// If first char is invalid, prefix with underscore
|
|
result.WriteRune('_')
|
|
if unicode.IsDigit(first) {
|
|
result.WriteRune(first)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Process remaining characters
|
|
for i := 1; i < len(runes); i++ {
|
|
r := runes[i]
|
|
if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '$' {
|
|
result.WriteRune(unicode.ToLower(r))
|
|
} else if unicode.IsSpace(r) || r == '-' {
|
|
// Replace spaces and hyphens with single underscore
|
|
// Avoid consecutive underscores
|
|
if result.Len() > 0 && result.String()[result.Len()-1] != '_' {
|
|
result.WriteRune('_')
|
|
}
|
|
}
|
|
// Skip other invalid characters
|
|
}
|
|
|
|
parsed := result.String()
|
|
|
|
// Remove trailing underscores
|
|
parsed = strings.TrimRight(parsed, "_")
|
|
|
|
// Ensure we have at least one character
|
|
if parsed == "" {
|
|
return "tool"
|
|
}
|
|
|
|
return parsed
|
|
}
|
|
|
|
// getCanonicalToolName returns the exact callable tool identifier exposed in Starlark.
|
|
func getCanonicalToolName(clientName, originalToolName string) string {
|
|
return parseToolName(stripClientPrefix(originalToolName, clientName))
|
|
}
|
|
|
|
// getCompatibilityToolAlias returns the case-preserving alias derived from the raw tool name.
|
|
// This is used as a compatibility alias when the raw name is still a valid Starlark identifier.
|
|
func getCompatibilityToolAlias(clientName, originalToolName string) string {
|
|
return strings.ReplaceAll(stripClientPrefix(originalToolName, clientName), "-", "_")
|
|
}
|
|
|
|
// matchesToolReference reports whether the requested tool name matches any supported identifier form.
|
|
// We accept the canonical callable name plus legacy display forms for backward compatibility.
|
|
func matchesToolReference(requestedToolName, clientName, originalToolName string) bool {
|
|
requested := strings.ToLower(requestedToolName)
|
|
if requested == "" {
|
|
return false
|
|
}
|
|
|
|
candidates := []string{
|
|
getCanonicalToolName(clientName, originalToolName),
|
|
getCompatibilityToolAlias(clientName, originalToolName),
|
|
stripClientPrefix(originalToolName, clientName),
|
|
}
|
|
|
|
for _, candidate := range candidates {
|
|
if candidate != "" && requested == strings.ToLower(candidate) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// isValidStarlarkIdentifier reports whether name can be used directly in Starlark code.
|
|
func isValidStarlarkIdentifier(name string) bool {
|
|
if name == "" {
|
|
return false
|
|
}
|
|
|
|
runes := []rune(name)
|
|
first := runes[0]
|
|
if !unicode.IsLetter(first) && first != '_' && first != '$' {
|
|
return false
|
|
}
|
|
|
|
for _, r := range runes[1:] {
|
|
if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' && r != '$' {
|
|
return false
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// validateNormalizedToolName validates a normalized tool name to prevent path traversal.
|
|
func validateNormalizedToolName(normalizedName string) error {
|
|
if normalizedName == "" {
|
|
return fmt.Errorf("tool name cannot be empty after normalization")
|
|
}
|
|
if strings.Contains(normalizedName, "/") {
|
|
return fmt.Errorf("tool name cannot contain '/' (path separator) after normalization: %s", normalizedName)
|
|
}
|
|
if strings.Contains(normalizedName, "..") {
|
|
return fmt.Errorf("tool name cannot contain '..' (path traversal) after normalization: %s", normalizedName)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// stripClientPrefix removes the client name prefix from a tool name.
|
|
func stripClientPrefix(prefixedToolName, clientName string) string {
|
|
prefix := clientName + "-"
|
|
if strings.HasPrefix(prefixedToolName, prefix) {
|
|
return strings.TrimPrefix(prefixedToolName, prefix)
|
|
}
|
|
// If prefix doesn't match, return as-is (shouldn't happen, but be safe)
|
|
return prefixedToolName
|
|
}
|