Refactor pusht_zarr

This commit is contained in:
Simon Alibert
2024-11-25 18:23:04 +01:00
parent 3b5af7eb38
commit 6ad84a6561
2 changed files with 110 additions and 132 deletions

View File

@@ -678,7 +678,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
"})',\n"
)
def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
return {
"size": 0,
@@ -709,7 +709,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# check the dtype and shape matches, etc.
if self.episode_buffer is None:
self.episode_buffer = self._create_episode_buffer()
self.episode_buffer = self.create_episode_buffer()
frame_index = self.episode_buffer["size"]
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
@@ -795,7 +795,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
episode_buffer[key] = video_paths[key]
if not episode_data: # Reset the buffer
self.episode_buffer = self._create_episode_buffer()
self.episode_buffer = self.create_episode_buffer()
self.consolidated = False
@@ -817,7 +817,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
shutil.rmtree(img_dir)
# Reset the buffer
self.episode_buffer = self._create_episode_buffer()
self.episode_buffer = self.create_episode_buffer()
def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
if isinstance(self.image_writer, AsyncImageWriter):
@@ -941,7 +941,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.start_image_writer(image_writer_processes, image_writer_threads)
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
obj.episode_buffer = obj._create_episode_buffer()
obj.episode_buffer = obj.create_episode_buffer()
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
# is used to know when certain operations are need (for instance, computing dataset statistics). In