forked from tangger/lerobot
Online finetuning runs (sometimes crash because of nans)
This commit is contained in:
@@ -77,18 +77,16 @@ class SimxarmEnv(EnvBase):
|
||||
|
||||
def _format_raw_obs(self, raw_obs):
|
||||
if self.from_pixels:
|
||||
camera = self.render(
|
||||
image = self.render(
|
||||
mode="rgb_array", width=self.image_size, height=self.image_size
|
||||
)
|
||||
camera = camera.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
|
||||
camera = torch.tensor(camera.copy(), dtype=torch.uint8)
|
||||
image = image.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
|
||||
image = torch.tensor(image.copy(), dtype=torch.uint8)
|
||||
|
||||
obs = {"camera": camera}
|
||||
obs = {"image": image}
|
||||
|
||||
if not self.pixels_only:
|
||||
obs["robot_state"] = torch.tensor(
|
||||
self._env.robot_state, dtype=torch.float32
|
||||
)
|
||||
obs["state"] = torch.tensor(self._env.robot_state, dtype=torch.float32)
|
||||
else:
|
||||
obs = {"state": torch.tensor(raw_obs["observation"], dtype=torch.float32)}
|
||||
|
||||
@@ -136,7 +134,7 @@ class SimxarmEnv(EnvBase):
|
||||
def _make_spec(self):
|
||||
obs = {}
|
||||
if self.from_pixels:
|
||||
obs["camera"] = BoundedTensorSpec(
|
||||
obs["image"] = BoundedTensorSpec(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(3, self.image_size, self.image_size),
|
||||
@@ -144,7 +142,7 @@ class SimxarmEnv(EnvBase):
|
||||
device=self.device,
|
||||
)
|
||||
if not self.pixels_only:
|
||||
obs["robot_state"] = UnboundedContinuousTensorSpec(
|
||||
obs["state"] = UnboundedContinuousTensorSpec(
|
||||
shape=(len(self._env.robot_state),),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
|
||||
Reference in New Issue
Block a user