Fix unit tests
This commit is contained in:
@@ -13,10 +13,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
|
||||
@@ -36,8 +34,6 @@ from lerobot.common.datasets.lerobot_dataset import (
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
create_branch,
|
||||
flatten_dict,
|
||||
unflatten_dict,
|
||||
)
|
||||
from lerobot.common.envs.factory import make_env_config
|
||||
from lerobot.common.policies.factory import make_policy_config
|
||||
@@ -100,6 +96,25 @@ def test_dataset_initialization(tmp_path, lerobot_dataset_factory):
|
||||
assert dataset.num_frames == len(dataset)
|
||||
|
||||
|
||||
# TODO(rcadene, aliberts): do not run LeRobotDataset.create, instead refactor LeRobotDatasetMetadata.create
|
||||
# and test the small resulting function that validates the features
|
||||
def test_dataset_feature_with_forward_slash_raises_error():
|
||||
# make sure dir does not exist
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
|
||||
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
|
||||
# make sure does not exist
|
||||
if dataset_dir.exists():
|
||||
dataset_dir.rmdir()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
LeRobotDataset.create(
|
||||
repo_id="lerobot/test/with/slash",
|
||||
fps=30,
|
||||
features={"a/b": {"dtype": "float32", "shape": 2, "names": None}},
|
||||
)
|
||||
|
||||
|
||||
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
@@ -329,6 +344,13 @@ def test_image_array_to_pil_image_wrong_range_float_0_255():
|
||||
# - [ ] test push_to_hub
|
||||
# - [ ] test smaller methods
|
||||
|
||||
# TODO(rcadene):
|
||||
# - [ ] fix code so that old test_factory + backward pass
|
||||
# - [ ] write new unit tests to test save_episode + getitem
|
||||
# - [ ] save_episode : case where new dataset, concatenate same file, write new file (meta/episodes, data, videos)
|
||||
# - [ ]
|
||||
# - [ ] remove old tests
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_name, repo_id, policy_name",
|
||||
@@ -436,30 +458,6 @@ def test_multidataset_frames():
|
||||
assert torch.equal(sub_dataset_item[k], dataset_item[k])
|
||||
|
||||
|
||||
# TODO(aliberts): Move to more appropriate location
|
||||
def test_flatten_unflatten_dict():
|
||||
d = {
|
||||
"obs": {
|
||||
"min": 0,
|
||||
"max": 1,
|
||||
"mean": 2,
|
||||
"std": 3,
|
||||
},
|
||||
"action": {
|
||||
"min": 4,
|
||||
"max": 5,
|
||||
"mean": 6,
|
||||
"std": 7,
|
||||
},
|
||||
}
|
||||
|
||||
original_d = deepcopy(d)
|
||||
d = unflatten_dict(flatten_dict(d))
|
||||
|
||||
# test equality between nested dicts
|
||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id",
|
||||
[
|
||||
@@ -569,20 +567,3 @@ def test_create_branch():
|
||||
|
||||
# Clean
|
||||
api.delete_repo(repo_id, repo_type=repo_type)
|
||||
|
||||
|
||||
def test_dataset_feature_with_forward_slash_raises_error():
|
||||
# make sure dir does not exist
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
|
||||
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
|
||||
# make sure does not exist
|
||||
if dataset_dir.exists():
|
||||
dataset_dir.rmdir()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
LeRobotDataset.create(
|
||||
repo_id="lerobot/test/with/slash",
|
||||
fps=30,
|
||||
features={"a/b": {"dtype": "float32", "shape": 2, "names": None}},
|
||||
)
|
||||
|
||||
@@ -14,12 +14,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
|
||||
from lerobot.common.datasets.utils import (
|
||||
create_lerobot_dataset_card,
|
||||
flatten_dict,
|
||||
hf_transform_to_torch,
|
||||
unflatten_dict,
|
||||
)
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
@@ -53,3 +61,26 @@ def test_calculate_episode_data_index():
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
|
||||
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
|
||||
|
||||
|
||||
def test_flatten_unflatten_dict():
|
||||
d = {
|
||||
"obs": {
|
||||
"min": 0,
|
||||
"max": 1,
|
||||
"mean": 2,
|
||||
"std": 3,
|
||||
},
|
||||
"action": {
|
||||
"min": 4,
|
||||
"max": 5,
|
||||
"mean": 6,
|
||||
"std": 7,
|
||||
},
|
||||
}
|
||||
|
||||
original_d = deepcopy(d)
|
||||
d = unflatten_dict(flatten_dict(d))
|
||||
|
||||
# test equality between nested dicts
|
||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
||||
|
||||
@@ -141,6 +141,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
||||
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
|
||||
and for now we add tests as we see fit.
|
||||
"""
|
||||
policy_kwargs["device"] = DEVICE
|
||||
|
||||
train_cfg = TrainPipelineConfig(
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
|
||||
Reference in New Issue
Block a user