most unit tests passing (TODO: convert datasets)

This commit is contained in:
Remi Cadene
2025-04-16 21:30:58 +02:00
parent c2a05a1fde
commit 6b6a990f4c
22 changed files with 150 additions and 136 deletions

View File

@@ -230,6 +230,8 @@ def episodes_factory(tasks_factory, stats_factory):
"meta/episodes/file_index": [],
"data/chunk_index": [],
"data/file_index": [],
"dataset_from_index": [],
"dataset_to_index": [],
"tasks": [],
"length": [],
}
@@ -241,6 +243,7 @@ def episodes_factory(tasks_factory, stats_factory):
for stats_key in flatten_dict({"stats": stats_factory(features)}):
d[stats_key] = []
num_frames = 0
remaining_tasks = list(tasks.index)
for ep_idx in range(total_episodes):
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
@@ -256,6 +259,8 @@ def episodes_factory(tasks_factory, stats_factory):
d["meta/episodes/file_index"].append(0)
d["data/chunk_index"].append(0)
d["data/file_index"].append(0)
d["dataset_from_index"].append(num_frames)
d["dataset_to_index"].append(num_frames + lengths[ep_idx])
d["tasks"].append(episode_tasks)
d["length"].append(lengths[ep_idx])
@@ -268,6 +273,8 @@ def episodes_factory(tasks_factory, stats_factory):
for stats_key, stats in flatten_dict({"stats": stats_factory(features)}).items():
d[stats_key].append(stats)
num_frames += lengths[ep_idx]
return Dataset.from_dict(d)
return _create_episodes
@@ -283,10 +290,10 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
) -> datasets.Dataset:
if tasks is None:
tasks = tasks_factory()
if episodes is None:
episodes = episodes_factory()
if features is None:
features = features_factory()
if episodes is None:
episodes = episodes_factory(features)
timestamp_col = np.array([], dtype=np.float32)
frame_index_col = np.array([], dtype=np.int64)

View File

@@ -10,7 +10,7 @@ from lerobot.common.datasets.utils import (
DEFAULT_EPISODES_PATH,
DEFAULT_TASKS_PATH,
INFO_PATH,
LEGACY_STATS_PATH,
STATS_PATH,
)
from tests.fixtures.constants import LEROBOT_TEST_DIR
@@ -70,7 +70,7 @@ def mock_snapshot_download_factory(
# List all possible files
all_files = [
INFO_PATH,
LEGACY_STATS_PATH,
STATS_PATH,
# TODO(rcadene): remove naive chunk 0 file 0 ?
DEFAULT_TASKS_PATH.format(chunk_index=0, file_index=0),
DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0),

View File

@@ -47,17 +47,23 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
)
# save 2 first frames of first episode
i = dataset.episode_data_index["from"][0].item()
i = dataset.meta.episodes["dataset_from_index"][0].item()
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
# save 2 frames at the middle of first episode
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
i = int(
(
dataset.meta.episodes["dataset_to_index"][0].item()
- dataset.meta.episodes["dataset_from_index"][0].item()
)
/ 2
)
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
# save 2 last frames of first episode
i = dataset.episode_data_index["to"][0].item()
i = dataset.meta.episodes["dataset_to_index"][0].item()
save_file(dataset[i - 2], repo_dir / f"frame_{i - 2}.safetensors")
save_file(dataset[i - 1], repo_dir / f"frame_{i - 1}.safetensors")
@@ -65,17 +71,17 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
# We currently cant because our test dataset only contains the first episode
# # save 2 first frames of second episode
# i = dataset.episode_data_index["from"][1].item()
# i = dataset.meta.episodes["dataset_from_index"][1].item()
# save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
# save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
# # save 2 last frames of second episode
# i = dataset.episode_data_index["to"][1].item()
# i = dataset.meta.episodes["dataset_to_index"][1].item()
# save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
# save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
# # save 2 last frames of last episode
# i = dataset.episode_data_index["to"][-1].item()
# i = dataset.meta.episodes["dataset_to_index"][-1].item()
# save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
# save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")

View File

