Add replay_buffer directory in pusht datasets + aloha (WIP)

This commit is contained in:
Cadene
2024-03-19 15:49:45 +00:00
parent 099a465367
commit 6a1a29386a
20 changed files with 53 additions and 8 deletions

View File

@@ -10,12 +10,15 @@ from pathlib import Path
def mock_dataset(in_data_dir, out_data_dir, num_frames=50):
in_data_dir = Path(in_data_dir)
out_data_dir = Path(out_data_dir)
# load full dataset as a tensor dict
in_td_data = TensorDict.load_memmap(in_data_dir)
in_td_data = TensorDict.load_memmap(in_data_dir / "replay_buffer")
# use 1 frame to know the specification of the dataset
# and copy it over `n` frames in the test artifact directory
out_td_data = in_td_data[0].expand(num_frames).memmap_like(out_data_dir)
out_td_data = in_td_data[0].expand(num_frames).memmap_like(out_data_dir / "replay_buffer")
# copy the first `n` frames so that we have real data
out_td_data[:num_frames] = in_td_data[:num_frames].clone()
@@ -24,8 +27,8 @@ def mock_dataset(in_data_dir, out_data_dir, num_frames=50):
out_td_data.lock_()
# copy the full statistics of dataset since it's pretty small
in_stats_path = Path(in_data_dir) / "stats.pth"
out_stats_path = Path(out_data_dir) / "stats.pth"
in_stats_path = in_data_dir / "stats.pth"
out_stats_path = out_data_dir / "stats.pth"
shutil.copy(in_stats_path, out_stats_path)