diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py new file mode 100644 index 00000000..f5535804 --- /dev/null +++ b/lerobot/common/datasets/factory.py @@ -0,0 +1,47 @@ +import torch + +from lerobot.common.datasets.pusht import PushtExperienceReplay +from lerobot.common.datasets.simxarm import SimxarmExperienceReplay +from rl.torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler + + +def make_offline_buffer(cfg): + + num_traj_per_batch = cfg.batch_size # // cfg.horizon + # TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size. + # We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size. + sampler = PrioritizedSliceSampler( + max_capacity=100_000, + alpha=cfg.per_alpha, + beta=cfg.per_beta, + num_slices=num_traj_per_batch, + strict_length=False, + ) + + if cfg.env == "simxarm": + # TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here + offline_buffer = SimxarmExperienceReplay( + f"xarm_{cfg.task}_medium", + # download="force", + download=True, + streaming=False, + root="data", + sampler=sampler, + ) + elif cfg.env == "pusht": + offline_buffer = PushtExperienceReplay( + f"xarm_{cfg.task}_medium", + # download="force", + download=True, + streaming=False, + root="data", + sampler=sampler, + ) + else: + raise ValueError(cfg.env) + + num_steps = len(offline_buffer) + index = torch.arange(0, num_steps, 1) + sampler.extend(index) + + return offline_buffer diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py new file mode 100644 index 00000000..76848daf --- /dev/null +++ b/lerobot/common/datasets/pusht.py @@ -0,0 +1,192 @@ +import os +import pickle +from pathlib import Path +from typing import Any, Callable, Dict, Tuple + +import torch +import torchrl +import tqdm +from tensordict import TensorDict +from torchrl.data.datasets.utils import _get_root_dir +from torchrl.data.replay_buffers.replay_buffers import ( + TensorDictPrioritizedReplayBuffer, + TensorDictReplayBuffer, +) +from torchrl.data.replay_buffers.samplers import ( + Sampler, + SliceSampler, + SliceSamplerWithoutReplacement, +) +from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id +from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer + + +class PushtExperienceReplay(TensorDictReplayBuffer): + + available_datasets = [ + "xarm_lift_medium", + ] + + def __init__( + self, + dataset_id, + batch_size: int = None, + *, + shuffle: bool = True, + num_slices: int = None, + slice_len: int = None, + pad: float = None, + replacement: bool = None, + streaming: bool = False, + root: Path = None, + download: bool = False, + sampler: Sampler = None, + writer: Writer = None, + collate_fn: Callable = None, + pin_memory: bool = False, + prefetch: int = None, + transform: "torchrl.envs.Transform" = None, # noqa-F821 + split_trajs: bool = False, + strict_length: bool = True, + ): + # TODO + raise NotImplementedError() + self.download = download + if streaming: + raise NotImplementedError + self.streaming = streaming + self.dataset_id = dataset_id + self.split_trajs = split_trajs + self.shuffle = shuffle + self.num_slices = num_slices + self.slice_len = slice_len + self.pad = pad + + self.strict_length = strict_length + if (self.num_slices is not None) and (self.slice_len is not None): + raise ValueError("num_slices or slice_len can be not None, but not both.") + if split_trajs: + raise NotImplementedError + + if root is None: + root = _get_root_dir("simxarm") + os.makedirs(root, exist_ok=True) + self.root = Path(root) + if self.download == "force" or (self.download and not self._is_downloaded()): + storage = self._download_and_preproc() + else: + storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id)) + + if num_slices is not None or slice_len is not None: + if sampler is not None: + raise ValueError( + "`num_slices` and `slice_len` are exclusive with the `sampler` argument." + ) + + if replacement: + if not self.shuffle: + raise RuntimeError( + "shuffle=False can only be used when replacement=False." + ) + sampler = SliceSampler( + num_slices=num_slices, + slice_len=slice_len, + strict_length=strict_length, + ) + else: + sampler = SliceSamplerWithoutReplacement( + num_slices=num_slices, + slice_len=slice_len, + strict_length=strict_length, + shuffle=self.shuffle, + ) + + if writer is None: + writer = ImmutableDatasetWriter() + if collate_fn is None: + collate_fn = _collate_id + + super().__init__( + storage=storage, + sampler=sampler, + writer=writer, + collate_fn=collate_fn, + pin_memory=pin_memory, + prefetch=prefetch, + batch_size=batch_size, + transform=transform, + ) + + @property + def data_path_root(self): + if self.streaming: + return None + return self.root / self.dataset_id + + def _is_downloaded(self): + return os.path.exists(self.data_path_root) + + def _download_and_preproc(self): + # download + # TODO(rcadene) + + # load + dataset_dir = Path("data") / self.dataset_id + dataset_path = dataset_dir / f"buffer.pkl" + print(f"Using offline dataset '{dataset_path}'") + with open(dataset_path, "rb") as f: + dataset_dict = pickle.load(f) + + total_frames = dataset_dict["actions"].shape[0] + + idx0 = 0 + idx1 = 0 + episode_id = 0 + for i in tqdm.tqdm(range(total_frames)): + idx1 += 1 + + if not dataset_dict["dones"][i]: + continue + + num_frames = idx1 - idx0 + + image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1]) + state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1]) + next_image = torch.tensor( + dataset_dict["next_observations"]["rgb"][idx0:idx1] + ) + next_state = torch.tensor( + dataset_dict["next_observations"]["state"][idx0:idx1] + ) + next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1]) + next_done = torch.tensor(dataset_dict["dones"][idx0:idx1]) + + episode = TensorDict( + { + ("observation", "image"): image, + ("observation", "state"): state, + "action": torch.tensor(dataset_dict["actions"][idx0:idx1]), + "episode": torch.tensor([episode_id] * num_frames, dtype=torch.int), + "frame_id": torch.arange(0, num_frames, 1), + ("next", "observation", "image"): next_image, + ("next", "observation", "state"): next_state, + ("next", "observation", "reward"): next_reward, + ("next", "observation", "done"): next_done, + }, + batch_size=num_frames, + ) + + if episode_id == 0: + # hack to initialize tensordict data structure to store episodes + td_data = ( + episode[0] + .expand(total_frames) + .memmap_like(self.root / self.dataset_id) + ) + + td_data[idx0:idx1] = episode + + episode_id += 1 + idx0 = idx1 + + return TensorStorage(td_data.lock_()) diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 9f491f71..66ea46e7 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -1,17 +1,26 @@ from torchrl.envs.transforms import StepCounter, TransformedEnv +from lerobot.common.envs.pusht import PushtEnv from lerobot.common.envs.simxarm import SimxarmEnv def make_env(cfg): - assert cfg.env == "simxarm" - env = SimxarmEnv( - task=cfg.task, - frame_skip=cfg.action_repeat, - from_pixels=cfg.from_pixels, - pixels_only=cfg.pixels_only, - image_size=cfg.image_size, - ) + kwargs = { + "frame_skip": cfg.action_repeat, + "from_pixels": cfg.from_pixels, + "pixels_only": cfg.pixels_only, + "image_size": cfg.image_size, + } + + if cfg.env == "simxarm": + kwargs["task"] = cfg.task + clsfunc = SimxarmEnv + elif cfg.env == "pusht": + clsfunc = PushtEnv + else: + raise ValueError(cfg.env) + + env = clsfunc(**kwargs) # limit rollout to max_steps env = TransformedEnv(env, StepCounter(max_steps=cfg.episode_length)) diff --git a/lerobot/common/envs/pusht.py b/lerobot/common/envs/pusht.py new file mode 100644 index 00000000..1f505c64 --- /dev/null +++ b/lerobot/common/envs/pusht.py @@ -0,0 +1,193 @@ +import importlib +from typing import Optional + +import numpy as np +import torch +from tensordict import TensorDict +from torchrl.data.tensor_specs import ( + BoundedTensorSpec, + CompositeSpec, + DiscreteTensorSpec, + UnboundedContinuousTensorSpec, +) +from torchrl.envs import EnvBase +from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform + +from lerobot.common.utils import set_seed + +_has_gym = importlib.util.find_spec("gym") is not None +_has_diffpolicy = importlib.util.find_spec("diffusion_policy") is not None and _has_gym + + +class PushtEnv(EnvBase): + + def __init__( + self, + frame_skip: int = 1, + from_pixels: bool = False, + pixels_only: bool = False, + image_size=None, + seed=1337, + device="cpu", + max_episode_length=25, # TODO: verify + ): + super().__init__(device=device, batch_size=[]) + self.frame_skip = frame_skip + self.from_pixels = from_pixels + self.pixels_only = pixels_only + self.image_size = image_size + self.max_episode_length = max_episode_length + + if pixels_only: + assert from_pixels + if from_pixels: + assert image_size + + if not _has_diffpolicy: + raise ImportError("Cannot import diffusion_policy.") + if not _has_gym: + raise ImportError("Cannot import gym.") + + import gym + from diffusion_policy.env.pusht.pusht_env import PushTEnv + from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv + from gym.wrappers import TimeLimit + + self._env = PushTImageEnv(render_size=self.image_size) + self._env = TimeLimit(self._env, self.max_episode_length) + + # MAX_NUM_ACTIONS = 4 + # num_actions = len(TASKS[self.task]["action_space"]) + # self._action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,)) + # self._action_padding = np.zeros( + # (MAX_NUM_ACTIONS - num_actions), dtype=np.float32 + # ) + # if "w" not in TASKS[self.task]["action_space"]: + # self._action_padding[-1] = 1.0 + + self._make_spec() + self.set_seed(seed) + + def render(self, mode="rgb_array", width=384, height=384): + if width != height: + raise NotImplementedError() + tmp = self._env.render_size + self._env.render_size = width + out = self._env.render(mode) + self._env.render_size = tmp + return out + + def _format_raw_obs(self, raw_obs): + if self.from_pixels: + obs = {"image": torch.from_numpy(raw_obs["image"])} + + if not self.pixels_only: + obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type( + torch.float32 + ) + else: + # TODO: + obs = { + "state": torch.from_numpy(raw_obs["observation"]).type(torch.float32) + } + + obs = TensorDict(obs, batch_size=[]) + return obs + + def _reset(self, tensordict: Optional[TensorDict] = None): + td = tensordict + if td is None or td.is_empty(): + raw_obs = self._env.reset() + + td = TensorDict( + { + "observation": self._format_raw_obs(raw_obs), + "done": torch.tensor([False], dtype=torch.bool), + }, + batch_size=[], + ) + else: + raise NotImplementedError() + return td + + def _step(self, tensordict: TensorDict): + td = tensordict + action = td["action"].numpy() + # step expects shape=(4,) so we pad if necessary + # TODO(rcadene): add info["is_success"] and info["success"] ? + sum_reward = 0 + for t in range(self.frame_skip): + raw_obs, reward, done, info = self._env.step(action) + sum_reward += reward + + td = TensorDict( + { + "observation": self._format_raw_obs(raw_obs), + "reward": torch.tensor([sum_reward], dtype=torch.float32), + # succes and done are true when coverage > self.success_threshold in env + "done": torch.tensor([done], dtype=torch.bool), + "success": torch.tensor([done], dtype=torch.bool), + }, + batch_size=[], + ) + return td + + def _make_spec(self): + obs = {} + if self.from_pixels: + obs["image"] = BoundedTensorSpec( + low=0, + high=1, + shape=(3, self.image_size, self.image_size), + dtype=torch.float32, + device=self.device, + ) + if not self.pixels_only: + obs["state"] = BoundedTensorSpec( + low=0, + high=512, + shape=self._env.observation_space["agent_pos"].shape, + dtype=torch.float32, + device=self.device, + ) + else: + # TODO(rcadene): add observation_space achieved_goal and desired_goal? + obs["state"] = UnboundedContinuousTensorSpec( + # TODO: + shape=self._env.observation_space["observation"].shape, + dtype=torch.float32, + device=self.device, + ) + self.observation_spec = CompositeSpec({"observation": obs}) + + self.action_spec = _gym_to_torchrl_spec_transform( + self._env.action_space, + device=self.device, + ) + + self.reward_spec = UnboundedContinuousTensorSpec( + shape=(1,), + dtype=torch.float32, + device=self.device, + ) + + self.done_spec = CompositeSpec( + { + "done": DiscreteTensorSpec( + 2, + shape=(1,), + dtype=torch.bool, + device=self.device, + ), + "success": DiscreteTensorSpec( + 2, + shape=(1,), + dtype=torch.bool, + device=self.device, + ), + } + ) + + def _set_seed(self, seed: Optional[int]): + set_seed(seed) + self._env.seed(seed) diff --git a/lerobot/common/envs/simxarm.py b/lerobot/common/envs/simxarm.py index 8d955072..fac82ae6 100644 --- a/lerobot/common/envs/simxarm.py +++ b/lerobot/common/envs/simxarm.py @@ -167,18 +167,21 @@ class SimxarmEnv(EnvBase): device=self.device, ) - self.done_spec = DiscreteTensorSpec( - 2, - shape=(1,), - dtype=torch.bool, - device=self.device, - ) - - self.success_spec = DiscreteTensorSpec( - 2, - shape=(1,), - dtype=torch.bool, - device=self.device, + self.done_spec = CompositeSpec( + { + "done": DiscreteTensorSpec( + 2, + shape=(1,), + dtype=torch.bool, + device=self.device, + ), + "success": DiscreteTensorSpec( + 2, + shape=(1,), + dtype=torch.bool, + device=self.device, + ), + } ) def _set_seed(self, seed: Optional[int]): diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index f1a014aa..52b3a551 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -29,7 +29,7 @@ train_steps: 50000 # pixels frame_stack: 1 num_channels: 32 -img_size: 84 +img_size: ${image_size} # TDMPC @@ -82,6 +82,8 @@ A_scaling: 3.0 # offline->online offline_steps: 25000 # ${train_steps}/2 pretrained_model_path: "" +# pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt" +# pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt" balanced_sampling: true demo_schedule: 0.5 diff --git a/lerobot/configs/pusht.yaml b/lerobot/configs/pusht.yaml new file mode 100644 index 00000000..cbd7c63a --- /dev/null +++ b/lerobot/configs/pusht.yaml @@ -0,0 +1,12 @@ +defaults: + - default + +hydra: + job: + name: pusht + +# env +env: pusht +image_size: 96 +frame_skip: 1 + diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index e4c5b9b1..130316d2 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -71,24 +71,35 @@ def eval(cfg: dict): print(colored("Log dir:", "yellow", attrs=["bold"]), cfg.log_dir) env = make_env(cfg) - policy = TDMPC(cfg) - # ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt" - ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt" - policy.load(ckpt_path) - policy = TensorDictModule( - policy, - in_keys=["observation", "step_count"], - out_keys=["action"], - ) + if cfg.pretrained_model_path: + policy = TDMPC(cfg) + ckpt_path = ( + "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt" + ) + if "offline" in cfg.pretrained_model_path: + policy.step = 25000 + elif "final" in cfg.pretrained_model_path: + policy.step = 100000 + else: + raise NotImplementedError() + policy.load(ckpt_path) + + policy = TensorDictModule( + policy, + in_keys=["observation", "step_count"], + out_keys=["action"], + ) + else: + # when policy is None, rollout a random policy + policy = None - # policy can be None to rollout a random policy metrics = eval_policy( env, policy=policy, num_episodes=20, - save_video=False, - video_dir=Path("tmp/2023_01_29_xarm_lift_final"), + save_video=True, + video_dir=Path("tmp/2023_02_19_pusht"), ) print(metrics) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 2d33d7b0..391a8e4d 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -10,6 +10,7 @@ from torchrl.data.datasets.d4rl import D4RLExperienceReplay from torchrl.data.datasets.openx import OpenXExperienceReplay from torchrl.data.replay_buffers import PrioritizedSliceSampler +from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.datasets.simxarm import SimxarmExperienceReplay from lerobot.common.envs.factory import make_env from lerobot.common.logger import Logger @@ -26,11 +27,17 @@ def train(cfg: dict): env = make_env(cfg) policy = TDMPC(cfg) - # ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt" - # policy.step = 25000 - # # ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt" - # # policy.step = 100000 - # policy.load(ckpt_path) + if cfg.pretrained_model_path: + ckpt_path = ( + "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt" + ) + if "offline" in cfg.pretrained_model_path: + policy.step = 25000 + elif "final" in cfg.pretrained_model_path: + policy.step = 100000 + else: + raise NotImplementedError() + policy.load(ckpt_path) td_policy = TensorDictModule( policy, @@ -40,32 +47,7 @@ def train(cfg: dict): # initialize offline dataset - dataset_id = f"xarm_{cfg.task}_medium" - - num_traj_per_batch = cfg.batch_size # // cfg.horizon - # TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size. - # We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size. - sampler = PrioritizedSliceSampler( - max_capacity=100_000, - alpha=cfg.per_alpha, - beta=cfg.per_beta, - num_slices=num_traj_per_batch, - strict_length=False, - ) - - # TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here - offline_buffer = SimxarmExperienceReplay( - dataset_id, - # download="force", - download=True, - streaming=False, - root="data", - sampler=sampler, - ) - - num_steps = len(offline_buffer) - index = torch.arange(0, num_steps, 1) - sampler.extend(index) + offline_buffer = make_offline_buffer(cfg) if cfg.balanced_sampling: online_sampler = PrioritizedSliceSampler( diff --git a/test/test_envs.py b/test/test_envs.py index 49f547b6..e9fffef1 100644 --- a/test/test_envs.py +++ b/test/test_envs.py @@ -2,7 +2,35 @@ import pytest from tensordict import TensorDict from torchrl.envs.utils import check_env_specs, step_mdp -from lerobot.common.envs import SimxarmEnv +from lerobot.common.envs.pusht import PushtEnv +from lerobot.common.envs.simxarm import SimxarmEnv + + +def print_spec_rollout(env): + print("observation_spec:", env.observation_spec) + print("action_spec:", env.action_spec) + print("reward_spec:", env.reward_spec) + print("done_spec:", env.done_spec) + + td = env.reset() + print("reset tensordict", td) + + td = env.rand_step(td) + print("random step tensordict", td) + + def simple_rollout(steps=100): + # preallocate: + data = TensorDict({}, [steps]) + # reset + _data = env.reset() + for i in range(steps): + _data["action"] = env.action_spec.rand() + _data = env.step(_data) + data[i] = _data + _data = step_mdp(_data, keep_other=True) + return data + + print("data from rollout:", simple_rollout(100)) @pytest.mark.parametrize( @@ -26,30 +54,21 @@ def test_simxarm(task, from_pixels, pixels_only): pixels_only=pixels_only, image_size=84 if from_pixels else None, ) + print_spec_rollout(env) check_env_specs(env) - print("observation_spec:", env.observation_spec) - print("action_spec:", env.action_spec) - print("reward_spec:", env.reward_spec) - print("done_spec:", env.done_spec) - print("success_spec:", env.success_spec) - td = env.reset() - print("reset tensordict", td) - - td = env.rand_step(td) - print("random step tensordict", td) - - def simple_rollout(steps=100): - # preallocate: - data = TensorDict({}, [steps]) - # reset - _data = env.reset() - for i in range(steps): - _data["action"] = env.action_spec.rand() - _data = env.step(_data) - data[i] = _data - _data = step_mdp(_data, keep_other=True) - return data - - print("data from rollout:", simple_rollout(100)) +@pytest.mark.parametrize( + "from_pixels,pixels_only", + [ + (True, False), + ], +) +def test_pusht(from_pixels, pixels_only): + env = PushtEnv( + from_pixels=from_pixels, + pixels_only=pixels_only, + image_size=96 if from_pixels else None, + ) + print_spec_rollout(env) + check_env_specs(env)