Progress on aggregate_datasets
This commit is contained in:
9
tests/fixtures/dataset_factories.py
vendored
9
tests/fixtures/dataset_factories.py
vendored
@@ -212,6 +212,7 @@ def tasks_factory():
|
||||
def episodes_factory(tasks_factory, stats_factory):
|
||||
def _create_episodes(
|
||||
features: dict[str],
|
||||
fps: int = DEFAULT_FPS,
|
||||
total_episodes: int = 3,
|
||||
total_frames: int = 400,
|
||||
video_keys: list[str] | None = None,
|
||||
@@ -252,6 +253,8 @@ def episodes_factory(tasks_factory, stats_factory):
|
||||
for video_key in video_keys:
|
||||
d[f"videos/{video_key}/chunk_index"] = []
|
||||
d[f"videos/{video_key}/file_index"] = []
|
||||
d[f"videos/{video_key}/from_timestamp"] = []
|
||||
d[f"videos/{video_key}/to_timestamp"] = []
|
||||
|
||||
for stats_key in flatten_dict({"stats": stats_factory(features)}):
|
||||
d[stats_key] = []
|
||||
@@ -281,6 +284,8 @@ def episodes_factory(tasks_factory, stats_factory):
|
||||
for video_key in video_keys:
|
||||
d[f"videos/{video_key}/chunk_index"].append(0)
|
||||
d[f"videos/{video_key}/file_index"].append(0)
|
||||
d[f"videos/{video_key}/from_timestamp"].append(num_frames / fps)
|
||||
d[f"videos/{video_key}/to_timestamp"].append((num_frames + lengths[ep_idx]) / fps)
|
||||
|
||||
# Add stats columns like "stats/action/max"
|
||||
for stats_key, stats in flatten_dict({"stats": stats_factory(features)}).items():
|
||||
@@ -306,7 +311,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
||||
if features is None:
|
||||
features = features_factory()
|
||||
if episodes is None:
|
||||
episodes = episodes_factory(features)
|
||||
episodes = episodes_factory(features, fps)
|
||||
|
||||
timestamp_col = np.array([], dtype=np.float32)
|
||||
frame_index_col = np.array([], dtype=np.int64)
|
||||
@@ -379,6 +384,7 @@ def lerobot_dataset_metadata_factory(
|
||||
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
||||
episodes = episodes_factory(
|
||||
features=info["features"],
|
||||
fps=info["fps"],
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
video_keys=video_keys,
|
||||
@@ -441,6 +447,7 @@ def lerobot_dataset_factory(
|
||||
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
||||
episodes_metadata = episodes_factory(
|
||||
features=info["features"],
|
||||
fps=info["fps"],
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
video_keys=video_keys,
|
||||
|
||||
Reference in New Issue
Block a user