Remove Prod, Tests are passind

This commit is contained in:
Cadene
2024-04-19 23:18:45 +00:00
parent 35a573c98e
commit c20cf2fbbc
12 changed files with 96 additions and 110 deletions

View File

@@ -103,6 +103,13 @@ def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dat
repo_id=f"lerobot/{dataset_id}",
repo_type="dataset",
)
api.upload_file(
path_or_fileobj=info_path,
path_in_repo=str(info_path).replace(f"{root}/{dataset_id}", ""),
repo_id=f"lerobot/{dataset_id}",
repo_type="dataset",
revision=revision,
)
# stats
stats_path = meta_data_dir / "stats.safetensors"
@@ -113,6 +120,13 @@ def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dat
repo_id=f"lerobot/{dataset_id}",
repo_type="dataset",
)
api.upload_file(
path_or_fileobj=stats_path,
path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""),
repo_id=f"lerobot/{dataset_id}",
repo_type="dataset",
revision=revision,
)
# episode_data_index
episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
@@ -124,6 +138,13 @@ def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dat
repo_id=f"lerobot/{dataset_id}",
repo_type="dataset",
)
api.upload_file(
path_or_fileobj=ep_data_idx_path,
path_in_repo=str(ep_data_idx_path).replace(f"{root}/{dataset_id}", ""),
repo_id=f"lerobot/{dataset_id}",
repo_type="dataset",
revision=revision,
)
# copy in tests folder, the first episode and the meta_data directory
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]

View File

