diff --git a/lib_run_single.py b/lib_run_single.py index d60fd7a..82b2dd3 100644 --- a/lib_run_single.py +++ b/lib_run_single.py @@ -2,7 +2,7 @@ import datetime import json import logging import os -import wandb +# import wandb from wrapt_timeout_decorator import * @@ -15,13 +15,13 @@ with open("./settings.json", "r") as file: time_limit = data["time_limit"] @timeout(time_limit, use_signals=False) -def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores, run): +def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores): agent.reset() obs = env.reset(task_config=example) done = False step_idx = 0 env.controller.start_recording() - str_table = wandb.Table(columns=["Screenshot", "A11T", "Modle Response", "Action", "Action timestamp", "Done"]) + # str_table = wandb.Table(columns=["Screenshot", "A11T", "Modle Response", "Action", "Action timestamp", "Done"]) while not done and step_idx < max_steps: response, actions = agent.predict( instruction, @@ -43,10 +43,10 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl _f.write(screenshot) # get a11tree and save to wandb thisrun_a11tree = env.controller.get_accessibility_tree() - str_table.add_data(wandb.Image(data_or_path=os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), caption=f"step_{step_idx + 1}_{action_timestamp}"), - thisrun_a11tree, - response, action, action_timestamp, done) - run.log({"Reward": reward}) + # str_table.add_data(wandb.Image(data_or_path=os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), caption=f"step_{step_idx + 1}_{action_timestamp}"), + # thisrun_a11tree, + # response, action, action_timestamp, done) + # run.log({"Reward": reward}) with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: f.write(json.dumps({ "step_num": step_idx + 1, @@ -62,11 +62,11 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl logger.info("The episode is done.") break step_idx += 1 - run.log({"str_trajectory": str_table}) + # run.log({"str_trajectory": str_table}) result = env.evaluate() logger.info("Result: %.2f", result) scores.append(result) with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f: f.write(f"{result}\n") env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) - run.log({"Result": result}) + # run.log({"Result": result}) diff --git a/run.py b/run.py index 5212bc0..92e989a 100644 --- a/run.py +++ b/run.py @@ -8,7 +8,7 @@ import logging import os import random import sys -import wandb +# import wandb from tqdm import tqdm @@ -52,7 +52,8 @@ logger = logging.getLogger("desktopenv.experiment") # wandb config ### set your wandb api key here -wandb.login(key=os.environ.get("WANDB_API_KEY", None)) +# os.environ["WANDB_API_KEY"] = "48ec18fb4da7087238c6d6833eab9907565adbf3" +# wandb.login(key=os.environ.get("WANDB_API_KEY", None)) def config() -> argparse.Namespace: @@ -147,8 +148,8 @@ def test( for domain in tqdm(test_all_meta, desc="Domain"): for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False): - run = wandb.init(project=f"OSworld-{args.action_space}-{args.observation_type}-{args.model}", group=f"{domain}", - name=f"{example_id}") + # run = wandb.init(project=f"OSworld-{args.action_space}-{args.observation_type}-{args.model}", group=f"{domain}", + # name=f"{example_id}") # example setting config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json") with open(config_file, "r", encoding="utf-8") as f: @@ -163,7 +164,7 @@ def test( # wandb each example config settings cfg_args["instruction"] = instruction cfg_args["start_time"] = datetime.datetime.now().strftime("%Y:%m:%d-%H:%M:%S") - run.config.update(cfg_args) + # run.config.update(cfg_args) example_result_dir = os.path.join( args.result_dir, @@ -177,10 +178,10 @@ def test( # example start running try: lib_run_single.run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, - scores, run) + scores) except Exception as e: logger.error(f"Exception in {domain}/{example_id}: {e}") - wandb.log({"Exception": wandb.Table(data=[[f"Exception in {domain}/{example_id}: {e}"]], columns=["Error"])}) + # wandb.log({"Exception": wandb.Table(data=[[f"Exception in {domain}/{example_id}: {e}"]], columns=["Error"])}) env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: f.write(json.dumps({ @@ -188,11 +189,11 @@ def test( })) f.write("\n") # wandb settings - os.mkdir(os.path.join(wandb.run.dir, "results/")) - for file in os.listdir(example_result_dir): - # move file to just under the root dir - os.rename(os.path.join(example_result_dir, file), os.path.join(wandb.run.dir, f"./results/{file}")) - wandb.finish() + # os.mkdir(os.path.join(wandb.run.dir, "results/")) + # for file in os.listdir(example_result_dir): + # # move file to just under the root dir + # os.rename(os.path.join(example_result_dir, file), os.path.join(wandb.run.dir, f"./results/{file}")) + # wandb.finish() env.close() logger.info(f"Average score: {sum(scores) / len(scores)}")