Ran pre-commit run --all-files

This commit is contained in:
Simon Alibert
2024-02-29 13:37:48 +01:00
parent 0e0e760e43
commit 7e024fdce6
16 changed files with 124 additions and 237 deletions

View File

@@ -1,24 +1,22 @@
import pickle
from pathlib import Path
import hydra
import imageio
import simxarm
import torch
from torchrl.data.replay_buffers import (
SamplerWithoutReplacement,
SliceSampler,
SliceSamplerWithoutReplacement,
)
from lerobot.common.datasets.factory import make_offline_buffer
NUM_EPISODES_TO_RENDER = 10
MAX_NUM_STEPS = 1000
FIRST_FRAME = 0
@hydra.main(version_base=None, config_name="default", config_path="../configs")
def visualize_dataset_cli(cfg: dict):
visualize_dataset(
cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
)
visualize_dataset(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
def visualize_dataset(cfg: dict, out_dir=None):
@@ -33,9 +31,6 @@ def visualize_dataset(cfg: dict, out_dir=None):
offline_buffer = make_offline_buffer(cfg, sampler)
NUM_EPISODES_TO_RENDER = 10
MAX_NUM_STEPS = 1000
FIRST_FRAME = 0
for _ in range(NUM_EPISODES_TO_RENDER):
episode = offline_buffer.sample(MAX_NUM_STEPS)
@@ -57,9 +52,7 @@ def visualize_dataset(cfg: dict, out_dir=None):
assert ep_frames.max().item() > 1, "Not mendatory, but sanity check"
assert ep_frames.max().item() <= 255
ep_frames = ep_frames.type(torch.uint8)
imageio.mimsave(
video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=cfg.fps
)
imageio.mimsave(video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=cfg.fps)
# ran out of episodes
if offline_buffer._sampler._sample_list.numel() == 0: