Add AbstractEnv, Refactor AlohaEnv, Add rendering_hook in env, Minor modifications, (TODO: Refactor Pusht and Simxarm)

This commit is contained in:
Cadene
2024-03-10 22:00:48 +00:00
parent b49f7b70e2
commit 7bf36cd413
11 changed files with 131 additions and 59 deletions

View File

@@ -15,8 +15,8 @@ from torchrl.data.tensor_specs import (
DiscreteTensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs import EnvBase
from lerobot.common.envs.abstract import AbstractEnv
from lerobot.common.envs.aloha.constants import (
ACTIONS,
ASSETS_DIR,
@@ -28,14 +28,13 @@ from lerobot.common.envs.aloha.tasks.sim_end_effector import (
InsertionEndEffectorTask,
TransferCubeEndEffectorTask,
)
from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
from lerobot.common.utils import set_seed
from .utils import sample_box_pose, sample_insertion_pose
_has_gym = importlib.util.find_spec("gym") is not None
class AlohaEnv(EnvBase):
class AlohaEnv(AbstractEnv):
def __init__(
self,
task,
@@ -48,20 +47,17 @@ class AlohaEnv(EnvBase):
num_prev_obs=1,
num_prev_action=0,
):
super().__init__(device=device, batch_size=[])
self.task = task
self.frame_skip = frame_skip
self.from_pixels = from_pixels
self.pixels_only = pixels_only
self.image_size = image_size
self.num_prev_obs = num_prev_obs
self.num_prev_action = num_prev_action
if pixels_only:
assert from_pixels
if from_pixels:
assert image_size
super().__init__(
task=task,
frame_skip=frame_skip,
from_pixels=from_pixels,
pixels_only=pixels_only,
image_size=image_size,
seed=seed,
device=device,
num_prev_obs=num_prev_obs,
num_prev_action=num_prev_action,
)
if not _has_gym:
raise ImportError("Cannot import gym.")
@@ -70,16 +66,6 @@ class AlohaEnv(EnvBase):
self._env = self._make_env_task(task)
self._make_spec()
self._current_seed = self.set_seed(seed)
if self.num_prev_obs > 0:
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
if self.num_prev_action > 0:
raise NotImplementedError()
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
def render(self, mode="rgb_array", width=640, height=480):
# TODO(rcadene): render and visualizer several cameras (e.g. angle, front_close)
image = self._env.physics.render(height=height, width=width, camera_id="top")
@@ -172,6 +158,8 @@ class AlohaEnv(EnvBase):
)
else:
raise NotImplementedError()
self.call_rendering_hooks()
return td
def _step(self, tensordict: TensorDict):
@@ -207,6 +195,8 @@ class AlohaEnv(EnvBase):
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
self.call_rendering_hooks()
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),