Remove total_episodes from default parquet path

This commit is contained in:
Simon Alibert
2024-10-23 00:03:30 +02:00
parent 237a484be0
commit c72dc23c43
2 changed files with 7 additions and 19 deletions

View File

@@ -296,13 +296,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
def get_data_file_path(self, ep_index: int) -> Path:
ep_chunk = self.get_episode_chunk(ep_index)
return self.data_path.format(
episode_chunk=ep_chunk, episode_index=ep_index, total_episodes=self.total_episodes
)
fpath = self.data_path.format(episode_chunk=ep_chunk, episode_index=ep_index)
return Path(fpath)
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
ep_chunk = self.get_episode_chunk(ep_index)
return self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
fpath = self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
return Path(fpath)
def get_episode_chunk(self, ep_index: int) -> int:
ep_chunk = ep_index // self.chunks_size
@@ -678,17 +678,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Reset the buffer
self.episode_buffer = self._create_episode_buffer()
def _update_data_file_names(self) -> None:
# TODO(aliberts): remove the need for this hack by removing total_episodes part in data file names.
# Must first investigate if this doesn't break hub/datasets features like viewer etc.
for ep_idx in range(self.total_episodes):
ep_chunk = self.get_episode_chunk(ep_idx)
current_file_name = self.data_path.replace("{total_episodes:05d}", "*")
current_file_name = current_file_name.format(episode_chunk=ep_chunk, episode_index=ep_idx)
current_file_name = list(self.root.glob(current_file_name))[0]
updated_file_name = self.root / self.get_data_file_path(ep_idx)
current_file_name.rename(updated_file_name)
def _remove_image_writer(self) -> None:
if self.image_writer is not None:
self.image_writer = None
@@ -710,7 +699,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
shutil.rmtree(tmp_imgs_dir)
def consolidate(self, run_compute_stats: bool = True) -> None:
self._update_data_file_names()
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)