test_datasets.py are passing!

This commit is contained in:
Cadene
2024-04-08 14:02:03 +00:00
parent e1ac5dc62f
commit 70aaf1c4cb
109 changed files with 90 additions and 228 deletions

View File

@@ -158,6 +158,7 @@ class AlohaDataset(torch.utils.data.Dataset):
self.data_ids_per_episode = {}
ep_dicts = []
idx0 = idx1 = 0
logging.info("Initialize and feed offline buffer")
for ep_id in tqdm.tqdm(range(NUM_EPISODES[self.dataset_id])):
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
@@ -165,7 +166,7 @@ class AlohaDataset(torch.utils.data.Dataset):
num_frames = ep["/action"].shape[0]
# last step of demonstration is considered done
done = torch.zeros(num_frames, 1, dtype=torch.bool)
done = torch.zeros(num_frames, dtype=torch.bool)
done[-1] = True
state = torch.from_numpy(ep["/observations/qpos"][:])
@@ -192,6 +193,14 @@ class AlohaDataset(torch.utils.data.Dataset):
ep_dicts.append(ep_dict)
idx1 += num_frames
assert isinstance(ep_id, int)
self.data_ids_per_episode[ep_id] = torch.arange(idx0, idx1, 1)
assert len(self.data_ids_per_episode[ep_id]) == num_frames
idx0 = idx1
self.data_dict = {}
keys = ep_dicts[0].keys()

View File

@@ -193,8 +193,6 @@ class PushtDataset(torch.utils.data.Dataset):
idx0 = 0
for episode_id in tqdm.tqdm(range(num_episodes)):
idx1 = dataset_dict.meta["episode_ends"][episode_id]
# to create test artifact
# idx1 = 51
num_frames = idx1 - idx0
@@ -207,9 +205,9 @@ class PushtDataset(torch.utils.data.Dataset):
block_pos = state[:, 2:4]
block_angle = state[:, 4]
reward = torch.zeros(num_frames, 1)
success = torch.zeros(num_frames, 1, dtype=torch.bool)
done = torch.zeros(num_frames, 1, dtype=torch.bool)
reward = torch.zeros(num_frames)
success = torch.zeros(num_frames, dtype=torch.bool)
done = torch.zeros(num_frames, dtype=torch.bool)
for i in range(num_frames):
space = pymunk.Space()
space.gravity = 0, 0

View File

@@ -92,11 +92,11 @@ def load_data_with_delta_timestamps(
# TODO(rcadene): synchronize timestamps + interpolation if needed
tol = 0.02
tol = 0.04
is_pad = min_ > tol
assert is_contiguously_true_or_false(is_pad), (
"One or several timestamps unexpectedly violate the tolerance."
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=})."
"This might be due to synchronization issues with timestamps during data collection."
)