@@ -51,8 +51,10 @@ print(f"{hf_dataset.features=}")
# display useful statistics about frames and episodes, which are sequences of frames from the same video
print(f"number of frames: {len(hf_dataset)=}")
print(f"number of episodes: {len(hf_dataset.unique('episode_id'))=}")
print(f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_id')):.3f}")
print(f"number of episodes: {len(hf_dataset.unique('episode_index'))=}")
print(
f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_index')):.3f}"
)
# select the frames belonging to episode number 5
hf_dataset = hf_dataset.filter(lambda frame: frame["episode_index"] == 5)

View File

@@ -63,8 +63,9 @@ dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_inde
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grap all the image frames.
frames = [sample["observation.image"] for sample in dataset]
# but frames are now channel first to follow pytorch convention,
# to view them, we convert to channel last
# but frames are now float32 range [0,1] channel first to follow pytorch convention,
# to view them, we convert to uint8 range [0,255] channel last
frames = [(frame * 255).type(torch.uint8) for frame in frames]
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
# and finally save them to a mp4 video

View File

@@ -1,9 +1,13 @@
from pathlib import Path
import torch
from datasets import load_dataset, load_from_disk
from lerobot.common.datasets.utils import load_previous_and_future_frames
from lerobot.common.datasets.utils import (
load_episode_data_index,
load_hf_dataset,
load_previous_and_future_frames,
load_stats,
)
class AlohaDataset(torch.utils.data.Dataset):
@@ -40,13 +44,10 @@ class AlohaDataset(torch.utils.data.Dataset):
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
if self.root is not None:
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
else:
self.hf_dataset = load_dataset(
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
)
self.hf_dataset = self.hf_dataset.with_format("torch")
# load data from hub or locally when root is provided
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
self.stats = load_stats(dataset_id, version, root)
@property
def num_samples(self) -> int:
@@ -66,19 +67,11 @@ class AlohaDataset(torch.utils.data.Dataset):
item = load_previous_and_future_frames(
item,
self.hf_dataset,
self.episode_data_index,
self.delta_timestamps,
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
)
# convert images from channel last (PIL) to channel first (pytorch)
for key in self.image_keys:
if item[key].ndim == 3:
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
elif item[key].ndim == 4:
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
else:
raise ValueError(item[key].ndim)
if self.transform is not None:
item = self.transform(item)

View File

@@ -4,7 +4,7 @@ from pathlib import Path
import torch
from torchvision.transforms import v2
from lerobot.common.transforms import NormalizeTransform, Prod
from lerobot.common.transforms import NormalizeTransform
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
@@ -55,7 +55,6 @@ def make_dataset(
dataset_id=cfg.dataset_id,
split="train",
root=DATA_DIR,
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
)
stats = stats_dataset.stats
else:
@@ -63,7 +62,6 @@ def make_dataset(
transforms = v2.Compose(
[
Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
NormalizeTransform(
stats,
in_keys=[

View File

@@ -70,15 +70,6 @@ class PushtDataset(torch.utils.data.Dataset):
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
)
# convert images from channel last (PIL) to channel first (pytorch)
for key in self.image_keys:
if item[key].ndim == 3:
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
elif item[key].ndim == 4:
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
else:
raise ValueError(item[key].ndim)
if self.transform is not None:
item = self.transform(item)

View File

@@ -63,7 +63,6 @@ def load_hf_dataset(dataset_id, version, root, split) -> datasets.Dataset:
# TODO(rcadene): remove dataset_id everywhere and use repo_id instead
repo_id = f"lerobot/{dataset_id}"
hf_dataset = load_dataset(repo_id, revision=version, split=split)
hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
@@ -156,6 +155,7 @@ def load_previous_and_future_frames(
# load timestamps
ep_timestamps = hf_dataset.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
ep_timestamps = torch.stack(ep_timestamps)
# we make the assumption that the timestamps are sorted
ep_first_ts = ep_timestamps[0]
@@ -186,6 +186,7 @@ def load_previous_and_future_frames(
# load frames modality
item[key] = hf_dataset.select_columns(key)[data_ids][key]
item[key] = torch.stack(item[key])
item[f"{key}_is_pad"] = is_pad
return item
@@ -251,8 +252,7 @@ def compute_stats(hf_dataset, batch_size=32, max_num_samples=None):
hf_dataset,
num_workers=4,
batch_size=batch_size,
shuffle=False,
# pin_memory=cfg.device != "cpu",
shuffle=True,
drop_last=False,
)
return dataloader

View File

@@ -1,9 +1,13 @@
from pathlib import Path
import torch
from datasets import load_dataset, load_from_disk
from lerobot.common.datasets.utils import load_previous_and_future_frames
from lerobot.common.datasets.utils import (
load_episode_data_index,
load_hf_dataset,
load_previous_and_future_frames,
load_stats,
)
class XarmDataset(torch.utils.data.Dataset):
@@ -40,13 +44,10 @@ class XarmDataset(torch.utils.data.Dataset):
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
if self.root is not None:
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
else:
self.hf_dataset = load_dataset(
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
)
self.hf_dataset = self.hf_dataset.with_format("torch")
# load data from hub or locally when root is provided
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
self.stats = load_stats(dataset_id, version, root)
@property
def num_samples(self) -> int:
@@ -66,19 +67,11 @@ class XarmDataset(torch.utils.data.Dataset):
item = load_previous_and_future_frames(
item,
self.hf_dataset,
self.episode_data_index,
self.delta_timestamps,
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
)
# convert images from channel last (PIL) to channel first (pytorch)
for key in self.image_keys:
if item[key].ndim == 3:
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
elif item[key].ndim == 4:
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
else:
raise ValueError(item[key].ndim)
if self.transform is not None:
item = self.transform(item)

View File

@@ -15,8 +15,19 @@ def preprocess_observation(observation, transform=None):
for imgkey, img in imgs.items():
img = torch.from_numpy(img)
# convert to (b c h w) torch format
# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w")
img = img.type(torch.float32)
img /= 255
obs[imgkey] = img
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"

View File

@@ -1,4 +1,3 @@
import torch
from torchvision.transforms.v2 import Compose, Transform
@@ -12,40 +11,6 @@ def apply_inverse_transform(item, transform):
return item
class Prod(Transform):
invertible = True
def __init__(self, in_keys: list[str], prod: float):
super().__init__()
self.in_keys = in_keys
self.prod = prod
self.original_dtypes = {}
def forward(self, item):
for key in self.in_keys:
if key not in item:
continue
self.original_dtypes[key] = item[key].dtype
item[key] = item[key].type(torch.float32) * self.prod
return item
def inverse_transform(self, item):
for key in self.in_keys:
if key not in item:
continue
item[key] = (item[key] / self.prod).type(self.original_dtypes[key])
return item
# def transform_observation_spec(self, obs_spec):
# for key in self.in_keys:
# if obs_spec.get(key, None) is None:
# continue
# obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod
# obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod
# obs_spec[key].dtype = torch.float32
# return obs_spec
class NormalizeTransform(Transform):
invertible = True

View File

@@ -249,9 +249,23 @@ def eval_policy(
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
# c h w -> h w c
img = PILImage.fromarray(x.permute(1, 2, 0).numpy())
for img in ep_dict[key]:
# sanity check that images are channel first
c, h, w = img.shape
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
# sanity check that images are float32 in range [0,1]
assert img.dtype == torch.float32, f"expect torch.float32, but instead {img.dtype=}"
assert img.max() <= 1, f"expect pixels lower than 1, but instead {img.max()=}"
assert img.min() >= 0, f"expect pixels greater than 1, but instead {img.min()=}"
# from float32 in range [0,1] to uint8 in range [0,255]
img *= 255
img = img.type(torch.uint8)
# convert to channel last and numpy as expected by PIL
img = PILImage.fromarray(img.permute(1, 2, 0).numpy())
data_dict[key].append(img)
data_dict["index"] = torch.arange(0, total_frames, 1)

View File

@@ -15,6 +15,7 @@ from lerobot.common.datasets.utils import (
compute_stats,
flatten_dict,
get_stats_einops_patterns,
hf_transform_to_torch,
load_previous_and_future_frames,
unflatten_dict,
)
@@ -51,12 +52,6 @@ def test_factory(env_name, dataset_id, policy_name):
("next.done", 0, False),
]
for key in image_keys:
keys_ndim_required.append(
(key, 3, True),
)
assert dataset.hf_dataset[key].dtype == torch.uint8, f"{key}"
# test number of dimensions
for key, ndim, required in keys_ndim_required:
if key not in item:
@@ -126,6 +121,7 @@ def test_compute_stats_on_xarm():
# compute stats based on all frames from the dataset without any batching
expected_stats = {}
for k, pattern in stats_patterns.items():
full_batch[k] = full_batch[k].float()
expected_stats[k] = {}
expected_stats[k]["mean"] = einops.reduce(full_batch[k], pattern, "mean")
expected_stats[k]["std"] = torch.sqrt(
@@ -142,14 +138,15 @@ def test_compute_stats_on_xarm():
assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
# load stats used during training which are expected to match the ones returned by computed_stats
loaded_stats = dataset.stats
loaded_stats = dataset.stats # noqa: F841
# test loaded stats match expected stats
for k in stats_patterns:
assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"])
assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"])
assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"])
assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
# TODO(rcadene): we can't test this because expected_stats is computed on a subset
# # test loaded stats match expected stats
# for k in stats_patterns:
# assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"])
# assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"])
# assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"])
# assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
def test_load_previous_and_future_frames_within_tolerance():
@@ -160,7 +157,7 @@ def test_load_previous_and_future_frames_within_tolerance():
"episode_index": [0, 0, 0, 0, 0],
}
)
hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
@@ -182,7 +179,7 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range(
"episode_index": [0, 0, 0, 0, 0],
}
)
hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
@@ -202,7 +199,7 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range
"episode_index": [0, 0, 0, 0, 0],
}
)
hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),