Online finetuning runs (sometimes crash because of nans)

This commit is contained in:
Cadene
2024-02-16 15:13:24 +00:00
parent 228c045674
commit c202c2b3c2
5 changed files with 165 additions and 110 deletions

View File

@@ -77,18 +77,16 @@ class SimxarmEnv(EnvBase):
def _format_raw_obs(self, raw_obs):
if self.from_pixels:
camera = self.render(
image = self.render(
mode="rgb_array", width=self.image_size, height=self.image_size
)
camera = camera.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
camera = torch.tensor(camera.copy(), dtype=torch.uint8)
image = image.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
image = torch.tensor(image.copy(), dtype=torch.uint8)
obs = {"camera": camera}
obs = {"image": image}
if not self.pixels_only:
obs["robot_state"] = torch.tensor(
self._env.robot_state, dtype=torch.float32
)
obs["state"] = torch.tensor(self._env.robot_state, dtype=torch.float32)
else:
obs = {"state": torch.tensor(raw_obs["observation"], dtype=torch.float32)}
@@ -136,7 +134,7 @@ class SimxarmEnv(EnvBase):
def _make_spec(self):
obs = {}
if self.from_pixels:
obs["camera"] = BoundedTensorSpec(
obs["image"] = BoundedTensorSpec(
low=0,
high=255,
shape=(3, self.image_size, self.image_size),
@@ -144,7 +142,7 @@ class SimxarmEnv(EnvBase):
device=self.device,
)
if not self.pixels_only:
obs["robot_state"] = UnboundedContinuousTensorSpec(
obs["state"] = UnboundedContinuousTensorSpec(
shape=(len(self._env.robot_state),),
dtype=torch.float32,
device=self.device,