fix pusht images type from float32 to uint8, update gym-pusht dependencies
This commit is contained in:
@@ -145,6 +145,9 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
assert (episode_ids[idx0:idx1] == episode_id).all()
|
||||
|
||||
image = imgs[idx0:idx1]
|
||||
assert image.min() >= 0.0
|
||||
assert image.max() <= 255.0
|
||||
image = image.type(torch.uint8)
|
||||
|
||||
state = states[idx0:idx1]
|
||||
agent_pos = state[:, :2]
|
||||
|
||||
Reference in New Issue
Block a user