Improve visualize_dataset, Improve AbstractReplayBuffer, Small improvements

This commit is contained in:
Remi Cadene
2024-03-06 10:14:03 +00:00
parent 2f80d71c3e
commit f95ecd66fc
7 changed files with 195 additions and 150 deletions

View File

@@ -9,13 +9,13 @@ import numpy as np
import torch
import tqdm
from tensordict.nn import TensorDictModule
from termcolor import colored
from torchrl.envs import EnvBase
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import set_seed
from lerobot.common.utils import init_logging, set_seed
def write_video(video_path, stacked_frames, fps):
@@ -109,10 +109,18 @@ def eval(cfg: dict, out_dir=None):
if out_dir is None:
raise NotImplementedError()
assert torch.cuda.is_available()
init_logging()
if cfg.device == "cuda":
assert torch.cuda.is_available()
else:
logging.warning("Using CPU, this will be slow.")
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_seed(cfg.seed)
print(colored("Log dir:", "yellow", attrs=["bold"]), out_dir)
log_output_dir(out_dir)
logging.info("make_offline_buffer")
offline_buffer = make_offline_buffer(cfg)
@@ -142,6 +150,8 @@ def eval(cfg: dict, out_dir=None):
)
print(metrics)
logging.info("End of eval")
if __name__ == "__main__":
eval_cli()