wip - still need to verify full training run

This commit is contained in:
Alexander Soare
2024-03-11 18:45:21 +00:00
parent 304355c917
commit 87fcc536f9
3 changed files with 9 additions and 7 deletions

View File

@@ -25,7 +25,7 @@ class PushTImageEnv(PushTEnv):
img = super()._render_frame(mode="rgb_array")
agent_pos = np.array(self.agent.position)
img_obs = np.moveaxis(img.astype(np.float32) / 255, -1, 0)
img_obs = np.moveaxis(img.astype(np.float32), -1, 0)
obs = {"image": img_obs, "agent_pos": agent_pos}
# draw action

View File

@@ -123,6 +123,8 @@ class MultiImageObsEncoder(ModuleAttrMixin):
if imagenet_norm:
# TODO(rcadene): move normalizer to dataset and env
this_normalizer = torchvision.transforms.Normalize(
# Note: This matches the normalization in the original impl. for PushT Image. This may not be
# the case for other tasks.
mean=[127.5, 127.5, 127.5],
std=[127.5, 127.5, 127.5],
)