Wxy/opencua (#260)
* OpenCUA Agent code base * update url * debug, modify url input * debug opencua * show result * debug agent history overlap * modify opencua agent; add comment lines
This commit is contained in:
@@ -20,8 +20,6 @@ active_environments = []
|
||||
processes = []
|
||||
is_terminating = False
|
||||
|
||||
# import wandb
|
||||
|
||||
# load the environment variables from .env file
|
||||
if os.path.exists(".env"):
|
||||
from dotenv import load_dotenv
|
||||
@@ -47,17 +45,8 @@ def config() -> argparse.Namespace:
|
||||
default="screenshot",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument("--screen_width", type=int, default=1920)
|
||||
parser.add_argument("--screen_height", type=int, default=1080)
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
||||
parser.add_argument("--max_steps", type=int, default=15)
|
||||
|
||||
# agent config
|
||||
parser.add_argument("--cot_level", type=str, default="l2", help="CoT version: l0, l1, l2, l3")
|
||||
parser.add_argument("--history_type", type=str, default="action_history", help="History: action history")
|
||||
parser.add_argument("--coordinate_type", type=str, default="relative", help="type of coordinate", choices=["relative", "qwen25"])
|
||||
parser.add_argument("--max_image_history_length", type=int, default=3)
|
||||
parser.add_argument("--detail_history_length", type=int, default=0, help="length of detail history")
|
||||
|
||||
# evaluation config
|
||||
parser.add_argument(
|
||||
@@ -71,6 +60,12 @@ def config() -> argparse.Namespace:
|
||||
parser.add_argument("--max_tokens", type=int, default=1500)
|
||||
parser.add_argument("--stop_token", type=str, default=None)
|
||||
|
||||
# OpenCUAagent config
|
||||
parser.add_argument("--cot_level", type=str, default="l2", help="CoT version: l1, l2, l3. Default is l2 includes 'thought' and 'action'")
|
||||
parser.add_argument("--history_type", type=str, default="action_history", help="Use action to represent history steps", choices=["action_history", "thought_history", "observation_history"])
|
||||
parser.add_argument("--coordinate_type", type=str, default="relative", help="Type of coordinate: Qwen2-VL or Kimi-VL based models use 'relative'; Qwen2.5-VL based models use 'qwen25'", choices=["relative", "qwen25"])
|
||||
parser.add_argument("--max_image_history_length", type=int, default=3, help="The max number of images in the history.")
|
||||
|
||||
# example config
|
||||
parser.add_argument("--domain", type=str, default="all")
|
||||
parser.add_argument(
|
||||
@@ -86,6 +81,18 @@ def config() -> argparse.Namespace:
|
||||
parser.add_argument(
|
||||
"--region", type=str, default="us-east-1", help="AWS region for the VM"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--client_password", type=str, default="", help="Client password"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--screen_width", type=int, default=1920, help="Screen width"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--screen_height", type=int, default=1080, help="Screen height"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@@ -187,36 +194,24 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
|
||||
signal.signal(signal.SIGTERM, lambda signum, frame: process_signal_handler(signum, frame, env_idx))
|
||||
|
||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||
REGION = "us-east-1"
|
||||
REGION = args.region
|
||||
screen_size = (args.screen_width, args.screen_height)
|
||||
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
|
||||
env = DesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
|
||||
provider_name="aws",
|
||||
provider_name=args.provider_name,
|
||||
region=REGION,
|
||||
snapshot_name=IMAGE_ID_MAP[REGION],
|
||||
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
snapshot_name=ami_id,
|
||||
screen_size=screen_size,
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
enable_proxy=True,
|
||||
client_password=args.client_password
|
||||
)
|
||||
active_environments.append(env)
|
||||
agent = OpenCUAAgent(
|
||||
env=env,
|
||||
model=args.model,
|
||||
max_tokens=args.max_tokens,
|
||||
top_p=args.top_p,
|
||||
temperature=args.temperature,
|
||||
action_space=args.action_space,
|
||||
observation_type=args.observation_type,
|
||||
cot_level=args.cot_level,
|
||||
history_type=args.history_type,
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
coordinate_type=args.coordinate_type,
|
||||
max_image_history_length=args.max_image_history_length,
|
||||
detail_history_length=args.detail_history_length,
|
||||
)
|
||||
|
||||
logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}")
|
||||
|
||||
try:
|
||||
@@ -242,6 +237,21 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
|
||||
)
|
||||
os.makedirs(example_result_dir, exist_ok=True)
|
||||
|
||||
agent = OpenCUAAgent(
|
||||
env=env,
|
||||
model=args.model,
|
||||
max_tokens=args.max_tokens,
|
||||
top_p=args.top_p,
|
||||
temperature=args.temperature,
|
||||
action_space=args.action_space,
|
||||
observation_type=args.observation_type,
|
||||
cot_level=args.cot_level,
|
||||
history_type=args.history_type,
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
coordinate_type=args.coordinate_type,
|
||||
max_image_history_length=args.max_image_history_length,
|
||||
)
|
||||
|
||||
try:
|
||||
lib_run_single.run_single_example_opencua(
|
||||
agent,
|
||||
|
||||
Reference in New Issue
Block a user