Merge remote-tracking branch 'Cadene/user/rcadene/2024_03_31_remove_torchrl' into refactor_act_remove_torchrl

This commit is contained in:
Alexander Soare
2024-04-08 09:25:45 +01:00
19 changed files with 253 additions and 242 deletions

View File

@@ -164,19 +164,11 @@ def make_dataset(
]
)
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
# TODO(rcadene): implement delta_timestamps in config
delta_timestamps = {
"observation.image": [-0.1, 0],
"observation.state": [-0.1, 0],
"action": [-0.1] + [i / clsfunc.fps for i in range(15)],
}
else:
delta_timestamps = {
"observation.images.top": [0],
"observation.state": [0],
"action": [i / clsfunc.fps for i in range(cfg.policy.horizon)],
}
delta_timestamps = cfg.policy.get("delta_timestamps")
if delta_timestamps is not None:
for key in delta_timestamps:
if isinstance(delta_timestamps[key], str):
delta_timestamps[key] = eval(delta_timestamps[key])
dataset = clsfunc(
dataset_id=cfg.dataset_id,

View File

@@ -6,9 +6,9 @@ import pygame
import pymunk
import torch
import tqdm
from gym_pusht.envs.pusht import pymunk_to_shapely
from lerobot.common.datasets.utils import download_and_extract_zip, load_data_with_delta_timestamps
from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
# as define in env

View File

@@ -4,6 +4,9 @@ register(
id="gym_aloha/AlohaInsertion-v0",
entry_point="lerobot.common.envs.aloha.env:AlohaEnv",
max_episode_steps=300,
# Even after seeding, the rendered observations are slightly different,
# so we set `nondeterministic=True` to pass `check_env` tests
nondeterministic=True,
kwargs={"obs_type": "state", "task": "insertion"},
)
@@ -11,5 +14,8 @@ register(
id="gym_aloha/AlohaTransferCube-v0",
entry_point="lerobot.common.envs.aloha.env:AlohaEnv",
max_episode_steps=300,
# Even after seeding, the rendered observations are slightly different,
# so we set `nondeterministic=True` to pass `check_env` tests
nondeterministic=True,
kwargs={"obs_type": "state", "task": "transfer_cube"},
)

View File

@@ -16,7 +16,6 @@ from lerobot.common.envs.aloha.tasks.sim_end_effector import (
TransferCubeEndEffectorTask,
)
from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
from lerobot.common.utils import set_global_seed
class AlohaEnv(gym.Env):
@@ -49,21 +48,33 @@ class AlohaEnv(gym.Env):
dtype=np.float64,
)
elif self.obs_type == "pixels":
self.observation_space = spaces.Box(
low=0, high=255, shape=(self.observation_height, self.observation_width, 3), dtype=np.uint8
)
elif self.obs_type == "pixels_agent_pos":
self.observation_space = spaces.Dict(
{
"pixels": spaces.Box(
"top": spaces.Box(
low=0,
high=255,
shape=(self.observation_height, self.observation_width, 3),
dtype=np.uint8,
)
}
)
elif self.obs_type == "pixels_agent_pos":
self.observation_space = spaces.Dict(
{
"pixels": spaces.Dict(
{
"top": spaces.Box(
low=0,
high=255,
shape=(self.observation_height, self.observation_width, 3),
dtype=np.uint8,
)
}
),
"agent_pos": spaces.Box(
low=np.array([-1] * len(JOINTS)), # ???
high=np.array([1] * len(JOINTS)), # ???
low=-np.inf,
high=np.inf,
shape=(len(JOINTS),),
dtype=np.float64,
),
}
@@ -89,21 +100,21 @@ class AlohaEnv(gym.Env):
if "transfer_cube" in task_name:
xml_path = ASSETS_DIR / "bimanual_viperx_transfer_cube.xml"
physics = mujoco.Physics.from_xml_path(str(xml_path))
task = TransferCubeTask(random=False)
task = TransferCubeTask()
elif "insertion" in task_name:
xml_path = ASSETS_DIR / "bimanual_viperx_insertion.xml"
physics = mujoco.Physics.from_xml_path(str(xml_path))
task = InsertionTask(random=False)
task = InsertionTask()
elif "end_effector_transfer_cube" in task_name:
raise NotImplementedError()
xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_transfer_cube.xml"
physics = mujoco.Physics.from_xml_path(str(xml_path))
task = TransferCubeEndEffectorTask(random=False)
task = TransferCubeEndEffectorTask()
elif "end_effector_insertion" in task_name:
raise NotImplementedError()
xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_insertion.xml"
physics = mujoco.Physics.from_xml_path(str(xml_path))
task = InsertionEndEffectorTask(random=False)
task = InsertionEndEffectorTask()
else:
raise NotImplementedError(task_name)
@@ -116,10 +127,10 @@ class AlohaEnv(gym.Env):
if self.obs_type == "state":
raise NotImplementedError()
elif self.obs_type == "pixels":
obs = raw_obs["images"]["top"].copy()
obs = {"top": raw_obs["images"]["top"].copy()}
elif self.obs_type == "pixels_agent_pos":
obs = {
"pixels": raw_obs["images"]["top"].copy(),
"pixels": {"top": raw_obs["images"]["top"].copy()},
"agent_pos": raw_obs["qpos"],
}
return obs
@@ -129,14 +140,14 @@ class AlohaEnv(gym.Env):
# TODO(rcadene): how to seed the env?
if seed is not None:
set_global_seed(seed)
self._env.task.random.seed(seed)
self._env.task._random = np.random.RandomState(seed)
# TODO(rcadene): do not use global variable for this
if "transfer_cube" in self.task:
BOX_POSE[0] = sample_box_pose() # used in sim reset
BOX_POSE[0] = sample_box_pose(seed) # used in sim reset
elif "insertion" in self.task:
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
BOX_POSE[0] = np.concatenate(sample_insertion_pose(seed)) # used in sim reset
else:
raise ValueError(self.task)

View File

@@ -1,26 +1,30 @@
import numpy as np
def sample_box_pose():
def sample_box_pose(seed=None):
x_range = [0.0, 0.2]
y_range = [0.4, 0.6]
z_range = [0.05, 0.05]
rng = np.random.RandomState(seed)
ranges = np.vstack([x_range, y_range, z_range])
cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
cube_position = rng.uniform(ranges[:, 0], ranges[:, 1])
cube_quat = np.array([1, 0, 0, 0])
return np.concatenate([cube_position, cube_quat])
def sample_insertion_pose():
def sample_insertion_pose(seed=None):
# Peg
x_range = [0.1, 0.2]
y_range = [0.4, 0.6]
z_range = [0.05, 0.05]
rng = np.random.RandomState(seed)
ranges = np.vstack([x_range, y_range, z_range])
peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
peg_position = rng.uniform(ranges[:, 0], ranges[:, 1])
peg_quat = np.array([1, 0, 0, 0])
peg_pose = np.concatenate([peg_position, peg_quat])
@@ -31,7 +35,7 @@ def sample_insertion_pose():
z_range = [0.05, 0.05]
ranges = np.vstack([x_range, y_range, z_range])
socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
socket_position = rng.uniform(ranges[:, 0], ranges[:, 1])
socket_quat = np.array([1, 0, 0, 0])
socket_pose = np.concatenate([socket_position, socket_quat])

View File

@@ -30,7 +30,7 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
**kwargs,
)
elif cfg.env.name == "aloha":
from lerobot.common.envs import aloha as gym_aloha # noqa: F401
import gym_aloha # noqa: F401
kwargs["task"] = cfg.env.task

View File

@@ -6,12 +6,20 @@ from lerobot.common.transforms import apply_inverse_transform
def preprocess_observation(observation, transform=None):
# map to expected inputs for the policy
obs = {
"observation.image": torch.from_numpy(observation["pixels"]).float(),
"observation.state": torch.from_numpy(observation["agent_pos"]).float(),
}
# convert to (b c h w) torch format
obs["observation.image"] = einops.rearrange(obs["observation.image"], "b h w c -> b c h w")
obs = {}
if isinstance(observation["pixels"], dict):
imgs = {f"observation.images.{key}": img for key, img in observation["pixels"].items()}
else:
imgs = {"observation.image": observation["pixels"]}
for imgkey, img in imgs.items():
img = torch.from_numpy(img).float()
# convert to (b c h w) torch format
img = einops.rearrange(img, "b h w c -> b c h w")
obs[imgkey] = img
obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()
# apply same transforms as in training
if transform is not None:

View File

@@ -1,11 +1,10 @@
def make_policy(cfg):
if cfg.policy.name not in ["diffusion", "act"] and cfg.rollout_batch_size > 1:
raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.")
if cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
policy = TDMPCPolicy(cfg.policy, cfg.device)
policy = TDMPCPolicy(
cfg.policy, n_obs_steps=cfg.n_obs_steps, n_action_steps=cfg.n_action_steps, device=cfg.device
)
elif cfg.policy.name == "diffusion":
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
@@ -17,14 +16,18 @@ def make_policy(cfg):
cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema,
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
n_obs_steps=cfg.n_obs_steps,
n_action_steps=cfg.n_action_steps,
**cfg.policy,
)
elif cfg.policy.name == "act":
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
policy = ActionChunkingTransformerPolicy(
cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps
cfg.policy,
cfg.device,
n_obs_steps=cfg.n_obs_steps,
n_action_steps=cfg.n_action_steps,
)
else:
raise ValueError(cfg.policy.name)

