Improve visualize_dataset, Improve AbstractReplayBuffer, Small improvements

This commit is contained in:
Remi Cadene
2024-03-06 10:14:03 +00:00
parent 2f80d71c3e
commit f95ecd66fc
7 changed files with 195 additions and 150 deletions

View File

@@ -4,13 +4,12 @@ import hydra
import numpy as np
import torch
from tensordict.nn import TensorDictModule
from termcolor import colored
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers import PrioritizedSliceSampler
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import format_big_number, init_logging, set_seed
from lerobot.scripts.eval import eval_policy
@@ -164,7 +163,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
# log metrics to terminal and wandb
logger = Logger(out_dir, job_name, cfg)
logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
log_output_dir(out_dir)
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
logging.info(f"{cfg.online_steps=}")
@@ -212,7 +211,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
for env_step in range(cfg.online_steps):
if env_step == 0:
logging.info("Start online training by interacting with environment")
# TODO: use SyncDataCollector for that?
# TODO: add configurable number of rollout? (default=1)
with torch.no_grad():
rollout = env.rollout(
@@ -268,6 +266,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
step += 1
online_step += 1
logging.info("End of training")
if __name__ == "__main__":
train_cli()