test_datasets.py are passing!

This commit is contained in:
Cadene
2024-04-08 14:02:03 +00:00
parent e1ac5dc62f
commit 70aaf1c4cb
109 changed files with 90 additions and 228 deletions

View File

@@ -26,16 +26,23 @@ import torch
def mock_dataset(in_data_dir, out_data_dir, num_frames):
in_data_dir = Path(in_data_dir)
out_data_dir = Path(out_data_dir)
out_data_dir.mkdir(exist_ok=True, parents=True)
# copy the first `n` frames for each data key so that we have real data
in_data_dict = torch.load(in_data_dir / "data_dict.pth")
out_data_dict = {key: in_data_dict[key][:num_frames].clone() for key in in_data_dict}
torch.save(out_data_dict, out_data_dir / "data_dict.pth")
# copy the full mapping between data_id and episode since it's small
in_ids_per_ep_path = in_data_dir / "data_ids_per_episode.pth"
out_ids_per_ep_path = out_data_dir / "data_ids_per_episode.pth"
shutil.copy(in_ids_per_ep_path, out_ids_per_ep_path)
# recreate data_ids_per_episode that corresponds to the subset
episodes = in_data_dict["episode"][:num_frames].tolist()
data_ids_per_episode = {}
for idx, ep_id in enumerate(episodes):
if ep_id not in data_ids_per_episode:
data_ids_per_episode[ep_id] = []
data_ids_per_episode[ep_id].append(idx)
for ep_id in data_ids_per_episode:
data_ids_per_episode[ep_id] = torch.tensor(data_ids_per_episode[ep_id])
torch.save(data_ids_per_episode, out_data_dir / "data_ids_per_episode.pth")
# copy the full statistics of dataset since it's small
in_stats_path = in_data_dir / "stats.pth"