Feat/claude cua support (#253)
* feat: add claude support * feat: add script for end-to-end evaluation with logging and task distribution * feat&fix: add tool result handling and update model default in evaluation script * chore: remove run_test_env.py script * feat&fix: implement action parsing for tool calls and update default action space * fix: update text formatting in action parsing and replace logger import * feat&fix: implement action parsing for tool calls and add screen size handling * feat: add setup instructions for Anthropic API integration * feat: add notice about image size limitations for Anthropic API * Delete test_env/logger.py * Delete test_env/utils.py
This commit is contained in:
69
mm_agents/anthropic/tools/base.py
Normal file
69
mm_agents/anthropic/tools/base.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import dataclass, fields, replace
|
||||
from typing import Any, Optional
|
||||
|
||||
from anthropic.types.beta import BetaToolUnionParam
|
||||
|
||||
|
||||
class BaseAnthropicTool(metaclass=ABCMeta):
|
||||
"""Abstract base class for Anthropic-defined tools."""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, **kwargs) -> Any:
|
||||
"""Executes the tool with the given arguments."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def to_params(
|
||||
self,
|
||||
) -> BetaToolUnionParam:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass(frozen=True) #kw_only=True,
|
||||
class ToolResult:
|
||||
"""Represents the result of a tool execution."""
|
||||
|
||||
output: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
base64_image: Optional[str] = None
|
||||
system: Optional[str] = None
|
||||
|
||||
def __bool__(self):
|
||||
return any(getattr(self, field.name) for field in fields(self))
|
||||
|
||||
def __add__(self, other: "ToolResult"):
|
||||
def combine_fields(
|
||||
field: Optional[str], other_field: Optional[str], concatenate: bool = True
|
||||
):
|
||||
if field and other_field:
|
||||
if concatenate:
|
||||
return field + other_field
|
||||
raise ValueError("Cannot combine tool results")
|
||||
return field or other_field
|
||||
|
||||
return ToolResult(
|
||||
output=combine_fields(self.output, other.output),
|
||||
error=combine_fields(self.error, other.error),
|
||||
base64_image=combine_fields(self.base64_image, other.base64_image, False),
|
||||
system=combine_fields(self.system, other.system),
|
||||
)
|
||||
|
||||
def replace(self, **kwargs):
|
||||
"""Returns a new ToolResult with the given fields replaced."""
|
||||
return replace(self, **kwargs)
|
||||
|
||||
|
||||
class CLIResult(ToolResult):
|
||||
"""A ToolResult that can be rendered as a CLI output."""
|
||||
|
||||
|
||||
class ToolFailure(ToolResult):
|
||||
"""A ToolResult that represents a failure."""
|
||||
|
||||
|
||||
class ToolError(Exception):
|
||||
"""Raised when a tool encounters an error."""
|
||||
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
Reference in New Issue
Block a user