chore: replace hard-coded obs values with constants throughout all the source code (#2037)
* chore: replace hard-coded OBS values with constants throughout all the source code * chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
@@ -21,6 +21,7 @@ from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.datasets.utils import combine_feature_dicts, create_lerobot_dataset_card, hf_transform_to_torch
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
@@ -96,14 +97,14 @@ def test_merge_multiple_groups_order_and_dedup():
|
||||
def test_non_vector_last_wins_for_images():
|
||||
# Non-vector (images) with same name should be overwritten by the last image specified
|
||||
g1 = {
|
||||
"observation.images.front": {
|
||||
f"{OBS_IMAGES}.front": {
|
||||
"dtype": "image",
|
||||
"shape": (3, 480, 640),
|
||||
"names": ["channels", "height", "width"],
|
||||
}
|
||||
}
|
||||
g2 = {
|
||||
"observation.images.front": {
|
||||
f"{OBS_IMAGES}.front": {
|
||||
"dtype": "image",
|
||||
"shape": (3, 720, 1280),
|
||||
"names": ["channels", "height", "width"],
|
||||
@@ -111,8 +112,8 @@ def test_non_vector_last_wins_for_images():
|
||||
}
|
||||
|
||||
out = combine_feature_dicts(g1, g2)
|
||||
assert out["observation.images.front"]["shape"] == (3, 720, 1280)
|
||||
assert out["observation.images.front"]["dtype"] == "image"
|
||||
assert out[f"{OBS_IMAGES}.front"]["shape"] == (3, 720, 1280)
|
||||
assert out[f"{OBS_IMAGES}.front"]["dtype"] == "image"
|
||||
|
||||
|
||||
def test_dtype_mismatch_raises():
|
||||
|
||||
Reference in New Issue
Block a user