Fix image writer
This commit is contained in:
@@ -25,7 +25,6 @@ import datasets
|
||||
import torch
|
||||
import torch.utils
|
||||
from datasets import load_dataset
|
||||
from datasets.table import embed_table_storage
|
||||
from huggingface_hub import snapshot_download, upload_folder
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
|
||||
@@ -51,6 +50,7 @@ from lerobot.common.datasets.utils import (
|
||||
load_stats,
|
||||
load_tasks,
|
||||
write_json,
|
||||
write_parquet,
|
||||
write_stats,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
@@ -354,7 +354,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
"""Number of samples/frames in selected episodes."""
|
||||
return len(self.hf_dataset)
|
||||
return len(self.hf_dataset) if self.hf_dataset is not None else self.total_frames
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
@@ -584,9 +584,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if frame_index == 0:
|
||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.image_writer.async_save_image(
|
||||
image=frame[cam_key],
|
||||
file_path=img_path,
|
||||
self.image_writer.save_image(
|
||||
image_array=frame[cam_key],
|
||||
fpath=img_path,
|
||||
)
|
||||
|
||||
if cam_key in self.image_keys:
|
||||
@@ -640,14 +640,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self._features, split="train")
|
||||
ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index)
|
||||
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Embed image bytes into the table before saving to parquet
|
||||
format = ep_dataset.format
|
||||
ep_dataset = ep_dataset.with_format("arrow")
|
||||
ep_dataset = ep_dataset.map(embed_table_storage, batched=False)
|
||||
ep_dataset = ep_dataset.with_format(**format)
|
||||
|
||||
ep_dataset.to_parquet(ep_data_path)
|
||||
write_parquet(ep_dataset, ep_data_path)
|
||||
|
||||
def _save_episode_to_metadata(
|
||||
self, episode_index: int, episode_length: int, task: str, task_index: int
|
||||
@@ -709,13 +702,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized.
|
||||
"""
|
||||
if self.image_writer is not None:
|
||||
self.image_writer.shutdown()
|
||||
self.image_writer.stop()
|
||||
self.image_writer = None
|
||||
|
||||
def _wait_image_writer(self) -> None:
|
||||
"""Wait for asynchronous image writer to finish."""
|
||||
if self.image_writer is not None:
|
||||
self.image_writer.wait()
|
||||
self.image_writer.wait_until_done()
|
||||
|
||||
def encode_videos(self) -> None:
|
||||
# Use ffmpeg to convert frames stored as png into mp4 videos
|
||||
@@ -754,7 +747,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self._write_video_info()
|
||||
|
||||
if not keep_image_files and self.image_writer is not None:
|
||||
shutil.rmtree(self.image_writer.dir)
|
||||
shutil.rmtree(self.image_writer.write_dir)
|
||||
|
||||
if run_compute_stats:
|
||||
self.stop_image_writer()
|
||||
|
||||
Reference in New Issue
Block a user