Add Claude Sonnet 4.5 support and improve action handling (#362)
🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -13,6 +13,7 @@ import time
|
||||
from typing import List
|
||||
from multiprocessing import Process, Manager, current_process
|
||||
import lib_run_single
|
||||
from lib_results_logger import log_task_error
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
from mm_agents.anthropic import AnthropicAgent
|
||||
|
||||
@@ -67,17 +68,27 @@ def config() -> argparse.Namespace:
|
||||
)
|
||||
|
||||
# lm config
|
||||
parser.add_argument("--model", type=str, default="claude-4-sonnet-20250514")
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--top_p", type=float, default=0.9)
|
||||
parser.add_argument("--max_tokens", type=int, default=1500)
|
||||
parser.add_argument("--model", type=str, default="")
|
||||
parser.add_argument("--temperature", type=float, default=None)
|
||||
parser.add_argument("--top_p", type=float, default=None)
|
||||
parser.add_argument("--max_tokens", type=int, default=3000)
|
||||
parser.add_argument("--stop_token", type=str, default=None)
|
||||
|
||||
# thinking mode config
|
||||
parser.add_argument("--no-thinking", action="store_true",
|
||||
help="Disable thinking mode (no scratchpad)")
|
||||
parser.add_argument("--use-isp", action="store_true",
|
||||
help="Use interleaved scratchpad (ISP) mode")
|
||||
|
||||
# example config
|
||||
parser.add_argument("--domain", type=str, default="all")
|
||||
parser.add_argument(
|
||||
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--specific_task_id", type=str, default=None,
|
||||
help="Run only a specific task ID (overrides domain filtering)"
|
||||
)
|
||||
|
||||
# logging related
|
||||
parser.add_argument("--result_dir", type=str, default="./results")
|
||||
@@ -95,6 +106,37 @@ def config() -> argparse.Namespace:
|
||||
|
||||
args = config() # Get command line arguments first
|
||||
|
||||
# Validate that model is specified to prevent accidental usage with empty model
|
||||
if not args.model or args.model.strip() == "":
|
||||
print("ERROR: Model must be specified. Use --model <model_name>")
|
||||
print("Example: --model claude-sonnet-4-5-20250929")
|
||||
sys.exit(1)
|
||||
|
||||
# Validate model support before proceeding
|
||||
from mm_agents.anthropic.utils import validate_model_support
|
||||
|
||||
# Pass same temperature/top_p and thinking parameters as will be used by the agent
|
||||
validation_kwargs = {}
|
||||
if args.temperature is not None:
|
||||
validation_kwargs['temperature'] = args.temperature
|
||||
if args.top_p is not None:
|
||||
validation_kwargs['top_p'] = args.top_p
|
||||
validation_kwargs['no_thinking'] = args.no_thinking
|
||||
validation_kwargs['use_isp'] = args.use_isp
|
||||
|
||||
if not validate_model_support(args.model, **validation_kwargs):
|
||||
print(f"\n💥 Model '{args.model}' api sample failed")
|
||||
sys.exit(1)
|
||||
|
||||
# Validate thinking mode options are mutually exclusive
|
||||
if args.no_thinking and args.use_isp:
|
||||
print("ERROR: --no-thinking and --use-isp are mutually exclusive")
|
||||
print("Choose one of:")
|
||||
print(" (default): Regular scratchpad mode")
|
||||
print(" --no-thinking: Disable thinking/scratchpad")
|
||||
print(" --use-isp: Use interleaved scratchpad (ISP)")
|
||||
sys.exit(1)
|
||||
|
||||
logger = logging.getLogger()
|
||||
log_level = getattr(logging, args.log_level.upper())
|
||||
logger.setLevel(log_level)
|
||||
@@ -182,7 +224,7 @@ def run_env_tasks(task_queue, args, shared_scores):
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
enable_proxy=True,
|
||||
enable_proxy=False,
|
||||
client_password=args.client_password
|
||||
)
|
||||
active_environments.append(env)
|
||||
@@ -196,8 +238,9 @@ def run_env_tasks(task_queue, args, shared_scores):
|
||||
observation_type=args.observation_type,
|
||||
max_trajectory_length=args.max_trajectory_length,
|
||||
provider_name=args.provider_name,
|
||||
screen_width=args.screen_width,
|
||||
screen_height=args.screen_height,
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
no_thinking=getattr(args, 'no_thinking', False),
|
||||
use_isp=getattr(args, 'use_isp', False),
|
||||
)
|
||||
logger.info(f"Process {current_process().name} started.")
|
||||
while True:
|
||||
@@ -239,6 +282,14 @@ def run_env_tasks(task_queue, args, shared_scores):
|
||||
import traceback
|
||||
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# Log error to results.json
|
||||
try:
|
||||
example = {"id": example_id} # Create minimal example dict for error logging
|
||||
log_task_error(example, str(e), example_result_dir, args)
|
||||
except Exception as log_e:
|
||||
logger.error(f"Failed to log error to results.json: {log_e}")
|
||||
|
||||
try:
|
||||
env.controller.end_recording(
|
||||
os.path.join(example_result_dir, "recording.mp4")
|
||||
@@ -479,7 +530,28 @@ if __name__ == "__main__":
|
||||
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
if args.domain != "all":
|
||||
# Filter for specific task ID if provided
|
||||
if args.specific_task_id:
|
||||
logger.info(f"Filtering for specific task ID: {args.specific_task_id}")
|
||||
filtered_meta = {}
|
||||
task_found = False
|
||||
|
||||
for domain, task_ids in test_all_meta.items():
|
||||
for task_id in task_ids:
|
||||
if task_id == args.specific_task_id:
|
||||
filtered_meta[domain] = [task_id]
|
||||
task_found = True
|
||||
logger.info(f"Found task {args.specific_task_id} in domain: {domain}")
|
||||
break
|
||||
if task_found:
|
||||
break
|
||||
|
||||
if not task_found:
|
||||
logger.error(f"Task ID {args.specific_task_id} not found in test file!")
|
||||
sys.exit(1)
|
||||
|
||||
test_all_meta = filtered_meta
|
||||
elif args.domain != "all":
|
||||
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||
|
||||
test_file_list = get_unfinished(
|
||||
|
||||
Reference in New Issue
Block a user