forked from tangger/lerobot
most unit tests passing (TODO: convert datasets)
This commit is contained in:
11
tests/fixtures/dataset_factories.py
vendored
11
tests/fixtures/dataset_factories.py
vendored
@@ -230,6 +230,8 @@ def episodes_factory(tasks_factory, stats_factory):
|
||||
"meta/episodes/file_index": [],
|
||||
"data/chunk_index": [],
|
||||
"data/file_index": [],
|
||||
"dataset_from_index": [],
|
||||
"dataset_to_index": [],
|
||||
"tasks": [],
|
||||
"length": [],
|
||||
}
|
||||
@@ -241,6 +243,7 @@ def episodes_factory(tasks_factory, stats_factory):
|
||||
for stats_key in flatten_dict({"stats": stats_factory(features)}):
|
||||
d[stats_key] = []
|
||||
|
||||
num_frames = 0
|
||||
remaining_tasks = list(tasks.index)
|
||||
for ep_idx in range(total_episodes):
|
||||
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
|
||||
@@ -256,6 +259,8 @@ def episodes_factory(tasks_factory, stats_factory):
|
||||
d["meta/episodes/file_index"].append(0)
|
||||
d["data/chunk_index"].append(0)
|
||||
d["data/file_index"].append(0)
|
||||
d["dataset_from_index"].append(num_frames)
|
||||
d["dataset_to_index"].append(num_frames + lengths[ep_idx])
|
||||
d["tasks"].append(episode_tasks)
|
||||
d["length"].append(lengths[ep_idx])
|
||||
|
||||
@@ -268,6 +273,8 @@ def episodes_factory(tasks_factory, stats_factory):
|
||||
for stats_key, stats in flatten_dict({"stats": stats_factory(features)}).items():
|
||||
d[stats_key].append(stats)
|
||||
|
||||
num_frames += lengths[ep_idx]
|
||||
|
||||
return Dataset.from_dict(d)
|
||||
|
||||
return _create_episodes
|
||||
@@ -283,10 +290,10 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
||||
) -> datasets.Dataset:
|
||||
if tasks is None:
|
||||
tasks = tasks_factory()
|
||||
if episodes is None:
|
||||
episodes = episodes_factory()
|
||||
if features is None:
|
||||
features = features_factory()
|
||||
if episodes is None:
|
||||
episodes = episodes_factory(features)
|
||||
|
||||
timestamp_col = np.array([], dtype=np.float32)
|
||||
frame_index_col = np.array([], dtype=np.int64)
|
||||
|
||||
4
tests/fixtures/hub.py
vendored
4
tests/fixtures/hub.py
vendored
@@ -10,7 +10,7 @@ from lerobot.common.datasets.utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_TASKS_PATH,
|
||||
INFO_PATH,
|
||||
LEGACY_STATS_PATH,
|
||||
STATS_PATH,
|
||||
)
|
||||
from tests.fixtures.constants import LEROBOT_TEST_DIR
|
||||
|
||||
@@ -70,7 +70,7 @@ def mock_snapshot_download_factory(
|
||||
# List all possible files
|
||||
all_files = [
|
||||
INFO_PATH,
|
||||
LEGACY_STATS_PATH,
|
||||
STATS_PATH,
|
||||
# TODO(rcadene): remove naive chunk 0 file 0 ?
|
||||
DEFAULT_TASKS_PATH.format(chunk_index=0, file_index=0),
|
||||
DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0),
|
||||
|
||||
Reference in New Issue
Block a user