fix pusht images type from float32 to uint8, update gym-pusht dependencies

This commit is contained in:
Cadene
2024-04-11 14:29:16 +00:00
parent 4216636084
commit c1a618e567
4 changed files with 6 additions and 2 deletions

View File

@@ -51,6 +51,7 @@ def test_factory(env_name, dataset_id, policy_name):
keys_ndim_required.append(
(key, 3, True),
)
assert dataset.data_dict[key].dtype == torch.uint8, f"{key}"
# test number of dimensions
for key, ndim, required in keys_ndim_required: