Ran pre-commit run --all-files

This commit is contained in:
Simon Alibert
2024-02-29 13:37:48 +01:00
parent 0e0e760e43
commit 7e024fdce6
16 changed files with 124 additions and 237 deletions

View File

@@ -37,10 +37,11 @@ def eval_policy(
tensordict = env.reset()
ep_frames = []
if save_video or (return_first_video and i == 0):
def rendering_callback(env, td=None):
ep_frames.append(env.render())
ep_frames.append(env.render()) # noqa: B023
# render first frame before rollout
rendering_callback(env)

View File

@@ -6,8 +6,6 @@ import torch
from tensordict.nn import TensorDictModule
from termcolor import colored
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.datasets.d4rl import D4RLExperienceReplay
from torchrl.data.datasets.openx import OpenXExperienceReplay
from torchrl.data.replay_buffers import PrioritizedSliceSampler
from lerobot.common.datasets.factory import make_offline_buffer
@@ -27,9 +25,7 @@ def train_cli(cfg: dict):
)
def train_notebook(
out_dir=None, job_name=None, config_name="default", config_path="../configs"
):
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
from hydra import compose, initialize
hydra.core.global_hydra.GlobalHydra.instance().clear()
@@ -38,7 +34,7 @@ def train_notebook(
train(cfg, out_dir=out_dir, job_name=job_name)
def log_training_metrics(L, metrics, step, online_episode_idx, start_time, is_offline):
def log_training_metrics(logger, metrics, step, online_episode_idx, start_time, is_offline):
common_metrics = {
"episode": online_episode_idx,
"step": step,
@@ -46,12 +42,10 @@ def log_training_metrics(L, metrics, step, online_episode_idx, start_time, is_of
"is_offline": float(is_offline),
}
metrics.update(common_metrics)
L.log(metrics, category="train")
logger.log(metrics, category="train")
def eval_policy_and_log(
env, td_policy, step, online_episode_idx, start_time, cfg, L, is_offline
):
def eval_policy_and_log(env, td_policy, step, online_episode_idx, start_time, cfg, logger, is_offline):
common_metrics = {
"episode": online_episode_idx,
"step": step,
@@ -65,11 +59,11 @@ def eval_policy_and_log(
return_first_video=True,
)
metrics.update(common_metrics)
L.log(metrics, category="eval")
logger.log(metrics, category="eval")
if cfg.wandb.enable:
eval_video = L._wandb.Video(first_video, fps=cfg.fps, format="mp4")
L._wandb.log({"eval_video": eval_video}, step=step)
eval_video = logger._wandb.Video(first_video, fps=cfg.fps, format="mp4")
logger._wandb.log({"eval_video": eval_video}, step=step)
def train(cfg: dict, out_dir=None, job_name=None):
@@ -116,7 +110,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
sampler=online_sampler,
)
L = Logger(out_dir, job_name, cfg)
logger = Logger(out_dir, job_name, cfg)
online_episode_idx = 0
start_time = time.time()
@@ -129,9 +123,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
metrics = policy.update(offline_buffer, step)
if step % cfg.log_freq == 0:
log_training_metrics(
L, metrics, step, online_episode_idx, start_time, is_offline=False
)
log_training_metrics(logger, metrics, step, online_episode_idx, start_time, is_offline=False)
if step > 0 and step % cfg.eval_freq == 0:
eval_policy_and_log(
@@ -141,13 +133,13 @@ def train(cfg: dict, out_dir=None, job_name=None):
online_episode_idx,
start_time,
cfg,
L,
logger,
is_offline=True,
)
if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
print(f"Checkpoint model at step {step}")
L.save_model(policy, identifier=step)
logger.save_model(policy, identifier=step)
step += 1
@@ -164,9 +156,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
auto_cast_to_device=True,
)
assert len(rollout) <= cfg.env.episode_length
rollout["episode"] = torch.tensor(
[online_episode_idx] * len(rollout), dtype=torch.int
)
rollout["episode"] = torch.tensor([online_episode_idx] * len(rollout), dtype=torch.int)
online_buffer.extend(rollout)
ep_sum_reward = rollout["next", "reward"].sum()
@@ -188,9 +178,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
)
metrics.update(train_metrics)
if step % cfg.log_freq == 0:
log_training_metrics(
L, metrics, step, online_episode_idx, start_time, is_offline=False
)
log_training_metrics(logger, metrics, step, online_episode_idx, start_time, is_offline=False)
if step > 0 and step % cfg.eval_freq == 0:
eval_policy_and_log(
@@ -200,13 +188,13 @@ def train(cfg: dict, out_dir=None, job_name=None):
online_episode_idx,
start_time,
cfg,
L,
logger,
is_offline=False,
)
if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
print(f"Checkpoint model at step {step}")
L.save_model(policy, identifier=step)
logger.save_model(policy, identifier=step)
step += 1

View File

@@ -1,24 +1,22 @@
import pickle
from pathlib import Path
import hydra
import imageio
import simxarm
import torch
from torchrl.data.replay_buffers import (
SamplerWithoutReplacement,
SliceSampler,
SliceSamplerWithoutReplacement,
)
from lerobot.common.datasets.factory import make_offline_buffer
NUM_EPISODES_TO_RENDER = 10
MAX_NUM_STEPS = 1000
FIRST_FRAME = 0
@hydra.main(version_base=None, config_name="default", config_path="../configs")
def visualize_dataset_cli(cfg: dict):
visualize_dataset(
cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
)
visualize_dataset(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
def visualize_dataset(cfg: dict, out_dir=None):
@@ -33,9 +31,6 @@ def visualize_dataset(cfg: dict, out_dir=None):
offline_buffer = make_offline_buffer(cfg, sampler)
NUM_EPISODES_TO_RENDER = 10
MAX_NUM_STEPS = 1000
FIRST_FRAME = 0
for _ in range(NUM_EPISODES_TO_RENDER):
episode = offline_buffer.sample(MAX_NUM_STEPS)
@@ -57,9 +52,7 @@ def visualize_dataset(cfg: dict, out_dir=None):
assert ep_frames.max().item() > 1, "Not mendatory, but sanity check"
assert ep_frames.max().item() <= 255
ep_frames = ep_frames.type(torch.uint8)
imageio.mimsave(
video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=cfg.fps
)
imageio.mimsave(video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=cfg.fps)
# ran out of episodes
if offline_buffer._sampler._sample_list.numel() == 0: