Merge remote-tracking branch 'origin/user/rcadene/2024_03_31_remove_torchrl' into user/rcadene/2024_03_31_remove_torchrl
This commit is contained in:
@@ -158,8 +158,7 @@ class AlohaDataset(torch.utils.data.Dataset):
|
||||
self.data_ids_per_episode = {}
|
||||
ep_dicts = []
|
||||
|
||||
idx0 = idx1 = 0
|
||||
logging.info("Initialize and feed offline buffer")
|
||||
frame_idx = 0
|
||||
for ep_id in tqdm.tqdm(range(NUM_EPISODES[self.dataset_id])):
|
||||
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
@@ -191,15 +190,13 @@ class AlohaDataset(torch.utils.data.Dataset):
|
||||
ep_dict[f"observation.images.{cam}"] = image[:-1]
|
||||
# ep_dict[f"next.observation.images.{cam}"] = image[1:]
|
||||
|
||||
assert isinstance(ep_id, int)
|
||||
self.data_ids_per_episode[ep_id] = torch.arange(frame_idx, frame_idx + num_frames, 1)
|
||||
assert len(self.data_ids_per_episode[ep_id]) == num_frames
|
||||
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
idx1 += num_frames
|
||||
|
||||
assert isinstance(ep_id, int)
|
||||
self.data_ids_per_episode[ep_id] = torch.arange(idx0, idx1, 1)
|
||||
assert len(self.data_ids_per_episode[ep_id]) == num_frames
|
||||
|
||||
idx0 = idx1
|
||||
frame_idx += num_frames
|
||||
|
||||
self.data_dict = {}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user