diff --git a/lib_run_single.py b/lib_run_single.py index bcf2496..d60fd7a 100644 --- a/lib_run_single.py +++ b/lib_run_single.py @@ -46,7 +46,7 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl str_table.add_data(wandb.Image(data_or_path=os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), caption=f"step_{step_idx + 1}_{action_timestamp}"), thisrun_a11tree, response, action, action_timestamp, done) - run.log({"Reward": reward}) + run.log({"Reward": reward}) with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: f.write(json.dumps({ "step_num": step_idx + 1, @@ -62,7 +62,6 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl logger.info("The episode is done.") break step_idx += 1 - # wandb.log({"str_trajectory": str_table}) run.log({"str_trajectory": str_table}) result = env.evaluate() logger.info("Result: %.2f", result) @@ -71,4 +70,3 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl f.write(f"{result}\n") env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) run.log({"Result": result}) - # wandb.log({"Result": result}) diff --git a/run.py b/run.py index 505ae54..728bea4 100644 --- a/run.py +++ b/run.py @@ -52,8 +52,7 @@ logger = logging.getLogger("desktopenv.experiment") # wandb config ### set your wandb api key here -os.environ["WANDB_API_KEY"] = "" -wandb.login(key=os.environ["WANDB_API_KEY"]) +wandb.login(key=os.get("WANDB_API_KEY", None)) def config() -> argparse.Namespace: @@ -148,7 +147,7 @@ def test( for domain in tqdm(test_all_meta, desc="Domain"): for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False): - run = wandb.init(project=f"OSworld-{args.action_space}-{args.observation_type}-{args.model}", group=f"{domain}", + run = wandb.init(project=f"OSworld-{args.action_space}-{args.observation_type}-{args.model}", group=f"{domain}", name=f"{example_id}") # example setting config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json") @@ -164,7 +163,7 @@ def test( # wandb each example config settings cfg_args["instruction"] = instruction cfg_args["start_time"] = datetime.datetime.now().strftime("%Y:%m:%d-%H:%M:%S") - run.config.update(cfg_args) + run.config.update(cfg_args) example_result_dir = os.path.join( args.result_dir, @@ -279,10 +278,10 @@ if __name__ == '__main__': left_info += f"{domain}: {len(test_file_list[domain])}\n" logger.info(f"Left tasks:\n{left_info}") - # get_result(args.action_space, - # args.model, - # args.observation_type, - # args.result_dir, - # test_all_meta - # ) + get_result(args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta + ) test(args, test_file_list)