first commit
This commit is contained in:
443
core/mcp/codemode/starlark/utils.go
Normal file
443
core/mcp/codemode/starlark/utils.go
Normal file
@@ -0,0 +1,443 @@
|
||||
//go:build !tinygo && !wasm
|
||||
|
||||
package starlark
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"go.starlark.net/starlark"
|
||||
"go.starlark.net/starlarkstruct"
|
||||
)
|
||||
|
||||
// starlarkToGo converts a Starlark value to a Go value
|
||||
func starlarkToGo(v starlark.Value) interface{} {
|
||||
switch val := v.(type) {
|
||||
case starlark.NoneType:
|
||||
return nil
|
||||
case starlark.Bool:
|
||||
return bool(val)
|
||||
case starlark.Int:
|
||||
if i, ok := val.Int64(); ok {
|
||||
return i
|
||||
}
|
||||
if i, ok := val.Uint64(); ok {
|
||||
return i
|
||||
}
|
||||
return val.String()
|
||||
case starlark.Float:
|
||||
return float64(val)
|
||||
case starlark.String:
|
||||
return string(val)
|
||||
case *starlark.List:
|
||||
result := make([]interface{}, val.Len())
|
||||
for i := 0; i < val.Len(); i++ {
|
||||
result[i] = starlarkToGo(val.Index(i))
|
||||
}
|
||||
return result
|
||||
case starlark.Tuple:
|
||||
result := make([]interface{}, len(val))
|
||||
for i, item := range val {
|
||||
result[i] = starlarkToGo(item)
|
||||
}
|
||||
return result
|
||||
case *starlark.Dict:
|
||||
result := make(map[string]interface{})
|
||||
for _, item := range val.Items() {
|
||||
if keyStr, ok := item[0].(starlark.String); ok {
|
||||
result[string(keyStr)] = starlarkToGo(item[1])
|
||||
} else {
|
||||
// Use string representation for non-string keys
|
||||
result[item[0].String()] = starlarkToGo(item[1])
|
||||
}
|
||||
}
|
||||
return result
|
||||
case *starlarkstruct.Struct:
|
||||
result := make(map[string]interface{})
|
||||
for _, name := range val.AttrNames() {
|
||||
if attrVal, err := val.Attr(name); err == nil {
|
||||
result[name] = starlarkToGo(attrVal)
|
||||
}
|
||||
}
|
||||
return result
|
||||
default:
|
||||
return val.String()
|
||||
}
|
||||
}
|
||||
|
||||
// goToStarlark converts a Go value to a Starlark value
|
||||
func goToStarlark(v interface{}) starlark.Value {
|
||||
if v == nil {
|
||||
return starlark.None
|
||||
}
|
||||
|
||||
switch val := v.(type) {
|
||||
case bool:
|
||||
return starlark.Bool(val)
|
||||
case int:
|
||||
return starlark.MakeInt(val)
|
||||
case int64:
|
||||
return starlark.MakeInt64(val)
|
||||
case uint64:
|
||||
return starlark.MakeUint64(val)
|
||||
case float64:
|
||||
return starlark.Float(val)
|
||||
case string:
|
||||
return starlark.String(val)
|
||||
case []interface{}:
|
||||
items := make([]starlark.Value, len(val))
|
||||
for i, item := range val {
|
||||
items[i] = goToStarlark(item)
|
||||
}
|
||||
return starlark.NewList(items)
|
||||
case map[string]interface{}:
|
||||
dict := starlark.NewDict(len(val))
|
||||
for k, v := range val {
|
||||
dict.SetKey(starlark.String(k), goToStarlark(v))
|
||||
}
|
||||
return dict
|
||||
default:
|
||||
// Try to marshal to JSON and parse as a generic structure
|
||||
if jsonBytes, err := schemas.MarshalSorted(val); err == nil {
|
||||
var generic interface{}
|
||||
if schemas.Unmarshal(jsonBytes, &generic) == nil {
|
||||
return goToStarlark(generic)
|
||||
}
|
||||
}
|
||||
return starlark.String(fmt.Sprintf("%v", val))
|
||||
}
|
||||
}
|
||||
|
||||
// extractResultFromChatMessage extracts the result from a chat message and parses it as JSON if possible.
|
||||
func extractResultFromChatMessage(msg *schemas.ChatMessage) interface{} {
|
||||
if msg == nil || msg.Content == nil || msg.Content.ContentStr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
rawResult := *msg.Content.ContentStr
|
||||
|
||||
var finalResult interface{}
|
||||
if err := sonic.Unmarshal([]byte(rawResult), &finalResult); err != nil {
|
||||
return rawResult
|
||||
}
|
||||
|
||||
return finalResult
|
||||
}
|
||||
|
||||
// extractResultFromResponsesMessage extracts the result or error from a ResponsesMessage.
|
||||
func extractResultFromResponsesMessage(msg *schemas.ResponsesMessage) (interface{}, error) {
|
||||
if msg == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if msg.ResponsesToolMessage != nil {
|
||||
if msg.ResponsesToolMessage.Error != nil && *msg.ResponsesToolMessage.Error != "" {
|
||||
return nil, fmt.Errorf("%s", *msg.ResponsesToolMessage.Error)
|
||||
}
|
||||
|
||||
if msg.ResponsesToolMessage.Output != nil {
|
||||
if msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != nil {
|
||||
rawResult := *msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr
|
||||
|
||||
var finalResult interface{}
|
||||
if err := sonic.Unmarshal([]byte(rawResult), &finalResult); err != nil {
|
||||
return rawResult, nil
|
||||
}
|
||||
return finalResult, nil
|
||||
}
|
||||
|
||||
if len(msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks) > 0 {
|
||||
var textParts []string
|
||||
for _, block := range msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks {
|
||||
if block.Text != nil {
|
||||
textParts = append(textParts, *block.Text)
|
||||
}
|
||||
}
|
||||
if len(textParts) > 0 {
|
||||
result := strings.Join(textParts, "\n")
|
||||
var finalResult interface{}
|
||||
if err := sonic.Unmarshal([]byte(result), &finalResult); err != nil {
|
||||
return result, nil
|
||||
}
|
||||
return finalResult, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// formatResultForLog formats a result value for logging purposes.
|
||||
func formatResultForLog(result interface{}) string {
|
||||
var resultStr string
|
||||
if result == nil {
|
||||
resultStr = "null"
|
||||
} else if resultBytes, err := schemas.MarshalSorted(result); err == nil {
|
||||
resultStr = string(resultBytes)
|
||||
} else {
|
||||
resultStr = fmt.Sprintf("%v", result)
|
||||
}
|
||||
return resultStr
|
||||
}
|
||||
|
||||
// generatePythonErrorHints generates helpful hints for Python/Starlark errors.
|
||||
func generatePythonErrorHints(errorMessage string, serverKeys []string) []string {
|
||||
hints := []string{}
|
||||
|
||||
if strings.Contains(errorMessage, "got try") || strings.Contains(errorMessage, "got except") ||
|
||||
strings.Contains(errorMessage, "got finally") || strings.Contains(errorMessage, "got raise") {
|
||||
hints = append(hints, "Starlark does NOT support try/except/finally/raise — there is no exception handling.")
|
||||
hints = append(hints, "Instead, check return values for errors:")
|
||||
hints = append(hints, " result = server.tool(param=\"value\")")
|
||||
hints = append(hints, " if result == None or (type(result) == \"dict\" and \"error\" in result):")
|
||||
hints = append(hints, " print(\"Error:\", result)")
|
||||
} else if strings.Contains(errorMessage, "undefined") || strings.Contains(errorMessage, "not defined") {
|
||||
var undefinedVar string
|
||||
if match := regexp.MustCompile(`name ['"]([^'"]+)['"] is not defined`).FindStringSubmatch(errorMessage); len(match) > 1 {
|
||||
undefinedVar = match[1]
|
||||
} else if match := regexp.MustCompile(`undefined:\s*([A-Za-z_][A-Za-z0-9_]*)`).FindStringSubmatch(errorMessage); len(match) > 1 {
|
||||
undefinedVar = match[1]
|
||||
} else if match := regexp.MustCompile(`([A-Za-z_][A-Za-z0-9_]*)[^A-Za-z0-9_]+(?:undefined|not defined)`).FindStringSubmatch(errorMessage); len(match) > 1 {
|
||||
undefinedVar = match[1]
|
||||
}
|
||||
if undefinedVar != "" {
|
||||
hints = append(hints, fmt.Sprintf("Variable '%s' is not defined.", undefinedVar))
|
||||
hints = append(hints, "Note: Each executeToolCode call runs in a fresh scope — no variables persist between calls.")
|
||||
if len(serverKeys) > 0 {
|
||||
hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", ")))
|
||||
hints = append(hints, "Access tools using: server_name.tool_name(param=\"value\")")
|
||||
}
|
||||
}
|
||||
} else if strings.Contains(errorMessage, "not within a function") {
|
||||
hints = append(hints, "Starlark requires for/if/while statements to be inside functions at the top level.")
|
||||
hints = append(hints, "Wrap your code in a function, then call it:")
|
||||
hints = append(hints, " def fetch_all():")
|
||||
hints = append(hints, " results = []")
|
||||
hints = append(hints, " for id in ids:")
|
||||
hints = append(hints, " results.append(server.get(id=id))")
|
||||
hints = append(hints, " return results")
|
||||
hints = append(hints, " result = fetch_all()")
|
||||
} else if strings.Contains(errorMessage, "syntax error") {
|
||||
hints = append(hints, "Python syntax error detected.")
|
||||
hints = append(hints, "Check for proper indentation (use spaces, not tabs).")
|
||||
hints = append(hints, "Ensure colons after if/for/def statements.")
|
||||
hints = append(hints, "Check for matching parentheses and brackets.")
|
||||
} else if strings.Contains(errorMessage, "has no") && strings.Contains(errorMessage, "attribute") {
|
||||
hints = append(hints, "You're trying to access an attribute that doesn't exist.")
|
||||
hints = append(hints, "Use dict access syntax: result[\"key\"] instead of result.key")
|
||||
hints = append(hints, "Use print(result) to see the actual structure.")
|
||||
if len(serverKeys) > 0 {
|
||||
hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", ")))
|
||||
}
|
||||
} else if strings.Contains(errorMessage, "not callable") {
|
||||
hints = append(hints, "You're trying to call something that is not a function.")
|
||||
hints = append(hints, "Ensure you're using the correct tool name.")
|
||||
if len(serverKeys) > 0 {
|
||||
hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", ")))
|
||||
}
|
||||
hints = append(hints, "Use readToolFile to see available tools for a server.")
|
||||
} else if strings.Contains(errorMessage, "key") && strings.Contains(errorMessage, "not found") {
|
||||
hints = append(hints, "Dictionary key not found.")
|
||||
hints = append(hints, "Use print() to inspect the dict structure before accessing keys.")
|
||||
hints = append(hints, "Use .get(\"key\", default) for safe access.")
|
||||
} else {
|
||||
hints = append(hints, "Check the error message above for details.")
|
||||
if len(serverKeys) > 0 {
|
||||
hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", ")))
|
||||
}
|
||||
hints = append(hints, "Use: result = server_name.tool_name(param=\"value\")")
|
||||
hints = append(hints, "Access dict values with brackets: result[\"key\"]")
|
||||
}
|
||||
|
||||
return hints
|
||||
}
|
||||
|
||||
// extractTextFromMCPResponse extracts text content from an MCP tool response.
|
||||
func extractTextFromMCPResponse(toolResponse *mcp.CallToolResult, toolName string) string {
|
||||
if toolResponse == nil {
|
||||
return fmt.Sprintf("MCP tool '%s' executed successfully", toolName)
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
for _, contentBlock := range toolResponse.Content {
|
||||
// Handle typed content
|
||||
switch content := contentBlock.(type) {
|
||||
case mcp.TextContent:
|
||||
result.WriteString(content.Text)
|
||||
case mcp.ImageContent:
|
||||
result.WriteString(fmt.Sprintf("[Image Response: %s, MIME: %s]\n", content.Data, content.MIMEType))
|
||||
case mcp.AudioContent:
|
||||
result.WriteString(fmt.Sprintf("[Audio Response: %s, MIME: %s]\n", content.Data, content.MIMEType))
|
||||
case mcp.EmbeddedResource:
|
||||
result.WriteString(fmt.Sprintf("[Embedded Resource Response: %s]\n", content.Type))
|
||||
default:
|
||||
// Fallback: try to extract from map structure
|
||||
if jsonBytes, err := schemas.MarshalSorted(contentBlock); err == nil {
|
||||
var contentMap map[string]interface{}
|
||||
if json.Unmarshal(jsonBytes, &contentMap) == nil {
|
||||
if text, ok := contentMap["text"].(string); ok {
|
||||
result.WriteString(fmt.Sprintf("[Text Response: %s]\n", text))
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Final fallback: serialize as JSON
|
||||
result.WriteString(string(jsonBytes))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if result.Len() > 0 {
|
||||
return strings.TrimSpace(result.String())
|
||||
}
|
||||
return fmt.Sprintf("MCP tool '%s' executed successfully", toolName)
|
||||
}
|
||||
|
||||
// createToolResponseMessage creates a tool response message with the execution result.
|
||||
func createToolResponseMessage(toolCall schemas.ChatAssistantMessageToolCall, responseText string) *schemas.ChatMessage {
|
||||
return &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleTool,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: &responseText,
|
||||
},
|
||||
ChatToolMessage: &schemas.ChatToolMessage{
|
||||
ToolCallID: toolCall.ID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// parseToolName normalizes a raw tool name into a Starlark-compatible identifier.
|
||||
func parseToolName(toolName string) string {
|
||||
if toolName == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
runes := []rune(toolName)
|
||||
|
||||
// Process first character - must be letter, underscore, or dollar sign
|
||||
if len(runes) > 0 {
|
||||
first := runes[0]
|
||||
if unicode.IsLetter(first) || first == '_' || first == '$' {
|
||||
result.WriteRune(unicode.ToLower(first))
|
||||
} else {
|
||||
// If first char is invalid, prefix with underscore
|
||||
result.WriteRune('_')
|
||||
if unicode.IsDigit(first) {
|
||||
result.WriteRune(first)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process remaining characters
|
||||
for i := 1; i < len(runes); i++ {
|
||||
r := runes[i]
|
||||
if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '$' {
|
||||
result.WriteRune(unicode.ToLower(r))
|
||||
} else if unicode.IsSpace(r) || r == '-' {
|
||||
// Replace spaces and hyphens with single underscore
|
||||
// Avoid consecutive underscores
|
||||
if result.Len() > 0 && result.String()[result.Len()-1] != '_' {
|
||||
result.WriteRune('_')
|
||||
}
|
||||
}
|
||||
// Skip other invalid characters
|
||||
}
|
||||
|
||||
parsed := result.String()
|
||||
|
||||
// Remove trailing underscores
|
||||
parsed = strings.TrimRight(parsed, "_")
|
||||
|
||||
// Ensure we have at least one character
|
||||
if parsed == "" {
|
||||
return "tool"
|
||||
}
|
||||
|
||||
return parsed
|
||||
}
|
||||
|
||||
// getCanonicalToolName returns the exact callable tool identifier exposed in Starlark.
|
||||
func getCanonicalToolName(clientName, originalToolName string) string {
|
||||
return parseToolName(stripClientPrefix(originalToolName, clientName))
|
||||
}
|
||||
|
||||
// getCompatibilityToolAlias returns the case-preserving alias derived from the raw tool name.
|
||||
// This is used as a compatibility alias when the raw name is still a valid Starlark identifier.
|
||||
func getCompatibilityToolAlias(clientName, originalToolName string) string {
|
||||
return strings.ReplaceAll(stripClientPrefix(originalToolName, clientName), "-", "_")
|
||||
}
|
||||
|
||||
// matchesToolReference reports whether the requested tool name matches any supported identifier form.
|
||||
// We accept the canonical callable name plus legacy display forms for backward compatibility.
|
||||
func matchesToolReference(requestedToolName, clientName, originalToolName string) bool {
|
||||
requested := strings.ToLower(requestedToolName)
|
||||
if requested == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
candidates := []string{
|
||||
getCanonicalToolName(clientName, originalToolName),
|
||||
getCompatibilityToolAlias(clientName, originalToolName),
|
||||
stripClientPrefix(originalToolName, clientName),
|
||||
}
|
||||
|
||||
for _, candidate := range candidates {
|
||||
if candidate != "" && requested == strings.ToLower(candidate) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isValidStarlarkIdentifier reports whether name can be used directly in Starlark code.
|
||||
func isValidStarlarkIdentifier(name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
runes := []rune(name)
|
||||
first := runes[0]
|
||||
if !unicode.IsLetter(first) && first != '_' && first != '$' {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, r := range runes[1:] {
|
||||
if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' && r != '$' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// validateNormalizedToolName validates a normalized tool name to prevent path traversal.
|
||||
func validateNormalizedToolName(normalizedName string) error {
|
||||
if normalizedName == "" {
|
||||
return fmt.Errorf("tool name cannot be empty after normalization")
|
||||
}
|
||||
if strings.Contains(normalizedName, "/") {
|
||||
return fmt.Errorf("tool name cannot contain '/' (path separator) after normalization: %s", normalizedName)
|
||||
}
|
||||
if strings.Contains(normalizedName, "..") {
|
||||
return fmt.Errorf("tool name cannot contain '..' (path traversal) after normalization: %s", normalizedName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// stripClientPrefix removes the client name prefix from a tool name.
|
||||
func stripClientPrefix(prefixedToolName, clientName string) string {
|
||||
prefix := clientName + "-"
|
||||
if strings.HasPrefix(prefixedToolName, prefix) {
|
||||
return strings.TrimPrefix(prefixedToolName, prefix)
|
||||
}
|
||||
// If prefix doesn't match, return as-is (shouldn't happen, but be safe)
|
||||
return prefixedToolName
|
||||
}
|
||||
Reference in New Issue
Block a user