fix pusht images type from float32 to uint8, update gym-pusht dependencies
This commit is contained in:
@@ -51,6 +51,7 @@ def test_factory(env_name, dataset_id, policy_name):
|
||||
keys_ndim_required.append(
|
||||
(key, 3, True),
|
||||
)
|
||||
assert dataset.data_dict[key].dtype == torch.uint8, f"{key}"
|
||||
|
||||
# test number of dimensions
|
||||
for key, ndim, required in keys_ndim_required:
|
||||
|
||||
Reference in New Issue
Block a user