Refactor env to add key word arguments from config yaml (#223)

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Remi
2024-05-30 13:45:22 +02:00
committed by GitHub
parent 2c2e4e14ed
commit 265b0ec44d
7 changed files with 59 additions and 36 deletions

View File

@@ -281,8 +281,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("make_dataset")
offline_dataset = make_dataset(cfg)
logging.info("make_env")
eval_env = make_env(cfg)
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
# using the eval.py instead, with gym_dora environment and dora-rs.
if cfg.training.eval_freq > 0:
logging.info("make_env")
eval_env = make_env(cfg)
logging.info("make_policy")
policy = make_policy(
@@ -315,7 +319,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Note: this helper will be used in offline and online training loops.
def evaluate_and_checkpoint_if_needed(step):
if step % cfg.training.eval_freq == 0:
if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0:
logging.info(f"Eval policy at step {step}")
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
eval_info = eval_policy(
@@ -349,7 +353,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# create dataloader for offline training
dataloader = torch.utils.data.DataLoader(
offline_dataset,
num_workers=4,
num_workers=cfg.training.num_workers,
batch_size=cfg.training.batch_size,
shuffle=True,
pin_memory=device.type != "cpu",
@@ -386,6 +390,16 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
step += 1
logging.info("End of offline training")
if cfg.training.online_steps == 0:
if cfg.training.eval_freq > 0:
eval_env.close()
return
# create an env dedicated to online episodes collection from policy rollout
online_training_env = make_env(cfg, n_envs=1)
# create an empty online dataset similar to offline dataset
online_dataset = deepcopy(offline_dataset)
online_dataset.hf_dataset = {}
@@ -406,8 +420,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
drop_last=False,
)
eval_env.close()
logging.info("End of training")
logging.info("End of online training")
if cfg.training.eval_freq > 0:
eval_env.close()
online_training_env.close()
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")