Compare commits

...

4 Commits

Author SHA1 Message Date
Remi Cadene
49ae3e19e1 Add clone, delete, WIP on remove_episode, drop_frame 2024-12-27 17:49:47 +01:00
Remi Cadene
ebe0bfad77 WIP 2024-12-11 09:09:14 -08:00
Remi Cadene
c6e9a3dc24 nit 2024-12-03 17:24:55 +01:00
Remi Cadene
afbd42d082 Add manage_dataset 2024-12-03 17:16:47 +01:00
2 changed files with 225 additions and 2 deletions

View File

@@ -873,7 +873,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
return video_paths
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
def consolidate(
self,
run_compute_stats: bool = True,
keep_image_files: bool = False,
batch_size: int = 8,
num_workers: int = 8,
) -> None:
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
@@ -896,7 +902,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if run_compute_stats:
self.stop_image_writer()
# TODO(aliberts): refactor stats in save_episodes
self.meta.stats = compute_stats(self)
self.meta.stats = compute_stats(self, batch_size=batch_size, num_workers=num_workers)
serialized_stats = serialize_dict(self.meta.stats)
write_json(serialized_stats, self.root / STATS_PATH)
self.consolidated = True
@@ -958,6 +964,35 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.video_backend = video_backend if video_backend is not None else "pyav"
return obj
def clone(self, new_repo_id: str, new_root: str | Path | None = None) -> "LeRobotDataset":
return LeRobotDataset.create(
repo_id=new_repo_id,
fps=self.fps,
root=new_root,
robot=self.robot,
robot_type=self.robot_type,
features=self.features,
use_videos=self.use_videos,
tolerance_s=self.tolerance_s,
image_writer_processes=self.image_writer_processes,
image_writer_threads=self.image_writer_threads,
video_backend=self.video_backend,
)
def delete(self):
"""Delete the dataset locally. If it was push to hub, you can still access it by downloading it again."""
shutil.rmtree(self.root)
def remove_episode(self, episode: int | list[int]):
if isinstance(episode, int):
episode = [episode]
for ep in episode:
self.meta.info
def drop_frame(self, episode_range: dict[int, tuple[int]]):
pass
class MultiLeRobotDataset(torch.utils.data.Dataset):
"""A dataset consisting of multiple underlying `LeRobotDataset`s.

View File

@@ -0,0 +1,188 @@
"""
Utilities to manage a dataset.
Examples of usage:
- Consolidate a dataset, by encoding images into videos and computing statistics:
```bash
python lerobot/scripts/manage_dataset.py consolidate \
--repo-id $USER/koch_test
```
- Consolidate a dataset which is not uploaded on the hub yet:
```bash
python lerobot/scripts/manage_dataset.py consolidate \
--repo-id $USER/koch_test \
--local-files-only 1
```
- Upload a dataset on the hub:
```bash
python lerobot/scripts/manage_dataset.py push_to_hub \
--repo-id $USER/koch_test
```
"""
import argparse
from pathlib import Path
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
def parse_episode_range_string(ep_range_str):
parts = ep_range_str.split("-")
if len(parts) != 3:
raise ValueError(
f"Invalid episode range string '{ep_range_str}'. Expected format: 'EP-FROM-TO', e.g., '1-5-10'."
)
ep, start, end = parts
return int(ep), int(start), int(end)
def parse_episode_range_strings(ep_range_strings):
ep_ranges = {}
for ep_range_str in ep_range_strings:
ep, start, end = parse_episode_range_string(ep_range_str)
if ep not in ep_ranges:
ep_ranges[ep] = []
ep_ranges[ep].append((start, end))
return ep_ranges
if __name__ == "__main__":
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest="mode", required=True)
# Set common options for all the subparsers
base_parser = argparse.ArgumentParser(add_help=False)
base_parser.add_argument(
"--root",
type=Path,
default=None,
help="Root directory where the dataset is stored (e.g. 'dataset/path').",
)
base_parser.add_argument(
"--repo-id",
type=str,
default="lerobot/test",
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
)
base_parser.add_argument(
"--local-files-only",
type=int,
default=0,
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
)
############################################################################
# consolidate
parser_conso = subparsers.add_parser("consolidate", parents=[base_parser])
parser_conso.add_argument(
"--batch-size",
type=int,
default=8,
help="Batch size loaded by DataLoader for computing the dataset statistics.",
)
parser_conso.add_argument(
"--num-workers",
type=int,
default=8,
help="Number of processes of Dataloader for computing the dataset statistics.",
)
############################################################################
# push_to_hub
parser_push = subparsers.add_parser("push_to_hub", parents=[base_parser])
parser_push.add_argument(
"--tags",
type=str,
nargs="*",
default=None,
help="Optional additional tags to categorize the dataset on the Hugging Face Hub. Use space-separated values (e.g. 'so100 indoor'). The tag 'LeRobot' will always be added.",
)
parser_push.add_argument(
"--license",
type=str,
default="apache-2.0",
help="Repo license. Must be one of https://huggingface.co/docs/hub/repositories-licenses. Defaults to mit.",
)
parser_push.add_argument(
"--private",
type=int,
default=0,
help="Create a private dataset repository on the Hugging Face Hub. Push publicly by default.",
)
############################################################################
# clone
parser_clone = subparsers.add_parser("clone", parents=[base_parser])
parser_clone.add_argument(
"--root",
type=Path,
default=None,
help="New root directory where the dataset is stored (e.g. 'dataset/path').",
)
parser_clone.add_argument(
"--new-repo-id",
type=str,
help="New dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
)
############################################################################
# delete
parser_del = subparsers.add_parser("delete", parents=[base_parser])
############################################################################
# remove_episode
parser_rm_ep = subparsers.add_parser("remove_episode", parents=[base_parser])
parser_rm_ep.add_argument(
"--episode",
type=int,
nargs="*",
help="List of one or several episodes to be removed from the dataset locally.",
)
############################################################################
# drop_frame
parser_drop_frame = subparsers.add_parser("drop_frame", parents=[base_parser])
parser_rm_ep.add_argument(
"--episode-range",
type=str,
nargs="*",
help="List of one or several frame ranges per episode to be removed from the dataset locally. For instance, using `--episode-frame-range 0-0-10 3-5-20` will remove from episode 0, the frames from indices 0 to 10 excluded, and from episode 3 the frames from indices 5 to 20.",
)
args = parser.parse_args()
kwargs = vars(args)
mode = kwargs.pop("mode")
repo_id = kwargs.pop("repo_id")
root = kwargs.pop("root")
local_files_only = kwargs.pop("local_files_only")
dataset = LeRobotDataset(
repo_id=repo_id,
root=root,
local_files_only=local_files_only,
)
if mode == "consolidate":
dataset.consolidate(**kwargs)
elif mode == "push_to_hub":
private = kwargs.pop("private") == 1
dataset.push_to_hub(private=private, **kwargs)
elif mode == "clone":
dataset.clone(**kwargs)
elif mode == "delete":
dataset.delete(**kwargs)
elif mode == "remove_episode":
dataset.remove_episode(**kwargs)
elif mode == "drop_frame":
ep_range = parse_episode_range_strings(kwargs.pop("episode_range"))
dataset.drop_frame(episode_range=ep_range, **kwargs)