Add frame level task (#693)

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Remi
2025-02-14 14:22:22 +01:00
committed by GitHub
parent d67ca342e9
commit 9d6886dd08
6 changed files with 105 additions and 50 deletions

View File

@@ -93,6 +93,24 @@ def test_dataset_initialization(lerobot_dataset_factory, tmp_path):
assert dataset.num_frames == len(dataset)
def test_add_frame_no_task(tmp_path):
features = {"1d": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, root=tmp_path / "test", features=features)
with pytest.raises(ValueError, match="The mandatory feature 'task' wasn't found in `frame` dictionnary."):
dataset.add_frame({"1d": torch.randn(1)})
def test_add_frame(tmp_path):
features = {"1d": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, root=tmp_path / "test", features=features)
dataset.add_frame({"1d": torch.randn(1), "task": "dummy"})
dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False)
assert len(dataset) == 1
assert dataset[0]["task"] == "dummy"
assert dataset[0]["task_index"] == 0
# TODO(aliberts):
# - [ ] test various attributes & state from init and create
# - [ ] test init with episodes and check num_frames