backup wip
This commit is contained in:
@@ -155,11 +155,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
td_policy = TensorDictModule(
|
||||
policy,
|
||||
in_keys=["observation", "step_count"],
|
||||
out_keys=["action"],
|
||||
)
|
||||
td_policy = TensorDictModule(policy, in_keys=["observation", "step_count"], out_keys=["action"])
|
||||
|
||||
# log metrics to terminal and wandb
|
||||
logger = Logger(out_dir, job_name, cfg)
|
||||
@@ -174,19 +170,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
step = 0 # number of policy update (forward + backward + optim)
|
||||
|
||||
is_offline = True
|
||||
for offline_step in range(cfg.offline_steps):
|
||||
if offline_step == 0:
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||
policy.train()
|
||||
train_info = policy.update(offline_buffer, step)
|
||||
if step % cfg.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
|
||||
|
||||
if step > 0 and step % cfg.eval_freq == 0:
|
||||
# Note: this helper will be used in offline and online training loops.
|
||||
def _maybe_eval_and_maybe_save(step):
|
||||
if step % cfg.eval_freq == 0:
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
eval_info, first_video = eval_policy(
|
||||
env,
|
||||
@@ -202,11 +188,27 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
logger.log_video(first_video, step, mode="eval")
|
||||
logging.info("Resume training")
|
||||
|
||||
if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
|
||||
logging.info(f"Checkpoint policy at step {step}")
|
||||
if cfg.save_model and step % cfg.save_freq == 0:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
logger.save_model(policy, identifier=step)
|
||||
logging.info("Resume training")
|
||||
|
||||
step = 0 # number of policy update (forward + backward + optim)
|
||||
|
||||
is_offline = True
|
||||
for offline_step in range(cfg.offline_steps):
|
||||
if offline_step == 0:
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||
policy.train()
|
||||
train_info = policy.update(offline_buffer, step)
|
||||
if step % cfg.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
|
||||
|
||||
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in
|
||||
# step + 1.
|
||||
_maybe_eval_and_maybe_save(step + 1)
|
||||
|
||||
step += 1
|
||||
|
||||
demo_buffer = offline_buffer if cfg.policy.balanced_sampling else None
|
||||
@@ -248,24 +250,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
train_info.update(rollout_info)
|
||||
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
|
||||
|
||||
if step > 0 and step % cfg.eval_freq == 0:
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
eval_info, first_video = eval_policy(
|
||||
env,
|
||||
td_policy,
|
||||
num_episodes=cfg.eval_episodes,
|
||||
max_steps=cfg.env.episode_length // cfg.n_action_steps,
|
||||
return_first_video=True,
|
||||
)
|
||||
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
|
||||
if cfg.wandb.enable:
|
||||
logger.log_video(first_video, step, mode="eval")
|
||||
logging.info("Resume training")
|
||||
|
||||
if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
|
||||
logging.info(f"Checkpoint policy at step {step}")
|
||||
logger.save_model(policy, identifier=step)
|
||||
logging.info("Resume training")
|
||||
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass
|
||||
# in step + 1.
|
||||
_maybe_eval_and_maybe_save(step + 1)
|
||||
|
||||
step += 1
|
||||
online_step += 1
|
||||
|
||||
Reference in New Issue
Block a user