forked from tangger/lerobot
Add manage_dataset
This commit is contained in:
@@ -873,7 +873,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
return video_paths
|
||||
|
||||
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
|
||||
def consolidate(
|
||||
self,
|
||||
run_compute_stats: bool = True,
|
||||
keep_image_files: bool = False,
|
||||
batch_size: int = 8,
|
||||
num_workers: int = 8,
|
||||
) -> None:
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
@@ -896,7 +902,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if run_compute_stats:
|
||||
self.stop_image_writer()
|
||||
# TODO(aliberts): refactor stats in save_episodes
|
||||
self.meta.stats = compute_stats(self)
|
||||
self.meta.stats = compute_stats(self, batch_size=batch_size, num_workers=num_workers)
|
||||
serialized_stats = serialize_dict(self.meta.stats)
|
||||
write_json(serialized_stats, self.root / STATS_PATH)
|
||||
self.consolidated = True
|
||||
|
||||
110
lerobot/scripts/manage_dataset.py
Normal file
110
lerobot/scripts/manage_dataset.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
Utilities to manage a dataset.
|
||||
|
||||
Examples of usage:
|
||||
|
||||
- Consolidate a dataset, by encoding images into videos and computing statistics:
|
||||
```bash
|
||||
python lerobot/scripts/manage_dataset.py consolidate \
|
||||
--repo-id $USER/koch_test
|
||||
```
|
||||
|
||||
- Consolidate a dataset which is not uploaded on the hub yet:
|
||||
```bash
|
||||
python lerobot/scripts/manage_dataset.py consolidate \
|
||||
--repo-id $USER/koch_test \
|
||||
--local-files-only 1
|
||||
```
|
||||
|
||||
- Upload a dataset on the hub:
|
||||
```bash
|
||||
python lerobot/scripts/manage_dataset.py push_to_hub \
|
||||
--repo-id $USER/koch_test
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers(dest="mode", required=True)
|
||||
|
||||
# Set common options for all the subparsers
|
||||
base_parser = argparse.ArgumentParser(add_help=False)
|
||||
base_parser.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Root directory where the dataset will be stored (e.g. 'dataset/path').",
|
||||
)
|
||||
base_parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default="lerobot/test",
|
||||
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
||||
)
|
||||
base_parser.add_argument(
|
||||
"--local-files-only",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
|
||||
)
|
||||
|
||||
parser_conso = subparsers.add_parser("consolidate", parents=[base_parser])
|
||||
parser_conso.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Batch size loaded by DataLoader for computing the dataset statistics.",
|
||||
)
|
||||
parser_conso.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of processes of Dataloader for computing the dataset statistics.",
|
||||
)
|
||||
|
||||
parser_push = subparsers.add_parser("push_to_hub", parents=[base_parser])
|
||||
parser_push.add_argument(
|
||||
"--tags",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=None,
|
||||
help="Optional additional tags to categorize the dataset on the Hugging Face Hub. Use space-separated values (e.g. 'so100 indoor'). The tag 'LeRobot' will always be added.",
|
||||
)
|
||||
parser_push.add_argument(
|
||||
"--license",
|
||||
type=str,
|
||||
default="apache-2.0",
|
||||
help="Repo license. Must be one of https://huggingface.co/docs/hub/repositories-licenses. Defaults to mit.",
|
||||
)
|
||||
parser_push.add_argument(
|
||||
"--private",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Create a private dataset repository on the Hugging Face Hub.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
kwargs = vars(args)
|
||||
|
||||
mode = kwargs.pop("mode")
|
||||
repo_id = kwargs.pop("repo_id")
|
||||
root = kwargs.pop("root")
|
||||
local_files_only = kwargs.pop("local_files_only")
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=root,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
|
||||
if mode == "consolidate":
|
||||
dataset.consolidate(**kwargs)
|
||||
|
||||
elif mode == "push_to_hub":
|
||||
private = kwargs.pop("private") == 1
|
||||
dataset.push_to_hub(private=private, **kwargs)
|
||||
Reference in New Issue
Block a user