Fix unit tests

This commit is contained in:
Remi Cadene
2025-04-22 10:35:20 +02:00
parent 601b5fdbfe
commit 367d9bda7d
5 changed files with 95 additions and 75 deletions

View File

@@ -13,10 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import re
from copy import deepcopy
from itertools import chain
from pathlib import Path
@@ -36,8 +34,6 @@ from lerobot.common.datasets.lerobot_dataset import (
)
from lerobot.common.datasets.utils import (
create_branch,
flatten_dict,
unflatten_dict,
)
from lerobot.common.envs.factory import make_env_config
from lerobot.common.policies.factory import make_policy_config
@@ -100,6 +96,25 @@ def test_dataset_initialization(tmp_path, lerobot_dataset_factory):
assert dataset.num_frames == len(dataset)
# TODO(rcadene, aliberts): do not run LeRobotDataset.create, instead refactor LeRobotDatasetMetadata.create
# and test the small resulting function that validates the features
def test_dataset_feature_with_forward_slash_raises_error():
# make sure dir does not exist
from lerobot.common.constants import HF_LEROBOT_HOME
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
# make sure does not exist
if dataset_dir.exists():
dataset_dir.rmdir()
with pytest.raises(ValueError):
LeRobotDataset.create(
repo_id="lerobot/test/with/slash",
fps=30,
features={"a/b": {"dtype": "float32", "shape": 2, "names": None}},
)
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
@@ -329,6 +344,13 @@ def test_image_array_to_pil_image_wrong_range_float_0_255():
# - [ ] test push_to_hub
# - [ ] test smaller methods
# TODO(rcadene):
# - [ ] fix code so that old test_factory + backward pass
# - [ ] write new unit tests to test save_episode + getitem
# - [ ] save_episode : case where new dataset, concatenate same file, write new file (meta/episodes, data, videos)
# - [ ]
# - [ ] remove old tests
@pytest.mark.parametrize(
"env_name, repo_id, policy_name",
@@ -436,30 +458,6 @@ def test_multidataset_frames():
assert torch.equal(sub_dataset_item[k], dataset_item[k])
# TODO(aliberts): Move to more appropriate location
def test_flatten_unflatten_dict():
d = {
"obs": {
"min": 0,
"max": 1,
"mean": 2,
"std": 3,
},
"action": {
"min": 4,
"max": 5,
"mean": 6,
"std": 7,
},
}
original_d = deepcopy(d)
d = unflatten_dict(flatten_dict(d))
# test equality between nested dicts
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
@pytest.mark.parametrize(
"repo_id",
[
@@ -569,20 +567,3 @@ def test_create_branch():
# Clean
api.delete_repo(repo_id, repo_type=repo_type)
def test_dataset_feature_with_forward_slash_raises_error():
# make sure dir does not exist
from lerobot.common.constants import HF_LEROBOT_HOME
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
# make sure does not exist
if dataset_dir.exists():
dataset_dir.rmdir()
with pytest.raises(ValueError):
LeRobotDataset.create(
repo_id="lerobot/test/with/slash",
fps=30,
features={"a/b": {"dtype": "float32", "shape": 2, "names": None}},
)

View File

@@ -14,12 +14,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from copy import deepcopy
import torch
from datasets import Dataset
from huggingface_hub import DatasetCard
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.common.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
from lerobot.common.datasets.utils import (
create_lerobot_dataset_card,
flatten_dict,
hf_transform_to_torch,
unflatten_dict,
)
def test_default_parameters():
@@ -53,3 +61,26 @@ def test_calculate_episode_data_index():
episode_data_index = calculate_episode_data_index(dataset)
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
def test_flatten_unflatten_dict():
d = {
"obs": {
"min": 0,
"max": 1,
"mean": 2,
"std": 3,
},
"action": {
"min": 4,
"max": 5,
"mean": 6,
"std": 7,
},
}
original_d = deepcopy(d)
d = unflatten_dict(flatten_dict(d))
# test equality between nested dicts
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"

View File

@@ -141,6 +141,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
and for now we add tests as we see fit.
"""
policy_kwargs["device"] = DEVICE
train_cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download