Improve visualize_dataset, Improve AbstractReplayBuffer, Small improvements
This commit is contained in:
@@ -4,13 +4,12 @@ import hydra
|
||||
import numpy as np
|
||||
import torch
|
||||
from tensordict.nn import TensorDictModule
|
||||
from termcolor import colored
|
||||
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
|
||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
||||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.logger import Logger
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils import format_big_number, init_logging, set_seed
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
@@ -164,7 +163,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
# log metrics to terminal and wandb
|
||||
logger = Logger(out_dir, job_name, cfg)
|
||||
|
||||
logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||
log_output_dir(out_dir)
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
|
||||
logging.info(f"{cfg.online_steps=}")
|
||||
@@ -212,7 +211,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
for env_step in range(cfg.online_steps):
|
||||
if env_step == 0:
|
||||
logging.info("Start online training by interacting with environment")
|
||||
# TODO: use SyncDataCollector for that?
|
||||
# TODO: add configurable number of rollout? (default=1)
|
||||
with torch.no_grad():
|
||||
rollout = env.rollout(
|
||||
@@ -268,6 +266,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
step += 1
|
||||
online_step += 1
|
||||
|
||||
logging.info("End of training")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_cli()
|
||||
|
||||
Reference in New Issue
Block a user