backup wip
This commit is contained in:
@@ -58,6 +58,7 @@ class AlohaEnv(AbstractEnv):
|
||||
num_prev_obs=num_prev_obs,
|
||||
num_prev_action=num_prev_action,
|
||||
)
|
||||
self._reset_warning_issued = False
|
||||
|
||||
def _make_env(self):
|
||||
if not _has_gym:
|
||||
@@ -120,47 +121,47 @@ class AlohaEnv(AbstractEnv):
|
||||
return obs
|
||||
|
||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||
td = tensordict
|
||||
if td is None or td.is_empty():
|
||||
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
||||
self._current_seed += 1
|
||||
self.set_seed(self._current_seed)
|
||||
if tensordict is not None and not self._reset_warning_issued:
|
||||
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
|
||||
self._reset_warning_issued = True
|
||||
|
||||
# TODO(rcadene): do not use global variable for this
|
||||
if "sim_transfer_cube" in self.task:
|
||||
BOX_POSE[0] = sample_box_pose() # used in sim reset
|
||||
elif "sim_insertion" in self.task:
|
||||
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
|
||||
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
||||
self._current_seed += 1
|
||||
self.set_seed(self._current_seed)
|
||||
|
||||
raw_obs = self._env.reset()
|
||||
# TODO(rcadene): add assert
|
||||
# assert self._current_seed == self._env._seed
|
||||
# TODO(rcadene): do not use global variable for this
|
||||
if "sim_transfer_cube" in self.task:
|
||||
BOX_POSE[0] = sample_box_pose() # used in sim reset
|
||||
elif "sim_insertion" in self.task:
|
||||
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
|
||||
|
||||
obs = self._format_raw_obs(raw_obs.observation)
|
||||
raw_obs = self._env.reset()
|
||||
# TODO(rcadene): add assert
|
||||
# assert self._current_seed == self._env._seed
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue = deque(
|
||||
[obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue = deque(
|
||||
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
obs = self._format_raw_obs(raw_obs.observation)
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"done": torch.tensor([False], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue = deque(
|
||||
[obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue = deque(
|
||||
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"done": torch.tensor([False], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
|
||||
self.call_rendering_hooks()
|
||||
return td
|
||||
|
||||
@@ -1,31 +1,20 @@
|
||||
from torchrl.envs import SerialEnv
|
||||
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv
|
||||
|
||||
|
||||
def make_env(cfg, transform=None):
|
||||
"""
|
||||
Provide seed to override the seed in the cfg (useful for batched environments).
|
||||
Note: The returned environment is wrapped in a torchrl.SerialEnv with cfg.rollout_batch_size underlying
|
||||
environments. The env therefore returns batches.`
|
||||
"""
|
||||
# assert cfg.rollout_batch_size == 1, \
|
||||
# """
|
||||
# For the time being, rollout batch sizes of > 1 are not supported. This is because the SerialEnv rollout does not
|
||||
# correctly handle terminated environments. If you really want to use a larger batch size, read on...
|
||||
|
||||
# When calling `EnvBase.rollout` with `break_when_any_done == True` all environments stop rolling out as soon as the
|
||||
# first is terminated or truncated. This almost certainly results in incorrect success metrics, as all but the first
|
||||
# environment get an opportunity to reach the goal. A possible work around is to comment out `if any_done: break`
|
||||
# inf `EnvBase._rollout_stop_early`. One potential downside is that the environments `step` function will continue
|
||||
# to be called and the outputs will continue to be added to the rollout.
|
||||
|
||||
# When calling `EnvBase.rollout` with `break_when_any_done == False` environments are reset when done.
|
||||
# """
|
||||
|
||||
kwargs = {
|
||||
"frame_skip": cfg.env.action_repeat,
|
||||
"from_pixels": cfg.env.from_pixels,
|
||||
"pixels_only": cfg.env.pixels_only,
|
||||
"image_size": cfg.env.image_size,
|
||||
"num_prev_obs": cfg.n_obs_steps - 1,
|
||||
"seed": cfg.seed,
|
||||
"num_prev_obs": cfg.n_obs_steps - 1,
|
||||
}
|
||||
|
||||
if cfg.env.name == "simxarm":
|
||||
@@ -67,13 +56,14 @@ def make_env(cfg, transform=None):
|
||||
|
||||
return env
|
||||
|
||||
# return SerialEnv(
|
||||
# cfg.rollout_batch_size,
|
||||
# create_env_fn=_make_env,
|
||||
# create_env_kwargs={
|
||||
# "seed": env_seed for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
|
||||
# },
|
||||
# )
|
||||
return SerialEnv(
|
||||
cfg.rollout_batch_size,
|
||||
create_env_fn=_make_env,
|
||||
create_env_kwargs={
|
||||
"seed": env_seed # noqa: B035
|
||||
for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# def make_env(env_name, frame_skip, device, is_test=False):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import importlib
|
||||
import logging
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
@@ -42,6 +43,7 @@ class PushtEnv(AbstractEnv):
|
||||
num_prev_obs=num_prev_obs,
|
||||
num_prev_action=num_prev_action,
|
||||
)
|
||||
self._reset_warning_issued = False
|
||||
|
||||
def _make_env(self):
|
||||
if not _has_gym:
|
||||
@@ -79,39 +81,39 @@ class PushtEnv(AbstractEnv):
|
||||
return obs
|
||||
|
||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||
td = tensordict
|
||||
if td is None or td.is_empty():
|
||||
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
||||
self._current_seed += 1
|
||||
self.set_seed(self._current_seed)
|
||||
raw_obs = self._env.reset()
|
||||
assert self._current_seed == self._env._seed
|
||||
if tensordict is not None and not self._reset_warning_issued:
|
||||
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
|
||||
self._reset_warning_issued = True
|
||||
|
||||
obs = self._format_raw_obs(raw_obs)
|
||||
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
||||
self._current_seed += 1
|
||||
self.set_seed(self._current_seed)
|
||||
raw_obs = self._env.reset()
|
||||
assert self._current_seed == self._env._seed
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue = deque(
|
||||
[obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue = deque(
|
||||
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
obs = self._format_raw_obs(raw_obs)
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"done": torch.tensor([False], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue = deque(
|
||||
[obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue = deque(
|
||||
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"done": torch.tensor([False], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
|
||||
self.call_rendering_hooks()
|
||||
return td
|
||||
|
||||
Reference in New Issue
Block a user