@@ -507,17 +507,23 @@ def test_backward_compatibility(repo_id):
)
# test2 first frames of first episode
i = dataset.episode_data_index["from"][0].item()
i = dataset.meta.episodes["dataset_from_index"][0].item()
load_and_compare(i)
load_and_compare(i + 1)
# test 2 frames at the middle of first episode
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
i = int(
(
dataset.meta.episodes["dataset_to_index"][0].item()
- dataset.meta.episodes["dataset_from_index"][0].item()
)
/ 2
)
load_and_compare(i)
load_and_compare(i + 1)
# test 2 last frames of first episode
i = dataset.episode_data_index["to"][0].item()
i = dataset.meta.episodes["dataset_to_index"][0].item()
load_and_compare(i - 2)
load_and_compare(i - 1)
@@ -525,17 +531,17 @@ def test_backward_compatibility(repo_id):
# We currently cant because our test dataset only contains the first episode
# # test 2 first frames of second episode
# i = dataset.episode_data_index["from"][1].item()
# i = dataset.meta.episodes["dataset_from_index"][1].item()
# load_and_compare(i)
# load_and_compare(i + 1)
# # test 2 last frames of second episode
# i = dataset.episode_data_index["to"][1].item()
# i = dataset.meta.episodes["dataset_to_index"][1].item()
# load_and_compare(i - 2)
# load_and_compare(i - 1)
# # test 2 last frames of last episode
# i = dataset.episode_data_index["to"][-1].item()
# i = dataset.meta.episodes["dataset_to_index"][-1].item()
# load_and_compare(i - 2)
# load_and_compare(i - 1)

View File

@@ -43,8 +43,8 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n
def synced_timestamps_factory(hf_dataset_factory):
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
hf_dataset = hf_dataset_factory(fps=fps)
timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
episode_indices = torch.stack(hf_dataset["episode_index"]).numpy()
timestamps = hf_dataset["timestamp"].numpy()
episode_indices = hf_dataset["episode_index"].numpy()
episode_data_index = calculate_episode_data_index(hf_dataset)
return timestamps, episode_indices, episode_data_index

View File

@@ -68,7 +68,11 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
},
}
info = info_factory(
total_episodes=1, total_frames=1, camera_features=camera_features, motor_features=motor_features
total_episodes=1,
total_frames=1,
total_tasks=1,
camera_features=camera_features,
motor_features=motor_features,
)
ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info)
return ds_meta

View File

@@ -32,7 +32,7 @@ def test_drop_n_first_frames():
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1)
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], drop_n_first_frames=1)
assert sampler.indices == [1, 4, 5]
assert len(sampler) == 3
assert list(sampler) == [1, 4, 5]
@@ -48,7 +48,7 @@ def test_drop_n_last_frames():
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1)
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], drop_n_last_frames=1)
assert sampler.indices == [0, 3, 4]
assert len(sampler) == 3
assert list(sampler) == [0, 3, 4]
@@ -64,7 +64,9 @@ def test_episode_indices_to_use():
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2])
sampler = EpisodeAwareSampler(
episode_data_index["from"], episode_data_index["to"], episode_indices_to_use=[0, 2]
)
assert sampler.indices == [0, 1, 3, 4, 5]
assert len(sampler) == 5
assert list(sampler) == [0, 1, 3, 4, 5]
@@ -80,11 +82,11 @@ def test_shuffle():
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, shuffle=False)
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], shuffle=False)
assert sampler.indices == [0, 1, 2, 3, 4, 5]
assert len(sampler) == 6
assert list(sampler) == [0, 1, 2, 3, 4, 5]
sampler = EpisodeAwareSampler(episode_data_index, shuffle=True)
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], shuffle=True)
assert sampler.indices == [0, 1, 2, 3, 4, 5]
assert len(sampler) == 6
assert set(sampler) == {0, 1, 2, 3, 4, 5}

View File

@@ -1,21 +0,0 @@
import torch
from datasets import Dataset
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.common.datasets.utils import (
hf_transform_to_torch,
)
def test_calculate_episode_data_index():
dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"index": [0, 1, 2, 3, 4, 5],
"episode_index": [0, 0, 1, 2, 2, 2],
},
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))