View File

@@ -154,8 +154,14 @@ class TDMPCPolicy(nn.Module):
if len(self._queues["action"]) == 0:
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
if self.n_obs_steps == 1:
# hack to remove the time dimension
for key in batch:
assert batch[key].shape[1] == 1
batch[key] = batch[key][:, 0]
actions = []
batch_size = batch["observation.image."].shape[0]
batch_size = batch["observation.image"].shape[0]
for i in range(batch_size):
obs = {
"rgb": batch["observation.image"][[i]],
@@ -166,6 +172,10 @@ class TDMPCPolicy(nn.Module):
actions.append(action)
action = torch.stack(actions)
# self.act returns an action for 1 timestep only, so we copy it over `n_action_steps` time
if i in range(self.n_action_steps):
self._queues["action"].append(action)
action = self._queues["action"].popleft()
return action
@@ -410,22 +420,45 @@ class TDMPCPolicy(nn.Module):
# idxs = torch.cat([idxs, demo_idxs])
# weights = torch.cat([weights, demo_weights])
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
# instead of currently (time/horizon, batch size, channels) which is not the pytorch convention
# batch size b = 256, time/horizon t = 5
# b t ... -> t b ...
for key in batch:
if batch[key].ndim > 1:
batch[key] = batch[key].transpose(1, 0)
action = batch["action"]
reward = batch["next.reward"][:, :, None] # add extra channel dimension
# idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
weights = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
obses = {
"rgb": batch["observation.image"],
"state": batch["observation.state"],
}
shapes = {}
for k in obses:
shapes[k] = obses[k].shape
obses[k] = einops.rearrange(obses[k], "t b ... -> (t b) ... ")
# Apply augmentations
aug_tf = h.aug(self.cfg)
obs = aug_tf(obs)
obses = aug_tf(obses)
for k in next_obses:
next_obses[k] = einops.rearrange(next_obses[k], "h t ... -> (h t) ...")
next_obses = aug_tf(next_obses)
for k in next_obses:
next_obses[k] = einops.rearrange(
next_obses[k],
"(h t) ... -> h t ...",
h=self.cfg.horizon,
t=self.cfg.batch_size,
)
for k in obses:
t, b = shapes[k][:2]
obses[k] = einops.rearrange(obses[k], "(t b) ... -> t b ... ", b=b, t=t)
horizon = self.cfg.horizon
obs, next_obses = {}, {}
for k in obses:
obs[k] = obses[k][0]
next_obses[k] = obses[k][1:].clone()
horizon = next_obses["rgb"].shape[0]
loss_mask = torch.ones_like(mask, device=self.device)
for t in range(1, horizon):
loss_mask[t] = loss_mask[t - 1] * (~done[t - 1])
@@ -497,19 +530,19 @@ class TDMPCPolicy(nn.Module):
)
self.optim.step()
if self.cfg.per:
# Update priorities
priorities = priority_loss.clamp(max=1e4).detach()
has_nan = torch.isnan(priorities).any().item()
if has_nan:
print(f"priorities has nan: {priorities=}")
else:
replay_buffer.update_priority(
idxs[:num_slices],
priorities[:num_slices],
)
if demo_batch_size > 0:
demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
# if self.cfg.per:
# # Update priorities
# priorities = priority_loss.clamp(max=1e4).detach()
# has_nan = torch.isnan(priorities).any().item()
# if has_nan:
# print(f"priorities has nan: {priorities=}")
# else:
# replay_buffer.update_priority(
# idxs[:num_slices],
# priorities[:num_slices],
# )
# if demo_batch_size > 0:
# demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
# Update policy + target network
_, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action)
@@ -532,7 +565,7 @@ class TDMPCPolicy(nn.Module):
"data_s": data_s,
"update_s": time.time() - start_time,
}
info["demo_batch_size"] = demo_batch_size
# info["demo_batch_size"] = demo_batch_size
info["expectile"] = expectile
info.update(value_info)
info.update(pi_update_info)