Feat/claude cua support (#253)
* feat: add claude support * feat: add script for end-to-end evaluation with logging and task distribution * feat&fix: add tool result handling and update model default in evaluation script * chore: remove run_test_env.py script * feat&fix: implement action parsing for tool calls and update default action space * fix: update text formatting in action parsing and replace logger import * feat&fix: implement action parsing for tool calls and add screen size handling * feat: add setup instructions for Anthropic API integration * feat: add notice about image size limitations for Anthropic API * Delete test_env/logger.py * Delete test_env/utils.py
This commit is contained in:
14
mm_agents/anthropic/tools/__init__.py
Normal file
14
mm_agents/anthropic/tools/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from .base import CLIResult, ToolResult
|
||||
from .bash import BashTool
|
||||
from .collection import ToolCollection
|
||||
from .computer import ComputerTool
|
||||
from .edit import EditTool
|
||||
|
||||
__ALL__ = [
|
||||
BashTool,
|
||||
CLIResult,
|
||||
ComputerTool,
|
||||
EditTool,
|
||||
ToolCollection,
|
||||
ToolResult,
|
||||
]
|
||||
69
mm_agents/anthropic/tools/base.py
Normal file
69
mm_agents/anthropic/tools/base.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import dataclass, fields, replace
|
||||
from typing import Any, Optional
|
||||
|
||||
from anthropic.types.beta import BetaToolUnionParam
|
||||
|
||||
|
||||
class BaseAnthropicTool(metaclass=ABCMeta):
|
||||
"""Abstract base class for Anthropic-defined tools."""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, **kwargs) -> Any:
|
||||
"""Executes the tool with the given arguments."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def to_params(
|
||||
self,
|
||||
) -> BetaToolUnionParam:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass(frozen=True) #kw_only=True,
|
||||
class ToolResult:
|
||||
"""Represents the result of a tool execution."""
|
||||
|
||||
output: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
base64_image: Optional[str] = None
|
||||
system: Optional[str] = None
|
||||
|
||||
def __bool__(self):
|
||||
return any(getattr(self, field.name) for field in fields(self))
|
||||
|
||||
def __add__(self, other: "ToolResult"):
|
||||
def combine_fields(
|
||||
field: Optional[str], other_field: Optional[str], concatenate: bool = True
|
||||
):
|
||||
if field and other_field:
|
||||
if concatenate:
|
||||
return field + other_field
|
||||
raise ValueError("Cannot combine tool results")
|
||||
return field or other_field
|
||||
|
||||
return ToolResult(
|
||||
output=combine_fields(self.output, other.output),
|
||||
error=combine_fields(self.error, other.error),
|
||||
base64_image=combine_fields(self.base64_image, other.base64_image, False),
|
||||
system=combine_fields(self.system, other.system),
|
||||
)
|
||||
|
||||
def replace(self, **kwargs):
|
||||
"""Returns a new ToolResult with the given fields replaced."""
|
||||
return replace(self, **kwargs)
|
||||
|
||||
|
||||
class CLIResult(ToolResult):
|
||||
"""A ToolResult that can be rendered as a CLI output."""
|
||||
|
||||
|
||||
class ToolFailure(ToolResult):
|
||||
"""A ToolResult that represents a failure."""
|
||||
|
||||
|
||||
class ToolError(Exception):
|
||||
"""Raised when a tool encounters an error."""
|
||||
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
144
mm_agents/anthropic/tools/bash.py
Normal file
144
mm_agents/anthropic/tools/bash.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import ClassVar, Literal, Optional
|
||||
|
||||
from anthropic.types.beta import BetaToolBash20241022Param
|
||||
|
||||
from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult
|
||||
|
||||
|
||||
class _BashSession:
|
||||
"""A session of a bash shell."""
|
||||
|
||||
_started: bool
|
||||
_process: asyncio.subprocess.Process
|
||||
|
||||
command: str = "/bin/bash"
|
||||
_output_delay: float = 0.2 # seconds
|
||||
_timeout: float = 120.0 # seconds
|
||||
_sentinel: str = "<<exit>>"
|
||||
|
||||
def __init__(self):
|
||||
self._started = False
|
||||
self._timed_out = False
|
||||
|
||||
async def start(self):
|
||||
if self._started:
|
||||
return
|
||||
|
||||
self._process = await asyncio.create_subprocess_shell(
|
||||
self.command,
|
||||
preexec_fn=os.setsid,
|
||||
shell=True,
|
||||
bufsize=0,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
self._started = True
|
||||
|
||||
def stop(self):
|
||||
"""Terminate the bash shell."""
|
||||
if not self._started:
|
||||
raise ToolError("Session has not started.")
|
||||
if self._process.returncode is not None:
|
||||
return
|
||||
self._process.terminate()
|
||||
|
||||
async def run(self, command: str):
|
||||
"""Execute a command in the bash shell."""
|
||||
if not self._started:
|
||||
raise ToolError("Session has not started.")
|
||||
if self._process.returncode is not None:
|
||||
return ToolResult(
|
||||
system="tool must be restarted",
|
||||
error=f"bash has exited with returncode {self._process.returncode}",
|
||||
)
|
||||
if self._timed_out:
|
||||
raise ToolError(
|
||||
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
|
||||
)
|
||||
|
||||
# we know these are not None because we created the process with PIPEs
|
||||
assert self._process.stdin
|
||||
assert self._process.stdout
|
||||
assert self._process.stderr
|
||||
|
||||
# send command to the process
|
||||
self._process.stdin.write(
|
||||
command.encode() + f"; echo '{self._sentinel}'\n".encode()
|
||||
)
|
||||
await self._process.stdin.drain()
|
||||
|
||||
# read output from the process, until the sentinel is found
|
||||
try:
|
||||
async with asyncio.timeout(self._timeout):
|
||||
while True:
|
||||
await asyncio.sleep(self._output_delay)
|
||||
# if we read directly from stdout/stderr, it will wait forever for
|
||||
# EOF. use the StreamReader buffer directly instead.
|
||||
output = self._process.stdout._buffer.decode() # pyright: ignore[reportAttributeAccessIssue]
|
||||
if self._sentinel in output:
|
||||
# strip the sentinel and break
|
||||
output = output[: output.index(self._sentinel)]
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
self._timed_out = True
|
||||
raise ToolError(
|
||||
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
|
||||
) from None
|
||||
|
||||
if output.endswith("\n"):
|
||||
output = output[:-1]
|
||||
|
||||
error = self._process.stderr._buffer.decode() # pyright: ignore[reportAttributeAccessIssue]
|
||||
if error.endswith("\n"):
|
||||
error = error[:-1]
|
||||
|
||||
# clear the buffers so that the next output can be read correctly
|
||||
self._process.stdout._buffer.clear() # pyright: ignore[reportAttributeAccessIssue]
|
||||
self._process.stderr._buffer.clear() # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
return CLIResult(output=output, error=error)
|
||||
|
||||
|
||||
class BashTool(BaseAnthropicTool):
|
||||
"""
|
||||
A tool that allows the agent to run bash commands.
|
||||
The tool parameters are defined by Anthropic and are not editable.
|
||||
"""
|
||||
|
||||
_session: Optional[_BashSession]
|
||||
name: ClassVar[Literal["bash"]] = "bash"
|
||||
api_type: ClassVar[Literal["bash_20241022"]] = "bash_20241022"
|
||||
|
||||
def __init__(self):
|
||||
self._session = None
|
||||
super().__init__()
|
||||
|
||||
async def __call__(
|
||||
self, command: Optional[str] = None, restart: bool = False, **kwargs
|
||||
):
|
||||
if restart:
|
||||
if self._session:
|
||||
self._session.stop()
|
||||
self._session = _BashSession()
|
||||
await self._session.start()
|
||||
|
||||
return ToolResult(system="tool has been restarted.")
|
||||
|
||||
if self._session is None:
|
||||
self._session = _BashSession()
|
||||
await self._session.start()
|
||||
|
||||
if command is not None:
|
||||
return await self._session.run(command)
|
||||
|
||||
raise ToolError("no command provided.")
|
||||
|
||||
def to_params(self) -> BetaToolBash20241022Param:
|
||||
return {
|
||||
"type": self.api_type,
|
||||
"name": self.name,
|
||||
}
|
||||
34
mm_agents/anthropic/tools/collection.py
Normal file
34
mm_agents/anthropic/tools/collection.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Collection classes for managing multiple tools."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from anthropic.types.beta import BetaToolUnionParam
|
||||
|
||||
from .base import (
|
||||
BaseAnthropicTool,
|
||||
ToolError,
|
||||
ToolFailure,
|
||||
ToolResult,
|
||||
)
|
||||
|
||||
|
||||
class ToolCollection:
|
||||
"""A collection of anthropic-defined tools."""
|
||||
|
||||
def __init__(self, *tools: BaseAnthropicTool):
|
||||
self.tools = tools
|
||||
self.tool_map = {tool.to_params()["name"]: tool for tool in tools}
|
||||
|
||||
def to_params(
|
||||
self,
|
||||
) -> list[BetaToolUnionParam]:
|
||||
return [tool.to_params() for tool in self.tools]
|
||||
|
||||
async def run(self, *, name: str, tool_input: dict[str, Any]) -> ToolResult:
|
||||
tool = self.tool_map.get(name)
|
||||
if not tool:
|
||||
return ToolFailure(error=f"Tool {name} is invalid")
|
||||
try:
|
||||
return await tool(**tool_input)
|
||||
except ToolError as e:
|
||||
return ToolFailure(error=e.message)
|
||||
260
mm_agents/anthropic/tools/computer.py
Normal file
260
mm_agents/anthropic/tools/computer.py
Normal file
@@ -0,0 +1,260 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
import shlex
|
||||
import shutil
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Literal, TypedDict, Optional, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
from anthropic.types.beta import BetaToolComputerUse20241022Param
|
||||
|
||||
from .base import BaseAnthropicTool, ToolError, ToolResult
|
||||
from .run import run
|
||||
|
||||
OUTPUT_DIR = "/tmp/outputs"
|
||||
|
||||
TYPING_DELAY_MS = 12
|
||||
TYPING_GROUP_SIZE = 50
|
||||
|
||||
Action = Literal[
|
||||
"key",
|
||||
"type",
|
||||
"mouse_move",
|
||||
"left_click",
|
||||
"left_click_drag",
|
||||
"right_click",
|
||||
"middle_click",
|
||||
"double_click",
|
||||
"screenshot",
|
||||
"cursor_position",
|
||||
]
|
||||
|
||||
|
||||
class Resolution(TypedDict):
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
# sizes above XGA/WXGA are not recommended (see README.md)
|
||||
# scale down to one of these targets if ComputerTool._scaling_enabled is set
|
||||
MAX_SCALING_TARGETS: dict[str, Resolution] = {
|
||||
"XGA": Resolution(width=1024, height=768), # 4:3
|
||||
"WXGA": Resolution(width=1280, height=800), # 16:10
|
||||
"FWXGA": Resolution(width=1366, height=768), # ~16:9
|
||||
}
|
||||
|
||||
|
||||
class ScalingSource(Enum):
|
||||
COMPUTER = "computer"
|
||||
API = "api"
|
||||
|
||||
|
||||
class ComputerToolOptions(TypedDict):
|
||||
display_height_px: int
|
||||
display_width_px: int
|
||||
display_number: Optional[int]
|
||||
|
||||
|
||||
def chunks(s: str, chunk_size: int) -> list[str]:
|
||||
return [s[i : i + chunk_size] for i in range(0, len(s), chunk_size)]
|
||||
|
||||
|
||||
class ComputerTool(BaseAnthropicTool):
|
||||
"""
|
||||
A tool that allows the agent to interact with the screen, keyboard, and mouse of the current computer.
|
||||
The tool parameters are defined by Anthropic and are not editable.
|
||||
"""
|
||||
|
||||
name: Literal["computer"] = "computer"
|
||||
api_type: Literal["computer_20241022"] = "computer_20241022"
|
||||
width: int
|
||||
height: int
|
||||
display_num: Optional[int]
|
||||
|
||||
_screenshot_delay = 2.0
|
||||
_scaling_enabled = True
|
||||
|
||||
@property
|
||||
def options(self) -> ComputerToolOptions:
|
||||
width, height = self.scale_coordinates(
|
||||
ScalingSource.COMPUTER, self.width, self.height
|
||||
)
|
||||
return {
|
||||
"display_width_px": width,
|
||||
"display_height_px": height,
|
||||
"display_number": self.display_num,
|
||||
}
|
||||
|
||||
def to_params(self) -> BetaToolComputerUse20241022Param:
|
||||
return {"name": self.name, "type": self.api_type, **self.options}
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.width = int(os.getenv("WIDTH") or 0)
|
||||
self.height = int(os.getenv("HEIGHT") or 0)
|
||||
assert self.width and self.height, "WIDTH, HEIGHT must be set"
|
||||
if (display_num := os.getenv("DISPLAY_NUM")) is not None:
|
||||
self.display_num = int(display_num)
|
||||
self._display_prefix = f"DISPLAY=:{self.display_num} "
|
||||
else:
|
||||
self.display_num = None
|
||||
self._display_prefix = ""
|
||||
|
||||
self.xdotool = f"{self._display_prefix}xdotool"
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
*,
|
||||
action: Action,
|
||||
text: Optional[str] = None,
|
||||
coordinate: Optional[Tuple[int, int]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if action in ("mouse_move", "left_click_drag"):
|
||||
if coordinate is None:
|
||||
raise ToolError(f"coordinate is required for {action}")
|
||||
if text is not None:
|
||||
raise ToolError(f"text is not accepted for {action}")
|
||||
if not isinstance(coordinate, list) or len(coordinate) != 2:
|
||||
raise ToolError(f"{coordinate} must be a tuple of length 2")
|
||||
if not all(isinstance(i, int) and i >= 0 for i in coordinate):
|
||||
raise ToolError(f"{coordinate} must be a tuple of non-negative ints")
|
||||
|
||||
x, y = self.scale_coordinates(
|
||||
ScalingSource.API, coordinate[0], coordinate[1]
|
||||
)
|
||||
|
||||
if action == "mouse_move":
|
||||
return await self.shell(f"{self.xdotool} mousemove --sync {x} {y}")
|
||||
elif action == "left_click_drag":
|
||||
return await self.shell(
|
||||
f"{self.xdotool} mousedown 1 mousemove --sync {x} {y} mouseup 1"
|
||||
)
|
||||
|
||||
if action in ("key", "type"):
|
||||
if text is None:
|
||||
raise ToolError(f"text is required for {action}")
|
||||
if coordinate is not None:
|
||||
raise ToolError(f"coordinate is not accepted for {action}")
|
||||
if not isinstance(text, str):
|
||||
raise ToolError(output=f"{text} must be a string")
|
||||
|
||||
if action == "key":
|
||||
return await self.shell(f"{self.xdotool} key -- {text}")
|
||||
elif action == "type":
|
||||
results: list[ToolResult] = []
|
||||
for chunk in chunks(text, TYPING_GROUP_SIZE):
|
||||
cmd = f"{self.xdotool} type --delay {TYPING_DELAY_MS} -- {shlex.quote(chunk)}"
|
||||
results.append(await self.shell(cmd, take_screenshot=False))
|
||||
screenshot_base64 = (await self.screenshot()).base64_image
|
||||
return ToolResult(
|
||||
output="".join(result.output or "" for result in results),
|
||||
error="".join(result.error or "" for result in results),
|
||||
base64_image=screenshot_base64,
|
||||
)
|
||||
|
||||
if action in (
|
||||
"left_click",
|
||||
"right_click",
|
||||
"double_click",
|
||||
"middle_click",
|
||||
"screenshot",
|
||||
"cursor_position",
|
||||
):
|
||||
if text is not None:
|
||||
raise ToolError(f"text is not accepted for {action}")
|
||||
if coordinate is not None:
|
||||
raise ToolError(f"coordinate is not accepted for {action}")
|
||||
|
||||
if action == "screenshot":
|
||||
return await self.screenshot()
|
||||
elif action == "cursor_position":
|
||||
result = await self.shell(
|
||||
f"{self.xdotool} getmouselocation --shell",
|
||||
take_screenshot=False,
|
||||
)
|
||||
output = result.output or ""
|
||||
x, y = self.scale_coordinates(
|
||||
ScalingSource.COMPUTER,
|
||||
int(output.split("X=")[1].split("\n")[0]),
|
||||
int(output.split("Y=")[1].split("\n")[0]),
|
||||
)
|
||||
return result.replace(output=f"X={x},Y={y}")
|
||||
else:
|
||||
click_arg = {
|
||||
"left_click": "1",
|
||||
"right_click": "3",
|
||||
"middle_click": "2",
|
||||
"double_click": "--repeat 2 --delay 500 1",
|
||||
}[action]
|
||||
return await self.shell(f"{self.xdotool} click {click_arg}")
|
||||
|
||||
raise ToolError(f"Invalid action: {action}")
|
||||
|
||||
async def screenshot(self):
|
||||
"""Take a screenshot of the current screen and return the base64 encoded image."""
|
||||
output_dir = Path(OUTPUT_DIR)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = output_dir / f"screenshot_{uuid4().hex}.png"
|
||||
|
||||
# Try gnome-screenshot first
|
||||
if shutil.which("gnome-screenshot"):
|
||||
screenshot_cmd = f"{self._display_prefix}gnome-screenshot -f {path} -p"
|
||||
else:
|
||||
# Fall back to scrot if gnome-screenshot isn't available
|
||||
screenshot_cmd = f"{self._display_prefix}scrot -p {path}"
|
||||
|
||||
result = await self.shell(screenshot_cmd, take_screenshot=False)
|
||||
if self._scaling_enabled:
|
||||
x, y = self.scale_coordinates(
|
||||
ScalingSource.COMPUTER, self.width, self.height
|
||||
)
|
||||
await self.shell(
|
||||
f"convert {path} -resize {x}x{y}! {path}", take_screenshot=False
|
||||
)
|
||||
|
||||
if path.exists():
|
||||
return result.replace(
|
||||
base64_image=base64.b64encode(path.read_bytes()).decode()
|
||||
)
|
||||
raise ToolError(f"Failed to take screenshot: {result.error}")
|
||||
|
||||
async def shell(self, command: str, take_screenshot=True) -> ToolResult:
|
||||
"""Run a shell command and return the output, error, and optionally a screenshot."""
|
||||
_, stdout, stderr = await run(command)
|
||||
base64_image = None
|
||||
|
||||
if take_screenshot:
|
||||
# delay to let things settle before taking a screenshot
|
||||
await asyncio.sleep(self._screenshot_delay)
|
||||
base64_image = (await self.screenshot()).base64_image
|
||||
|
||||
return ToolResult(output=stdout, error=stderr, base64_image=base64_image)
|
||||
|
||||
def scale_coordinates(self, source: ScalingSource, x: int, y: int):
|
||||
"""Scale coordinates to a target maximum resolution."""
|
||||
if not self._scaling_enabled:
|
||||
return x, y
|
||||
ratio = self.width / self.height
|
||||
target_dimension = None
|
||||
for dimension in MAX_SCALING_TARGETS.values():
|
||||
# allow some error in the aspect ratio - not ratios are exactly 16:9
|
||||
if abs(dimension["width"] / dimension["height"] - ratio) < 0.02:
|
||||
if dimension["width"] < self.width:
|
||||
target_dimension = dimension
|
||||
break
|
||||
if target_dimension is None:
|
||||
return x, y
|
||||
# should be less than 1
|
||||
x_scaling_factor = target_dimension["width"] / self.width
|
||||
y_scaling_factor = target_dimension["height"] / self.height
|
||||
if source == ScalingSource.API:
|
||||
if x > self.width or y > self.height:
|
||||
raise ToolError(f"Coordinates {x}, {y} are out of bounds")
|
||||
# scale up
|
||||
return round(x / x_scaling_factor), round(y / y_scaling_factor)
|
||||
# scale down
|
||||
return round(x * x_scaling_factor), round(y * y_scaling_factor)
|
||||
290
mm_agents/anthropic/tools/edit.py
Normal file
290
mm_agents/anthropic/tools/edit.py
Normal file
@@ -0,0 +1,290 @@
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Literal, get_args, Optional, List
|
||||
|
||||
from anthropic.types.beta import BetaToolTextEditor20241022Param
|
||||
|
||||
from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult
|
||||
from .run import maybe_truncate, run
|
||||
|
||||
Command = Literal[
|
||||
"view",
|
||||
"create",
|
||||
"str_replace",
|
||||
"insert",
|
||||
"undo_edit",
|
||||
]
|
||||
SNIPPET_LINES: int = 4
|
||||
|
||||
|
||||
class EditTool(BaseAnthropicTool):
|
||||
"""
|
||||
An filesystem editor tool that allows the agent to view, create, and edit files.
|
||||
The tool parameters are defined by Anthropic and are not editable.
|
||||
"""
|
||||
|
||||
api_type: Literal["text_editor_20241022"] = "text_editor_20241022"
|
||||
name: Literal["str_replace_editor"] = "str_replace_editor"
|
||||
|
||||
_file_history: dict[Path, list[str]]
|
||||
|
||||
def __init__(self):
|
||||
self._file_history = defaultdict(list)
|
||||
super().__init__()
|
||||
|
||||
def to_params(self) -> BetaToolTextEditor20241022Param:
|
||||
return {
|
||||
"name": self.name,
|
||||
"type": self.api_type,
|
||||
}
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
*,
|
||||
command: Command,
|
||||
path: str,
|
||||
file_text: Optional[str] = None,
|
||||
view_range: Optional[list[int]] = None,
|
||||
old_str: Optional[str] = None,
|
||||
new_str: Optional[str] = None,
|
||||
insert_line: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
_path = Path(path)
|
||||
self.validate_path(command, _path)
|
||||
if command == "view":
|
||||
return await self.view(_path, view_range)
|
||||
elif command == "create":
|
||||
if file_text is None:
|
||||
raise ToolError("Parameter `file_text` is required for command: create")
|
||||
self.write_file(_path, file_text)
|
||||
self._file_history[_path].append(file_text)
|
||||
return ToolResult(output=f"File created successfully at: {_path}")
|
||||
elif command == "str_replace":
|
||||
if old_str is None:
|
||||
raise ToolError(
|
||||
"Parameter `old_str` is required for command: str_replace"
|
||||
)
|
||||
return self.str_replace(_path, old_str, new_str)
|
||||
elif command == "insert":
|
||||
if insert_line is None:
|
||||
raise ToolError(
|
||||
"Parameter `insert_line` is required for command: insert"
|
||||
)
|
||||
if new_str is None:
|
||||
raise ToolError("Parameter `new_str` is required for command: insert")
|
||||
return self.insert(_path, insert_line, new_str)
|
||||
elif command == "undo_edit":
|
||||
return self.undo_edit(_path)
|
||||
raise ToolError(
|
||||
f'Unrecognized command {command}. The allowed commands for the {self.name} tool are: {", ".join(get_args(Command))}'
|
||||
)
|
||||
|
||||
def validate_path(self, command: str, path: Path):
|
||||
"""
|
||||
Check that the path/command combination is valid.
|
||||
"""
|
||||
# Check if its an absolute path
|
||||
if not path.is_absolute():
|
||||
suggested_path = Path("") / path
|
||||
raise ToolError(
|
||||
f"The path {path} is not an absolute path, it should start with `/`. Maybe you meant {suggested_path}?"
|
||||
)
|
||||
# Check if path exists
|
||||
if not path.exists() and command != "create":
|
||||
raise ToolError(
|
||||
f"The path {path} does not exist. Please provide a valid path."
|
||||
)
|
||||
if path.exists() and command == "create":
|
||||
raise ToolError(
|
||||
f"File already exists at: {path}. Cannot overwrite files using command `create`."
|
||||
)
|
||||
# Check if the path points to a directory
|
||||
if path.is_dir():
|
||||
if command != "view":
|
||||
raise ToolError(
|
||||
f"The path {path} is a directory and only the `view` command can be used on directories"
|
||||
)
|
||||
|
||||
async def view(self, path: Path, view_range: Optional[List[int]] = None):
|
||||
"""Implement the view command"""
|
||||
if path.is_dir():
|
||||
if view_range:
|
||||
raise ToolError(
|
||||
"The `view_range` parameter is not allowed when `path` points to a directory."
|
||||
)
|
||||
|
||||
_, stdout, stderr = await run(
|
||||
rf"find {path} -maxdepth 2 -not -path '*/\.*'"
|
||||
)
|
||||
if not stderr:
|
||||
stdout = f"Here's the files and directories up to 2 levels deep in {path}, excluding hidden items:\n{stdout}\n"
|
||||
return CLIResult(output=stdout, error=stderr)
|
||||
|
||||
file_content = self.read_file(path)
|
||||
init_line = 1
|
||||
if view_range:
|
||||
if len(view_range) != 2 or not all(isinstance(i, int) for i in view_range):
|
||||
raise ToolError(
|
||||
"Invalid `view_range`. It should be a list of two integers."
|
||||
)
|
||||
file_lines = file_content.split("\n")
|
||||
n_lines_file = len(file_lines)
|
||||
init_line, final_line = view_range
|
||||
if init_line < 1 or init_line > n_lines_file:
|
||||
raise ToolError(
|
||||
f"Invalid `view_range`: {view_range}. Its first element `{init_line}` should be within the range of lines of the file: {[1, n_lines_file]}"
|
||||
)
|
||||
if final_line > n_lines_file:
|
||||
raise ToolError(
|
||||
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be smaller than the number of lines in the file: `{n_lines_file}`"
|
||||
)
|
||||
if final_line != -1 and final_line < init_line:
|
||||
raise ToolError(
|
||||
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be larger or equal than its first `{init_line}`"
|
||||
)
|
||||
|
||||
if final_line == -1:
|
||||
file_content = "\n".join(file_lines[init_line - 1 :])
|
||||
else:
|
||||
file_content = "\n".join(file_lines[init_line - 1 : final_line])
|
||||
|
||||
return CLIResult(
|
||||
output=self._make_output(file_content, str(path), init_line=init_line)
|
||||
)
|
||||
|
||||
def str_replace(self, path: Path, old_str: str, new_str: Optional[str]):
|
||||
"""Implement the str_replace command, which replaces old_str with new_str in the file content"""
|
||||
# Read the file content
|
||||
file_content = self.read_file(path).expandtabs()
|
||||
old_str = old_str.expandtabs()
|
||||
new_str = new_str.expandtabs() if new_str is not None else ""
|
||||
|
||||
# Check if old_str is unique in the file
|
||||
occurrences = file_content.count(old_str)
|
||||
if occurrences == 0:
|
||||
raise ToolError(
|
||||
f"No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}."
|
||||
)
|
||||
elif occurrences > 1:
|
||||
file_content_lines = file_content.split("\n")
|
||||
lines = [
|
||||
idx + 1
|
||||
for idx, line in enumerate(file_content_lines)
|
||||
if old_str in line
|
||||
]
|
||||
raise ToolError(
|
||||
f"No replacement was performed. Multiple occurrences of old_str `{old_str}` in lines {lines}. Please ensure it is unique"
|
||||
)
|
||||
|
||||
# Replace old_str with new_str
|
||||
new_file_content = file_content.replace(old_str, new_str)
|
||||
|
||||
# Write the new content to the file
|
||||
self.write_file(path, new_file_content)
|
||||
|
||||
# Save the content to history
|
||||
self._file_history[path].append(file_content)
|
||||
|
||||
# Create a snippet of the edited section
|
||||
replacement_line = file_content.split(old_str)[0].count("\n")
|
||||
start_line = max(0, replacement_line - SNIPPET_LINES)
|
||||
end_line = replacement_line + SNIPPET_LINES + new_str.count("\n")
|
||||
snippet = "\n".join(new_file_content.split("\n")[start_line : end_line + 1])
|
||||
|
||||
# Prepare the success message
|
||||
success_msg = f"The file {path} has been edited. "
|
||||
success_msg += self._make_output(
|
||||
snippet, f"a snippet of {path}", start_line + 1
|
||||
)
|
||||
success_msg += "Review the changes and make sure they are as expected. Edit the file again if necessary."
|
||||
|
||||
return CLIResult(output=success_msg)
|
||||
|
||||
def insert(self, path: Path, insert_line: int, new_str: str):
|
||||
"""Implement the insert command, which inserts new_str at the specified line in the file content."""
|
||||
file_text = self.read_file(path).expandtabs()
|
||||
new_str = new_str.expandtabs()
|
||||
file_text_lines = file_text.split("\n")
|
||||
n_lines_file = len(file_text_lines)
|
||||
|
||||
if insert_line < 0 or insert_line > n_lines_file:
|
||||
raise ToolError(
|
||||
f"Invalid `insert_line` parameter: {insert_line}. It should be within the range of lines of the file: {[0, n_lines_file]}"
|
||||
)
|
||||
|
||||
new_str_lines = new_str.split("\n")
|
||||
new_file_text_lines = (
|
||||
file_text_lines[:insert_line]
|
||||
+ new_str_lines
|
||||
+ file_text_lines[insert_line:]
|
||||
)
|
||||
snippet_lines = (
|
||||
file_text_lines[max(0, insert_line - SNIPPET_LINES) : insert_line]
|
||||
+ new_str_lines
|
||||
+ file_text_lines[insert_line : insert_line + SNIPPET_LINES]
|
||||
)
|
||||
|
||||
new_file_text = "\n".join(new_file_text_lines)
|
||||
snippet = "\n".join(snippet_lines)
|
||||
|
||||
self.write_file(path, new_file_text)
|
||||
self._file_history[path].append(file_text)
|
||||
|
||||
success_msg = f"The file {path} has been edited. "
|
||||
success_msg += self._make_output(
|
||||
snippet,
|
||||
"a snippet of the edited file",
|
||||
max(1, insert_line - SNIPPET_LINES + 1),
|
||||
)
|
||||
success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary."
|
||||
return CLIResult(output=success_msg)
|
||||
|
||||
def undo_edit(self, path: Path):
|
||||
"""Implement the undo_edit command."""
|
||||
if not self._file_history[path]:
|
||||
raise ToolError(f"No edit history found for {path}.")
|
||||
|
||||
old_text = self._file_history[path].pop()
|
||||
self.write_file(path, old_text)
|
||||
|
||||
return CLIResult(
|
||||
output=f"Last edit to {path} undone successfully. {self._make_output(old_text, str(path))}"
|
||||
)
|
||||
|
||||
def read_file(self, path: Path):
|
||||
"""Read the content of a file from a given path; raise a ToolError if an error occurs."""
|
||||
try:
|
||||
return path.read_text()
|
||||
except Exception as e:
|
||||
raise ToolError(f"Ran into {e} while trying to read {path}") from None
|
||||
|
||||
def write_file(self, path: Path, file: str):
|
||||
"""Write the content of a file to a given path; raise a ToolError if an error occurs."""
|
||||
try:
|
||||
path.write_text(file)
|
||||
except Exception as e:
|
||||
raise ToolError(f"Ran into {e} while trying to write to {path}") from None
|
||||
|
||||
def _make_output(
|
||||
self,
|
||||
file_content: str,
|
||||
file_descriptor: str,
|
||||
init_line: int = 1,
|
||||
expand_tabs: bool = True,
|
||||
):
|
||||
"""Generate output for the CLI based on the content of a file."""
|
||||
file_content = maybe_truncate(file_content)
|
||||
if expand_tabs:
|
||||
file_content = file_content.expandtabs()
|
||||
file_content = "\n".join(
|
||||
[
|
||||
f"{i + init_line:6}\t{line}"
|
||||
for i, line in enumerate(file_content.split("\n"))
|
||||
]
|
||||
)
|
||||
return (
|
||||
f"Here's the result of running `cat -n` on {file_descriptor}:\n"
|
||||
+ file_content
|
||||
+ "\n"
|
||||
)
|
||||
42
mm_agents/anthropic/tools/run.py
Normal file
42
mm_agents/anthropic/tools/run.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""Utility to run shell commands asynchronously with a timeout."""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
TRUNCATED_MESSAGE: str = "<response clipped><NOTE>To save on context only part of this file has been shown to you. You should retry this tool after you have searched inside the file with `grep -n` in order to find the line numbers of what you are looking for.</NOTE>"
|
||||
MAX_RESPONSE_LEN: int = 16000
|
||||
|
||||
|
||||
def maybe_truncate(content: str, truncate_after: Optional[int] = MAX_RESPONSE_LEN):
|
||||
"""Truncate content and append a notice if content exceeds the specified length."""
|
||||
return (
|
||||
content
|
||||
if not truncate_after or len(content) <= truncate_after
|
||||
else content[:truncate_after] + TRUNCATED_MESSAGE
|
||||
)
|
||||
|
||||
|
||||
async def run(
|
||||
cmd: str,
|
||||
timeout: Optional[float] = 120.0, # seconds
|
||||
truncate_after: Optional[int] = MAX_RESPONSE_LEN,
|
||||
):
|
||||
"""Run a shell command asynchronously with a timeout."""
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
||||
return (
|
||||
process.returncode or 0,
|
||||
maybe_truncate(stdout.decode(), truncate_after=truncate_after),
|
||||
maybe_truncate(stderr.decode(), truncate_after=truncate_after),
|
||||
)
|
||||
except asyncio.TimeoutError as exc:
|
||||
try:
|
||||
process.kill()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
raise TimeoutError(
|
||||
f"Command '{cmd}' timed out after {timeout} seconds"
|
||||
) from exc
|
||||
Reference in New Issue
Block a user