Add pusht test artifact
This commit is contained in:
@@ -125,6 +125,9 @@ class PushtExperienceReplay(AbstractExperienceReplay):
|
||||
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
|
||||
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
||||
total_frames = dataset_dict["action"].shape[0]
|
||||
# to create test artifact
|
||||
# num_episodes = 1
|
||||
# total_frames = 50
|
||||
assert len(
|
||||
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
|
||||
), "Some data type dont have the same number of total frames."
|
||||
@@ -142,6 +145,8 @@ class PushtExperienceReplay(AbstractExperienceReplay):
|
||||
idxtd = 0
|
||||
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||
idx1 = dataset_dict.meta["episode_ends"][episode_id]
|
||||
# to create test artifact
|
||||
# idx1 = 51
|
||||
|
||||
num_frames = idx1 - idx0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user