Dataset v2.0 (#461)

Co-authored-by: Remi <remi.cadene@huggingface.co>
This commit is contained in:
Simon Alibert
2024-11-29 19:04:00 +01:00
committed by GitHub
parent 96c7052777
commit 32eb0cec8f
71 changed files with 6115 additions and 2235 deletions

View File

@@ -29,7 +29,6 @@ python lerobot/scripts/control_robot.py teleoperate \
```bash
python lerobot/scripts/control_robot.py record \
--fps 30 \
--root tmp/data \
--repo-id $USER/koch_test \
--num-episodes 1 \
--run-compute-stats 0
@@ -38,7 +37,6 @@ python lerobot/scripts/control_robot.py record \
- Visualize dataset:
```bash
python lerobot/scripts/visualize_dataset.py \
--root tmp/data \
--repo-id $USER/koch_test \
--episode-index 0
```
@@ -47,7 +45,6 @@ python lerobot/scripts/visualize_dataset.py \
```bash
python lerobot/scripts/control_robot.py replay \
--fps 30 \
--root tmp/data \
--repo-id $USER/koch_test \
--episode 0
```
@@ -57,7 +54,6 @@ python lerobot/scripts/control_robot.py replay \
```bash
python lerobot/scripts/control_robot.py record \
--fps 30 \
--root data \
--repo-id $USER/koch_pick_place_lego \
--num-episodes 50 \
--warmup-time-s 2 \
@@ -77,7 +73,7 @@ To avoid resuming by deleting the dataset, use `--force-override 1`.
- Train on this dataset with the ACT policy:
```bash
DATA_DIR=data python lerobot/scripts/train.py \
python lerobot/scripts/train.py \
policy=act_koch_real \
env=koch_real \
dataset_repo_id=$USER/koch_pick_place_lego \
@@ -88,7 +84,6 @@ DATA_DIR=data python lerobot/scripts/train.py \
```bash
python lerobot/scripts/control_robot.py record \
--fps 30 \
--root data \
--repo-id $USER/eval_act_koch_real \
--num-episodes 10 \
--warmup-time-s 2 \
@@ -106,12 +101,6 @@ from typing import List
# from safetensors.torch import load_file, save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.populate_dataset import (
create_lerobot_dataset,
delete_current_episode,
init_dataset,
save_current_episode,
)
from lerobot.common.robot_devices.control_utils import (
control_loop,
has_method,
@@ -121,6 +110,7 @@ from lerobot.common.robot_devices.control_utils import (
record_episode,
reset_environment,
sanity_check_dataset_name,
sanity_check_dataset_robot_compatibility,
stop_recording,
warmup_record,
)
@@ -196,25 +186,28 @@ def teleoperate(
@safe_disconnect
def record(
robot: Robot,
root: str,
root: Path,
repo_id: str,
single_task: str,
pretrained_policy_name_or_path: str | None = None,
policy_overrides: List[str] | None = None,
fps: int | None = None,
warmup_time_s=2,
episode_time_s=10,
reset_time_s=5,
num_episodes=50,
video=True,
run_compute_stats=True,
push_to_hub=True,
tags=None,
num_image_writer_processes=0,
num_image_writer_threads_per_camera=4,
force_override=False,
display_cameras=True,
play_sounds=True,
):
warmup_time_s: int | float = 2,
episode_time_s: int | float = 10,
reset_time_s: int | float = 5,
num_episodes: int = 50,
video: bool = True,
run_compute_stats: bool = True,
push_to_hub: bool = True,
tags: list[str] | None = None,
num_image_writer_processes: int = 0,
num_image_writer_threads_per_camera: int = 4,
display_cameras: bool = True,
play_sounds: bool = True,
resume: bool = False,
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
local_files_only: bool = False,
) -> LeRobotDataset:
# TODO(rcadene): Add option to record logs
listener = None
events = None
@@ -222,6 +215,11 @@ def record(
device = None
use_amp = None
if single_task:
task = single_task
else:
raise NotImplementedError("Only single-task recording is supported for now")
# Load pretrained policy
if pretrained_policy_name_or_path is not None:
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
@@ -234,18 +232,29 @@ def record(
f"There is a mismatch between the provided fps ({fps}) and the one from policy config ({policy_fps})."
)
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy)
dataset = init_dataset(
repo_id,
root,
force_override,
fps,
video,
write_images=robot.has_camera,
num_image_writer_processes=num_image_writer_processes,
num_image_writer_threads=num_image_writer_threads_per_camera * robot.num_cameras,
)
if resume:
dataset = LeRobotDataset(
repo_id,
root=root,
local_files_only=local_files_only,
)
dataset.start_image_writer(
num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads_per_camera * len(robot.cameras),
)
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video)
else:
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy)
dataset = LeRobotDataset.create(
repo_id,
fps,
root=root,
robot=robot,
use_videos=video,
image_writer_processes=num_image_writer_processes,
image_writer_threads=num_image_writer_threads_per_camera * len(robot.cameras),
)
if not robot.is_connected:
robot.connect()
@@ -263,12 +272,17 @@ def record(
if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop()
recorded_episodes = 0
while True:
if dataset["num_episodes"] >= num_episodes:
if recorded_episodes >= num_episodes:
break
episode_index = dataset["num_episodes"]
log_say(f"Recording episode {episode_index}", play_sounds)
# TODO(aliberts): add task prompt for multitask here. Might need to temporarily disable event if
# input() messes with them.
# if multi_task:
# task = input("Enter your task description: ")
log_say(f"Recording episode {dataset.num_episodes}", play_sounds)
record_episode(
dataset=dataset,
robot=robot,
@@ -286,7 +300,7 @@ def record(
# TODO(rcadene): add an option to enable teleoperation during reset
# Skip reset for the last episode to be recorded
if not events["stop_recording"] and (
(episode_index < num_episodes - 1) or events["rerecord_episode"]
(dataset.num_episodes < num_episodes - 1) or events["rerecord_episode"]
):
log_say("Reset the environment", play_sounds)
reset_environment(robot, events, reset_time_s)
@@ -295,11 +309,11 @@ def record(
log_say("Re-record episode", play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
delete_current_episode(dataset)
dataset.clear_episode_buffer()
continue
# Increment by one dataset["current_episode_index"]
save_current_episode(dataset)
dataset.save_episode(task)
recorded_episodes += 1
if events["stop_recording"]:
break
@@ -307,35 +321,42 @@ def record(
log_say("Stop recording", play_sounds, blocking=True)
stop_recording(robot, listener, display_cameras)
lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds)
if run_compute_stats:
logging.info("Computing dataset statistics")
dataset.consolidate(run_compute_stats)
if push_to_hub:
dataset.push_to_hub(tags=tags)
log_say("Exiting", play_sounds)
return lerobot_dataset
return dataset
@safe_disconnect
def replay(
robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug", play_sounds=True
robot: Robot,
root: Path,
repo_id: str,
episode: int,
fps: int | None = None,
play_sounds: bool = True,
local_files_only: bool = True,
):
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
# TODO(rcadene): Add option to record logs
local_dir = Path(root) / repo_id
if not local_dir.exists():
raise ValueError(local_dir)
dataset = LeRobotDataset(repo_id, root=root)
items = dataset.hf_dataset.select_columns("action")
from_idx = dataset.episode_data_index["from"][episode].item()
to_idx = dataset.episode_data_index["to"][episode].item()
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
actions = dataset.hf_dataset.select_columns("action")
if not robot.is_connected:
robot.connect()
log_say("Replaying episode", play_sounds, blocking=True)
for idx in range(from_idx, to_idx):
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
action = items[idx]["action"]
action = actions[idx]["action"]
robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t
@@ -384,13 +405,25 @@ if __name__ == "__main__":
)
parser_record = subparsers.add_parser("record", parents=[base_parser])
task_args = parser_record.add_mutually_exclusive_group(required=True)
parser_record.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
)
task_args.add_argument(
"--single-task",
type=str,
help="A short but accurate description of the task performed during the recording.",
)
# TODO(aliberts): add multi-task support
# task_args.add_argument(
# "--multi-task",
# type=int,
# help="You will need to enter the task performed at the start of each episode.",
# )
parser_record.add_argument(
"--root",
type=Path,
default="data",
default=None,
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
)
parser_record.add_argument(
@@ -458,10 +491,10 @@ if __name__ == "__main__":
),
)
parser_record.add_argument(
"--force-override",
"--resume",
type=int,
default=0,
help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.",
help="Resume recording on an existing dataset.",
)
parser_record.add_argument(
"-p",
@@ -486,7 +519,7 @@ if __name__ == "__main__":
parser_replay.add_argument(
"--root",
type=Path,
default="data",
default=None,
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
)
parser_replay.add_argument(

View File

@@ -484,7 +484,7 @@ def main(
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=str(pretrained_policy_path))
else:
# Note: We need the dataset stats to pass to the policy's normalization modules.
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).meta.stats)
assert isinstance(policy, nn.Module)
policy.eval()

View File

@@ -117,10 +117,14 @@ def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str
def push_dataset_card_to_hub(
repo_id: str, revision: str | None, tags: list | None = None, text: str | None = None
repo_id: str,
revision: str | None,
tags: list | None = None,
license: str = "apache-2.0",
**card_kwargs,
):
"""Creates and pushes a LeRobotDataset Card with appropriate tags to easily find it on the hub."""
card = create_lerobot_dataset_card(tags=tags, text=text)
card = create_lerobot_dataset_card(tags=tags, license=license, **card_kwargs)
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=revision)
@@ -260,7 +264,7 @@ def push_dataset_to_hub(
episode_index = 0
tests_videos_dir = tests_data_dir / repo_id / "videos"
tests_videos_dir.mkdir(parents=True, exist_ok=True)
for key in lerobot_dataset.video_frame_keys:
for key in lerobot_dataset.camera_keys:
fname = f"{key}_episode_{episode_index:06d}.mp4"
shutil.copy(videos_dir / fname, tests_videos_dir / fname)

View File

@@ -171,9 +171,9 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_online):
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.training.batch_size
avg_samples_per_ep = dataset.num_samples / dataset.num_episodes
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / dataset.num_samples
num_epochs = num_samples / dataset.num_frames
log_items = [
f"step:{format_big_number(step)}",
# number of samples seen during training
@@ -208,9 +208,9 @@ def log_eval_info(logger, info, step, cfg, dataset, is_online):
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.training.batch_size
avg_samples_per_ep = dataset.num_samples / dataset.num_episodes
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / dataset.num_samples
num_epochs = num_samples / dataset.num_frames
log_items = [
f"step:{format_big_number(step)}",
# number of samples seen during training
@@ -328,7 +328,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("make_policy")
policy = make_policy(
hydra_cfg=cfg,
dataset_stats=offline_dataset.stats if not cfg.resume else None,
dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
)
assert isinstance(policy, nn.Module)
@@ -349,7 +349,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
logging.info(f"{cfg.training.online_steps=}")
logging.info(f"{offline_dataset.num_samples=} ({format_big_number(offline_dataset.num_samples)})")
logging.info(f"{offline_dataset.num_frames=} ({format_big_number(offline_dataset.num_frames)})")
logging.info(f"{offline_dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
@@ -573,7 +573,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1,
online_sampling_ratio=cfg.training.online_sampling_ratio,
)
sampler.num_samples = len(concat_dataset)
sampler.num_frames = len(concat_dataset)
update_online_buffer_s = time.perf_counter() - start_update_buffer_time

View File

@@ -100,7 +100,7 @@ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
def visualize_dataset(
repo_id: str,
dataset: LeRobotDataset,
episode_index: int,
batch_size: int = 32,
num_workers: int = 0,
@@ -108,7 +108,6 @@ def visualize_dataset(
web_port: int = 9090,
ws_port: int = 9087,
save: bool = False,
root: Path | None = None,
output_dir: Path | None = None,
) -> Path | None:
if save:
@@ -116,8 +115,7 @@ def visualize_dataset(
output_dir is not None
), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
logging.info("Loading dataset")
dataset = LeRobotDataset(repo_id, root=root)
repo_id = dataset.repo_id
logging.info("Loading dataloader")
episode_sampler = EpisodeSampler(dataset, episode_index)
@@ -153,7 +151,7 @@ def visualize_dataset(
rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
# display each camera image
for key in dataset.camera_keys:
for key in dataset.meta.camera_keys:
# TODO(rcadene): add `.compress()`? is it lossless?
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
@@ -209,11 +207,17 @@ def main():
required=True,
help="Episode to visualize.",
)
parser.add_argument(
"--local-files-only",
type=int,
default=0,
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
)
parser.add_argument(
"--root",
type=Path,
default=None,
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
help="Root directory for the dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
)
parser.add_argument(
"--output-dir",
@@ -268,7 +272,15 @@ def main():
)
args = parser.parse_args()
visualize_dataset(**vars(args))
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
root = kwargs.pop("root")
local_files_only = kwargs.pop("local_files_only")
logging.info("Loading dataset")
dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
visualize_dataset(dataset, **vars(args))
if __name__ == "__main__":

View File

@@ -93,18 +93,17 @@ def run_server(
def show_episode(dataset_namespace, dataset_name, episode_id):
dataset_info = {
"repo_id": dataset.repo_id,
"num_samples": dataset.num_samples,
"num_samples": dataset.num_frames,
"num_episodes": dataset.num_episodes,
"fps": dataset.fps,
}
video_paths = get_episode_video_paths(dataset, episode_id)
language_instruction = get_episode_language_instruction(dataset, episode_id)
video_paths = [dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys]
tasks = dataset.meta.episodes[episode_id]["tasks"]
videos_info = [
{"url": url_for("static", filename=video_path), "filename": Path(video_path).name}
{"url": url_for("static", filename=video_path), "filename": video_path.name}
for video_path in video_paths
]
if language_instruction:
videos_info[0]["language_instruction"] = language_instruction
videos_info[0]["language_instruction"] = tasks
ep_csv_url = url_for("static", filename=get_ep_csv_fname(episode_id))
return render_template(
@@ -131,16 +130,16 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
from_idx = dataset.episode_data_index["from"][episode_index]
to_idx = dataset.episode_data_index["to"][episode_index]
has_state = "observation.state" in dataset.hf_dataset.features
has_action = "action" in dataset.hf_dataset.features
has_state = "observation.state" in dataset.features
has_action = "action" in dataset.features
# init header of csv with state and action names
header = ["timestamp"]
if has_state:
dim_state = len(dataset.hf_dataset["observation.state"][0])
dim_state = dataset.meta.shapes["observation.state"][0]
header += [f"state_{i}" for i in range(dim_state)]
if has_action:
dim_action = len(dataset.hf_dataset["action"][0])
dim_action = dataset.meta.shapes["action"][0]
header += [f"action_{i}" for i in range(dim_action)]
columns = ["timestamp"]
@@ -172,27 +171,12 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
return [
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
for key in dataset.video_frame_keys
for key in dataset.meta.video_keys
]
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
# check if the dataset has language instructions
if "language_instruction" not in dataset.hf_dataset.features:
return None
# get first frame index
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
# with the tf.tensor appearing in the string
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
def visualize_dataset_html(
repo_id: str,
root: Path | None = None,
dataset: LeRobotDataset,
episodes: list[int] = None,
output_dir: Path | None = None,
serve: bool = True,
@@ -202,13 +186,11 @@ def visualize_dataset_html(
) -> Path | None:
init_logging()
dataset = LeRobotDataset(repo_id, root=root)
if not dataset.video:
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
if len(dataset.meta.image_keys) > 0:
raise NotImplementedError(f"Image keys ({dataset.meta.image_keys=}) are currently not supported.")
if output_dir is None:
output_dir = f"outputs/visualize_dataset_html/{repo_id}"
output_dir = f"outputs/visualize_dataset_html/{dataset.repo_id}"
output_dir = Path(output_dir)
if output_dir.exists():
@@ -225,7 +207,7 @@ def visualize_dataset_html(
static_dir.mkdir(parents=True, exist_ok=True)
ln_videos_dir = static_dir / "videos"
if not ln_videos_dir.exists():
ln_videos_dir.symlink_to(dataset.videos_dir.resolve())
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
template_dir = Path(__file__).resolve().parent.parent / "templates"
@@ -252,6 +234,12 @@ def main():
required=True,
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
)
parser.add_argument(
"--local-files-only",
type=int,
default=0,
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
)
parser.add_argument(
"--root",
type=Path,
@@ -297,7 +285,13 @@ def main():
)
args = parser.parse_args()
visualize_dataset_html(**vars(args))
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
root = kwargs.pop("root")
local_files_only = kwargs.pop("local_files_only")
dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
visualize_dataset_html(dataset, **kwargs)
if __name__ == "__main__":

View File

@@ -157,7 +157,7 @@ def visualize_transforms(cfg, output_dir: Path, n_examples: int = 5):
output_dir.mkdir(parents=True, exist_ok=True)
# Get 1st frame from 1st camera of 1st episode
original_frame = dataset[0][dataset.camera_keys[0]]
original_frame = dataset[0][dataset.meta.camera_keys[0]]
to_pil(original_frame).save(output_dir / "original_frame.png", quality=100)
print("\nOriginal frame saved to:")
print(f" {output_dir / 'original_frame.png'}.")