fix tests

This commit is contained in:
Remi Cadene
2024-03-06 13:55:12 +00:00
committed by Simon Alibert
parent c2c0ef9927
commit 524d29aa80
4 changed files with 23 additions and 14 deletions

View File

@@ -9,6 +9,7 @@ def make_env(cfg, transform=None):
"image_size": cfg.env.image_size,
# TODO(rcadene): do we want a specific eval_env_seed?
"seed": cfg.seed,
"num_prev_obs": cfg.n_obs_steps - 1,
}
if cfg.env.name == "simxarm":

View File

@@ -2,6 +2,7 @@ import importlib
from collections import deque
from typing import Optional
import einops
import torch
from tensordict import TensorDict
from torchrl.data.tensor_specs import (
@@ -28,7 +29,7 @@ class PushtEnv(EnvBase):
image_size=None,
seed=1337,
device="cpu",
num_prev_obs=1,
num_prev_obs=0,
num_prev_action=0,
):
super().__init__(device=device, batch_size=[])
@@ -65,7 +66,8 @@ class PushtEnv(EnvBase):
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
if self.num_prev_action > 0:
self._prev_action_queue = deque(maxlen=self.num_prev_action)
raise NotImplementedError()
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
def render(self, mode="rgb_array", width=384, height=384):
if width != height:
@@ -133,7 +135,7 @@ class PushtEnv(EnvBase):
sum_reward = 0
if action.ndim == 1:
action = action.repeat(self.frame_skip, 1)
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
else:
if self.frame_skip > 1:
raise NotImplementedError()
@@ -172,7 +174,7 @@ class PushtEnv(EnvBase):
if self.from_pixels:
image_shape = (3, self.image_size, self.image_size)
if self.num_prev_obs > 0:
image_shape = (self.num_prev_obs, *image_shape)
image_shape = (self.num_prev_obs + 1, *image_shape)
obs["image"] = BoundedTensorSpec(
low=0,
@@ -184,12 +186,12 @@ class PushtEnv(EnvBase):
if not self.pixels_only:
state_shape = self._env.observation_space["agent_pos"].shape
if self.num_prev_obs > 0:
state_shape = (self.num_prev_obs, *state_shape)
state_shape = (self.num_prev_obs + 1, *state_shape)
obs["state"] = BoundedTensorSpec(
low=0,
high=512,
shape=self._env.observation_space["agent_pos"].shape,
shape=state_shape,
dtype=torch.float32,
device=self.device,
)
@@ -197,11 +199,11 @@ class PushtEnv(EnvBase):
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
state_shape = self._env.observation_space["observation"].shape
if self.num_prev_obs > 0:
state_shape = (self.num_prev_obs, *state_shape)
state_shape = (self.num_prev_obs + 1, *state_shape)
obs["state"] = UnboundedContinuousTensorSpec(
# TODO:
shape=self._env.observation_space["observation"].shape,
shape=state_shape,
dtype=torch.float32,
device=self.device,
)