diff --git a/desktop_env/controllers/python.py b/desktop_env/controllers/python.py index 60a4bb4..4159cde 100644 --- a/desktop_env/controllers/python.py +++ b/desktop_env/controllers/python.py @@ -263,16 +263,19 @@ class PythonController: """ Ends recording the screen. """ - response = requests.post(self.http_server + "/end_recording") - if response.status_code == 200: - logger.info("Recording stopped successfully") - with open(dest, 'wb') as f: - for chunk in response.iter_content(chunk_size=8192): - if chunk: - f.write(chunk) - else: - logger.error("Failed to stop recording. Status code: %d", response.status_code) - return None + try: + response = requests.post(self.http_server + "/end_recording") + if response.status_code == 200: + logger.info("Recording stopped successfully") + with open(dest, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + else: + logger.error("Failed to stop recording. Status code: %d", response.status_code) + return None + except Exception as e: + logger.error("An error occurred while trying to download the recording: %s", e) # Additional info def get_vm_platform(self): diff --git a/lib_run_single.py b/lib_run_single.py index e492736..ff9972d 100644 --- a/lib_run_single.py +++ b/lib_run_single.py @@ -2,6 +2,7 @@ import datetime import json import logging import os +import wandb from wrapt_timeout_decorator import * @@ -13,7 +14,6 @@ with open("./settings.json", "r") as file: data = json.load(file) time_limit = data["time_limit"] - @timeout(time_limit, use_signals=False) def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores): agent.reset() @@ -21,9 +21,9 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl done = False step_idx = 0 env.controller.start_recording() - + str_table = wandb.Table(columns=["Screenshot", "A11T", "Modle Response", "Action", "Action timestamp", "Done"]) while not done and step_idx < max_steps: - actions = agent.predict( + response, actions = agent.predict( instruction, obs ) @@ -31,20 +31,22 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl # Capture the timestamp before executing the action action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") logger.info("Step %d: %s", step_idx + 1, action) - obs, reward, done, info = env.step(action, args.sleep_after_execution) logger.info("Reward: %.2f", reward) logger.info("Done: %s", done) - logger.info("Info: %s", info) - # Save screenshot and trajectory information with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), "wb") as _f: with open(obs['screenshot'], "rb") as __f: screenshot = __f.read() _f.write(screenshot) - + # get a11tree and save to wandb + thisrun_a11tree = env.controller.get_accessibility_tree() + str_table.add_data(wandb.Image(data_or_path=os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), caption=f"step_{step_idx + 1}_{action_timestamp}"), + thisrun_a11tree, + response, action, action_timestamp, done) + wandb.log({"Reward": reward}) with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: f.write(json.dumps({ "step_num": step_idx + 1, @@ -56,14 +58,15 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl "screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png" })) f.write("\n") - if done: logger.info("The episode is done.") break step_idx += 1 + wandb.log({"str_trajectory": str_table}) result = env.evaluate() logger.info("Result: %.2f", result) scores.append(result) with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f: f.write(f"{result}\n") env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) + wandb.log({"Result": result}) diff --git a/mm_agents/agent.py b/mm_agents/agent.py index 7599b02..cb0ba85 100644 --- a/mm_agents/agent.py +++ b/mm_agents/agent.py @@ -15,6 +15,7 @@ import backoff import dashscope import google.generativeai as genai import requests +import wandb from PIL import Image from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes @@ -441,7 +442,7 @@ class PromptAgent: actions = None self.thoughts.append("") - return actions + return response, actions @backoff.on_exception( backoff.expo, diff --git a/run.py b/run.py index 3014e87..28563c8 100644 --- a/run.py +++ b/run.py @@ -7,6 +7,7 @@ import json import logging import os import sys +import wandb from tqdm import tqdm @@ -48,6 +49,11 @@ logger.addHandler(sdebug_handler) logger = logging.getLogger("desktopenv.experiment") +# wandb config +### set your wandb api key here +os.environ["WANDB_API_KEY"] = "" +wandb.login(key=os.environ["WANDB_API_KEY"]) + def config() -> argparse.Namespace: parser = argparse.ArgumentParser( @@ -104,6 +110,25 @@ def test( # log args logger.info("Args: %s", args) + # set wandb project + cfg_args = \ + { + "path_to_vm": args.path_to_vm, + "headless": args.headless, + "action_space": args.action_space, + "observation_type": args.observation_type, + "screen_width": args.screen_width, + "screen_height": args.screen_height, + "sleep_after_execution": args.sleep_after_execution, + "max_steps": args.max_steps, + "max_trajectory_length": args.max_trajectory_length, + "model": args.model, + "temperature": args.temperature, + "top_p": args.top_p, + "max_tokens": args.max_tokens, + "stop_token": args.stop_token, + "result_dir": args.result_dir + } agent = PromptAgent( model=args.model, @@ -122,6 +147,8 @@ def test( for domain in tqdm(test_all_meta, desc="Domain"): for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False): + wandb.init(project=f"OSworld-{args.action_space}-{args.observation_type}-{args.model}", group=f"{domain}", + name=f"{example_id}") # example setting config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json") with open(config_file, "r", encoding="utf-8") as f: @@ -133,6 +160,10 @@ def test( instruction = example["instruction"] logger.info(f"[Instruction]: {instruction}") + # wandb each example config settings + cfg_args["instruction"] = instruction + cfg_args["start_time"] = datetime.datetime.now().strftime("%Y:%m:%d-%H:%M:%S") + wandb.config.update(cfg_args) example_result_dir = os.path.join( args.result_dir, @@ -148,13 +179,20 @@ def test( lib_run_single.run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores) except Exception as e: + logger.error(f"Exception in {domain}/{example_id}: {e}") + wandb.log({"Exception": wandb.Table(data=[[f"Exception in {domain}/{example_id}: {e}"]], columns=["Error"])}) env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) - logger.error(f"Time limit exceeded in {domain}/{example_id}") with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: f.write(json.dumps({ "Error": f"Time limit exceeded in {domain}/{example_id}" })) f.write("\n") + # wandb settings + os.mkdir(os.path.join(wandb.run.dir, "results/")) + for file in os.listdir(example_result_dir): + # move file to just under the root dir + os.rename(os.path.join(example_result_dir, file), os.path.join(wandb.run.dir, f"./results/{file}")) + wandb.finish() env.close() logger.info(f"Average score: {sum(scores) / len(scores)}") @@ -235,11 +273,10 @@ if __name__ == '__main__': left_info += f"{domain}: {len(test_file_list[domain])}\n" logger.info(f"Left tasks:\n{left_info}") - get_result(args.action_space, - args.model, - args.observation_type, - args.result_dir, - test_all_meta - ) - - # test(args, test_all_meta) + # get_result(args.action_space, + # args.model, + # args.observation_type, + # args.result_dir, + # test_all_meta + # ) + test(args, test_file_list) diff --git a/settings.json b/settings.json index 469579c..23bab77 100644 --- a/settings.json +++ b/settings.json @@ -1,3 +1,3 @@ { - "time_limit": "1200" + "time_limit": "10" } \ No newline at end of file