forked from tangger/lerobot
Add frame level task (#693)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
@@ -93,6 +93,24 @@ def test_dataset_initialization(lerobot_dataset_factory, tmp_path):
|
||||
assert dataset.num_frames == len(dataset)
|
||||
|
||||
|
||||
def test_add_frame_no_task(tmp_path):
|
||||
features = {"1d": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, root=tmp_path / "test", features=features)
|
||||
with pytest.raises(ValueError, match="The mandatory feature 'task' wasn't found in `frame` dictionnary."):
|
||||
dataset.add_frame({"1d": torch.randn(1)})
|
||||
|
||||
|
||||
def test_add_frame(tmp_path):
|
||||
features = {"1d": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"1d": torch.randn(1), "task": "dummy"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate(run_compute_stats=False)
|
||||
assert len(dataset) == 1
|
||||
assert dataset[0]["task"] == "dummy"
|
||||
assert dataset[0]["task_index"] == 0
|
||||
|
||||
|
||||
# TODO(aliberts):
|
||||
# - [ ] test various attributes & state from init and create
|
||||
# - [ ] test init with episodes and check num_frames
|
||||
|
||||
Reference in New Issue
Block a user