Add download_and_upload_dataset.py in script, update all datasets, update online training

This commit is contained in:
Cadene
2024-04-15 21:26:33 +00:00
parent c6aca7fe44
commit 67d79732f9
8 changed files with 493 additions and 519 deletions

View File

@@ -1,5 +1,4 @@
import logging
import os
from pathlib import Path
import torch
@@ -8,11 +7,6 @@ from torchvision.transforms import v2
from lerobot.common.datasets.utils import compute_stats
from lerobot.common.transforms import NormalizeTransform, Prod
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data`
# to load a subset of our datasets for faster continuous integration.
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
def make_dataset(
cfg,
@@ -57,12 +51,11 @@ def make_dataset(
# instantiate a one frame dataset with light transform
stats_dataset = clsfunc(
dataset_id=cfg.dataset_id,
root=DATA_DIR,
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
)
# load stats if the file exists already or compute stats and save it
precomputed_stats_path = stats_dataset.data_dir / "stats.pth"
precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth"
if precomputed_stats_path.exists():
stats = torch.load(precomputed_stats_path)
else:
@@ -94,7 +87,6 @@ def make_dataset(
dataset = clsfunc(
dataset_id=cfg.dataset_id,
root=DATA_DIR,
delta_timestamps=delta_timestamps,
transform=transforms,
)