Fix unit tests
This commit is contained in:
@@ -14,12 +14,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
|
||||
from lerobot.common.datasets.utils import (
|
||||
create_lerobot_dataset_card,
|
||||
flatten_dict,
|
||||
hf_transform_to_torch,
|
||||
unflatten_dict,
|
||||
)
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
@@ -53,3 +61,26 @@ def test_calculate_episode_data_index():
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
|
||||
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
|
||||
|
||||
|
||||
def test_flatten_unflatten_dict():
|
||||
d = {
|
||||
"obs": {
|
||||
"min": 0,
|
||||
"max": 1,
|
||||
"mean": 2,
|
||||
"std": 3,
|
||||
},
|
||||
"action": {
|
||||
"min": 4,
|
||||
"max": 5,
|
||||
"mean": 6,
|
||||
"std": 7,
|
||||
},
|
||||
}
|
||||
|
||||
original_d = deepcopy(d)
|
||||
d = unflatten_dict(flatten_dict(d))
|
||||
|
||||
# test equality between nested dicts
|
||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
||||
|
||||
Reference in New Issue
Block a user