diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 66ea46e7..770ea392 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -2,6 +2,7 @@ from torchrl.envs.transforms import StepCounter, TransformedEnv from lerobot.common.envs.pusht import PushtEnv from lerobot.common.envs.simxarm import SimxarmEnv +from lerobot.common.envs.transforms import Prod def make_env(cfg): @@ -25,6 +26,10 @@ def make_env(cfg): # limit rollout to max_steps env = TransformedEnv(env, StepCounter(max_steps=cfg.episode_length)) + if cfg.env == "pusht": + # to ensure pusht is in [0,255] like simxarm + env.append_transform(Prod(in_keys=[("observation", "image")], prod=255.0)) + return env diff --git a/lerobot/common/envs/pusht.py b/lerobot/common/envs/pusht.py index 1f505c64..6b9bbb51 100644 --- a/lerobot/common/envs/pusht.py +++ b/lerobot/common/envs/pusht.py @@ -48,23 +48,16 @@ class PushtEnv(EnvBase): if not _has_gym: raise ImportError("Cannot import gym.") - import gym from diffusion_policy.env.pusht.pusht_env import PushTEnv + + if not from_pixels: + raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv") from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv from gym.wrappers import TimeLimit self._env = PushTImageEnv(render_size=self.image_size) self._env = TimeLimit(self._env, self.max_episode_length) - # MAX_NUM_ACTIONS = 4 - # num_actions = len(TASKS[self.task]["action_space"]) - # self._action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,)) - # self._action_padding = np.zeros( - # (MAX_NUM_ACTIONS - num_actions), dtype=np.float32 - # ) - # if "w" not in TASKS[self.task]["action_space"]: - # self._action_padding[-1] = 1.0 - self._make_spec() self.set_seed(seed) diff --git a/lerobot/common/envs/transforms.py b/lerobot/common/envs/transforms.py new file mode 100644 index 00000000..f1e6657b --- /dev/null +++ b/lerobot/common/envs/transforms.py @@ -0,0 +1,22 @@ +from typing import Sequence + +from tensordict.utils import NestedKey +from torchrl.envs.transforms import ObservationTransform + + +class Prod(ObservationTransform): + + def __init__(self, in_keys: Sequence[NestedKey], prod: float): + super().__init__() + self.in_keys = in_keys + self.prod = prod + + def _call(self, td): + for key in self.in_keys: + td[key] *= self.prod + return td + + def transform_observation_spec(self, obs_spec): + for key in self.in_keys: + obs_spec[key].space.high *= self.prod + return obs_spec diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 391a8e4d..a85e298e 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -50,6 +50,8 @@ def train(cfg: dict): offline_buffer = make_offline_buffer(cfg) if cfg.balanced_sampling: + num_traj_per_batch = cfg.batch_size + online_sampler = PrioritizedSliceSampler( max_capacity=100_000, alpha=cfg.per_alpha, diff --git a/test/test_envs.py b/test/test_envs.py index e9fffef1..b5c730e3 100644 --- a/test/test_envs.py +++ b/test/test_envs.py @@ -2,6 +2,7 @@ import pytest from tensordict import TensorDict from torchrl.envs.utils import check_env_specs, step_mdp +from lerobot.common.envs.factory import make_env from lerobot.common.envs.pusht import PushtEnv from lerobot.common.envs.simxarm import SimxarmEnv @@ -54,7 +55,7 @@ def test_simxarm(task, from_pixels, pixels_only): pixels_only=pixels_only, image_size=84 if from_pixels else None, ) - print_spec_rollout(env) + # print_spec_rollout(env) check_env_specs(env) @@ -70,5 +71,26 @@ def test_pusht(from_pixels, pixels_only): pixels_only=pixels_only, image_size=96 if from_pixels else None, ) - print_spec_rollout(env) + # print_spec_rollout(env) + check_env_specs(env) + + +@pytest.mark.parametrize( + "config_name", + [ + "default", + "pusht", + ], +) +def test_factory(config_name): + import hydra + from hydra import compose, initialize + + config_path = "../lerobot/configs" + hydra.core.global_hydra.GlobalHydra.instance().clear() + initialize(config_path=config_path) + cfg = compose(config_name=config_name) + + env = make_env(cfg) + check_env_specs(env)