Add pusht test artifact

This commit is contained in:
Simon Alibert
2024-03-08 15:54:30 +01:00
parent 7dbdbb051c
commit 89eaab140b
19 changed files with 42 additions and 0 deletions

View File

@@ -125,6 +125,9 @@ class PushtExperienceReplay(AbstractExperienceReplay):
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
total_frames = dataset_dict["action"].shape[0]
# to create test artifact
# num_episodes = 1
# total_frames = 50
assert len(
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
), "Some data type dont have the same number of total frames."
@@ -142,6 +145,8 @@ class PushtExperienceReplay(AbstractExperienceReplay):
idxtd = 0
for episode_id in tqdm.tqdm(range(num_episodes)):
idx1 = dataset_dict.meta["episode_ends"][episode_id]
# to create test artifact
# idx1 = 51
num_frames = idx1 - idx0