test_datasets.py are passing!

This commit is contained in:
Cadene
2024-04-08 14:02:03 +00:00
parent e1ac5dc62f
commit 70aaf1c4cb
109 changed files with 90 additions and 228 deletions

View File

@@ -193,8 +193,6 @@ class PushtDataset(torch.utils.data.Dataset):
idx0 = 0
for episode_id in tqdm.tqdm(range(num_episodes)):
idx1 = dataset_dict.meta["episode_ends"][episode_id]
# to create test artifact
# idx1 = 51
num_frames = idx1 - idx0
@@ -207,9 +205,9 @@ class PushtDataset(torch.utils.data.Dataset):
block_pos = state[:, 2:4]
block_angle = state[:, 4]
reward = torch.zeros(num_frames, 1)
success = torch.zeros(num_frames, 1, dtype=torch.bool)
done = torch.zeros(num_frames, 1, dtype=torch.bool)
reward = torch.zeros(num_frames)
success = torch.zeros(num_frames, dtype=torch.bool)
done = torch.zeros(num_frames, dtype=torch.bool)
for i in range(num_frames):
space = pymunk.Space()
space.gravity = 0, 0