most unit tests passing (TODO: convert datasets)

This commit is contained in:
Remi Cadene
2025-04-16 21:30:58 +02:00
parent c2a05a1fde
commit 6b6a990f4c
22 changed files with 150 additions and 136 deletions

View File

@@ -230,6 +230,8 @@ def episodes_factory(tasks_factory, stats_factory):
"meta/episodes/file_index": [],
"data/chunk_index": [],
"data/file_index": [],
"dataset_from_index": [],
"dataset_to_index": [],
"tasks": [],
"length": [],
}
@@ -241,6 +243,7 @@ def episodes_factory(tasks_factory, stats_factory):
for stats_key in flatten_dict({"stats": stats_factory(features)}):
d[stats_key] = []
num_frames = 0
remaining_tasks = list(tasks.index)
for ep_idx in range(total_episodes):
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
@@ -256,6 +259,8 @@ def episodes_factory(tasks_factory, stats_factory):
d["meta/episodes/file_index"].append(0)
d["data/chunk_index"].append(0)
d["data/file_index"].append(0)
d["dataset_from_index"].append(num_frames)
d["dataset_to_index"].append(num_frames + lengths[ep_idx])
d["tasks"].append(episode_tasks)
d["length"].append(lengths[ep_idx])
@@ -268,6 +273,8 @@ def episodes_factory(tasks_factory, stats_factory):
for stats_key, stats in flatten_dict({"stats": stats_factory(features)}).items():
d[stats_key].append(stats)
num_frames += lengths[ep_idx]
return Dataset.from_dict(d)
return _create_episodes
@@ -283,10 +290,10 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
) -> datasets.Dataset:
if tasks is None:
tasks = tasks_factory()
if episodes is None:
episodes = episodes_factory()
if features is None:
features = features_factory()
if episodes is None:
episodes = episodes_factory(features)
timestamp_col = np.array([], dtype=np.float32)
frame_index_col = np.array([], dtype=np.int64)

View File

@@ -10,7 +10,7 @@ from lerobot.common.datasets.utils import (
DEFAULT_EPISODES_PATH,
DEFAULT_TASKS_PATH,
INFO_PATH,
LEGACY_STATS_PATH,
STATS_PATH,
)
from tests.fixtures.constants import LEROBOT_TEST_DIR
@@ -70,7 +70,7 @@ def mock_snapshot_download_factory(
# List all possible files
all_files = [
INFO_PATH,
LEGACY_STATS_PATH,
STATS_PATH,
# TODO(rcadene): remove naive chunk 0 file 0 ?
DEFAULT_TASKS_PATH.format(chunk_index=0, file_index=0),
DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0),