fix pusht images type from float32 to uint8, update gym-pusht dependencies

This commit is contained in:
Cadene
2024-04-11 14:29:16 +00:00
parent 4216636084
commit c1a618e567
4 changed files with 6 additions and 2 deletions

View File

@@ -145,6 +145,9 @@ class PushtDataset(torch.utils.data.Dataset):
assert (episode_ids[idx0:idx1] == episode_id).all()
image = imgs[idx0:idx1]
assert image.min() >= 0.0
assert image.max() <= 255.0
image = image.type(torch.uint8)
state = states[idx0:idx1]
agent_pos = state[:, :2]