Ran pre-commit run --all-files

This commit is contained in:
Simon Alibert
2024-02-29 13:37:48 +01:00
parent 0e0e760e43
commit 7e024fdce6
16 changed files with 124 additions and 237 deletions

View File

@@ -6,8 +6,6 @@ import torch
from tensordict.nn import TensorDictModule
from termcolor import colored
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.datasets.d4rl import D4RLExperienceReplay
from torchrl.data.datasets.openx import OpenXExperienceReplay
from torchrl.data.replay_buffers import PrioritizedSliceSampler
from lerobot.common.datasets.factory import make_offline_buffer
@@ -27,9 +25,7 @@ def train_cli(cfg: dict):
)
def train_notebook(
out_dir=None, job_name=None, config_name="default", config_path="../configs"
):
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
from hydra import compose, initialize
hydra.core.global_hydra.GlobalHydra.instance().clear()
@@ -38,7 +34,7 @@ def train_notebook(
train(cfg, out_dir=out_dir, job_name=job_name)
def log_training_metrics(L, metrics, step, online_episode_idx, start_time, is_offline):
def log_training_metrics(logger, metrics, step, online_episode_idx, start_time, is_offline):
common_metrics = {
"episode": online_episode_idx,
"step": step,
@@ -46,12 +42,10 @@ def log_training_metrics(L, metrics, step, online_episode_idx, start_time, is_of
"is_offline": float(is_offline),
}
metrics.update(common_metrics)
L.log(metrics, category="train")
logger.log(metrics, category="train")
def eval_policy_and_log(
env, td_policy, step, online_episode_idx, start_time, cfg, L, is_offline
):
def eval_policy_and_log(env, td_policy, step, online_episode_idx, start_time, cfg, logger, is_offline):
common_metrics = {
"episode": online_episode_idx,
"step": step,
@@ -65,11 +59,11 @@ def eval_policy_and_log(
return_first_video=True,
)
metrics.update(common_metrics)
L.log(metrics, category="eval")
logger.log(metrics, category="eval")
if cfg.wandb.enable:
eval_video = L._wandb.Video(first_video, fps=cfg.fps, format="mp4")
L._wandb.log({"eval_video": eval_video}, step=step)
eval_video = logger._wandb.Video(first_video, fps=cfg.fps, format="mp4")
logger._wandb.log({"eval_video": eval_video}, step=step)
def train(cfg: dict, out_dir=None, job_name=None):
@@ -116,7 +110,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
sampler=online_sampler,
)
L = Logger(out_dir, job_name, cfg)
logger = Logger(out_dir, job_name, cfg)
online_episode_idx = 0
start_time = time.time()
@@ -129,9 +123,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
metrics = policy.update(offline_buffer, step)
if step % cfg.log_freq == 0:
log_training_metrics(
L, metrics, step, online_episode_idx, start_time, is_offline=False
)
log_training_metrics(logger, metrics, step, online_episode_idx, start_time, is_offline=False)
if step > 0 and step % cfg.eval_freq == 0:
eval_policy_and_log(
@@ -141,13 +133,13 @@ def train(cfg: dict, out_dir=None, job_name=None):
online_episode_idx,
start_time,
cfg,
L,
logger,
is_offline=True,
)
if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
print(f"Checkpoint model at step {step}")
L.save_model(policy, identifier=step)
logger.save_model(policy, identifier=step)
step += 1
@@ -164,9 +156,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
auto_cast_to_device=True,
)
assert len(rollout) <= cfg.env.episode_length
rollout["episode"] = torch.tensor(
[online_episode_idx] * len(rollout), dtype=torch.int
)
rollout["episode"] = torch.tensor([online_episode_idx] * len(rollout), dtype=torch.int)
online_buffer.extend(rollout)
ep_sum_reward = rollout["next", "reward"].sum()
@@ -188,9 +178,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
)
metrics.update(train_metrics)
if step % cfg.log_freq == 0:
log_training_metrics(
L, metrics, step, online_episode_idx, start_time, is_offline=False
)
log_training_metrics(logger, metrics, step, online_episode_idx, start_time, is_offline=False)
if step > 0 and step % cfg.eval_freq == 0:
eval_policy_and_log(
@@ -200,13 +188,13 @@ def train(cfg: dict, out_dir=None, job_name=None):
online_episode_idx,
start_time,
cfg,
L,
logger,
is_offline=False,
)
if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
print(f"Checkpoint model at step {step}")
L.save_model(policy, identifier=step)
logger.save_model(policy, identifier=step)
step += 1