test_datasets.py are passing!
This commit is contained in:
@@ -9,43 +9,76 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_name,dataset_id",
|
||||
"env_name,dataset_id,policy_name",
|
||||
[
|
||||
("simxarm", "lift"),
|
||||
("pusht", "pusht"),
|
||||
("aloha", "sim_insertion_human"),
|
||||
("aloha", "sim_insertion_scripted"),
|
||||
("aloha", "sim_transfer_cube_human"),
|
||||
("aloha", "sim_transfer_cube_scripted"),
|
||||
("simxarm", "xarm_lift_medium", "tdmpc"),
|
||||
("pusht", "pusht", "diffusion"),
|
||||
("aloha", "aloha_sim_insertion_human", "act"),
|
||||
("aloha", "aloha_sim_insertion_scripted", "act"),
|
||||
("aloha", "aloha_sim_transfer_cube_human", "act"),
|
||||
("aloha", "aloha_sim_transfer_cube_scripted", "act"),
|
||||
],
|
||||
)
|
||||
def test_factory(env_name, dataset_id):
|
||||
def test_factory(env_name, dataset_id, policy_name):
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"]
|
||||
overrides=[f"env={env_name}", f"dataset_id={dataset_id}", f"policy={policy_name}", f"device={DEVICE}"]
|
||||
)
|
||||
dataset = make_dataset(cfg)
|
||||
delta_timestamps = dataset.delta_timestamps
|
||||
image_keys = dataset.image_keys
|
||||
|
||||
item = dataset[0]
|
||||
|
||||
assert "action" in item
|
||||
assert "episode" in item
|
||||
assert "frame_id" in item
|
||||
assert "timestamp" in item
|
||||
assert "next.done" in item
|
||||
# TODO(rcadene): should we rename it agent_pos?
|
||||
assert "observation.state" in item
|
||||
for key in dataset.image_keys:
|
||||
img = item.get(key)
|
||||
assert img.dtype == torch.float32
|
||||
# TODO(rcadene): we assume for now that image normalization takes place in the model
|
||||
assert img.max() <= 1.0
|
||||
assert img.min() >= 0.0
|
||||
keys_ndim_required = [
|
||||
("action", 1, True),
|
||||
("episode", 0, True),
|
||||
("frame_id", 0, True),
|
||||
("timestamp", 0, True),
|
||||
# TODO(rcadene): should we rename it agent_pos?
|
||||
("observation.state", 1, True),
|
||||
("next.reward", 0, False),
|
||||
("next.done", 0, False),
|
||||
]
|
||||
|
||||
if "next.reward" not in item:
|
||||
logging.warning(f'Missing "next.reward" key in dataset {dataset}.')
|
||||
if "next.done" not in item:
|
||||
logging.warning(f'Missing "next.done" key in dataset {dataset}.')
|
||||
for key in image_keys:
|
||||
keys_ndim_required.append(
|
||||
(key, 3, True),
|
||||
)
|
||||
|
||||
# test number of dimensions
|
||||
for key, ndim, required in keys_ndim_required:
|
||||
if key not in item:
|
||||
if required:
|
||||
assert key in item, f"{key}"
|
||||
else:
|
||||
logging.warning(f'Missing key in dataset: "{key}" not in {dataset}.')
|
||||
continue
|
||||
|
||||
if delta_timestamps is not None and key in delta_timestamps:
|
||||
assert item[key].ndim == ndim + 1, f"{key}"
|
||||
assert item[key].shape[0] == len(delta_timestamps[key]), f"{key}"
|
||||
else:
|
||||
assert item[key].ndim == ndim, f"{key}"
|
||||
|
||||
if key in image_keys:
|
||||
assert item[key].dtype == torch.float32, f"{key}"
|
||||
# TODO(rcadene): we assume for now that image normalization takes place in the model
|
||||
assert item[key].max() <= 1.0, f"{key}"
|
||||
assert item[key].min() >= 0.0, f"{key}"
|
||||
|
||||
if delta_timestamps is not None and key in delta_timestamps:
|
||||
# test t,c,h,w
|
||||
assert item[key].shape[1] == 3, f"{key}"
|
||||
else:
|
||||
# test c,h,w
|
||||
assert item[key].shape[0] == 3, f"{key}"
|
||||
|
||||
|
||||
if delta_timestamps is not None:
|
||||
# test missing keys in delta_timestamps
|
||||
for key in delta_timestamps:
|
||||
assert key in item, f"{key}"
|
||||
|
||||
|
||||
# def test_compute_stats():
|
||||
|
||||
Reference in New Issue
Block a user