Ran pre-commit run --all-files

This commit is contained in:
Simon Alibert
2024-02-29 13:37:48 +01:00
parent 0e0e760e43
commit 7e024fdce6
16 changed files with 124 additions and 237 deletions

View File

@@ -3,16 +3,17 @@ import copy
import hydra
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusion_policy.model.common.lr_scheduler import get_scheduler
from diffusion_policy.model.vision.model_getter import get_resnet
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
from diffusion_policy.policy.diffusion_unet_image_policy import DiffusionUnetImagePolicy
FIRST_ACTION = 0
class DiffusionPolicy(nn.Module):
def __init__(
self,
cfg,
@@ -105,7 +106,6 @@ class DiffusionPolicy(nn.Module):
out = self.diffusion.predict_action(obs_dict)
# TODO(rcadene): add possibility to return >1 timestemps
FIRST_ACTION = 0
action = out["action"].squeeze(0)[FIRST_ACTION]
return action
@@ -132,10 +132,7 @@ class DiffusionPolicy(nn.Module):
}
return out
if self.cfg.balanced_sampling:
batch = replay_buffer.sample(batch_size)
else:
batch = replay_buffer.sample()
batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
batch = process_batch(batch, self.cfg.horizon, num_slices)
loss = self.diffusion.compute_loss(batch)

View File

@@ -7,6 +7,8 @@ import torch.nn as nn
import lerobot.common.policies.tdmpc_helper as h
FIRST_FRAME = 0
class TOLD(nn.Module):
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
@@ -17,9 +19,7 @@ class TOLD(nn.Module):
self.cfg = cfg
self._encoder = h.enc(cfg)
self._dynamics = h.dynamics(
cfg.latent_dim + action_dim, cfg.mlp_dim, cfg.latent_dim
)
self._dynamics = h.dynamics(cfg.latent_dim + action_dim, cfg.mlp_dim, cfg.latent_dim)
self._reward = h.mlp(cfg.latent_dim + action_dim, cfg.mlp_dim, 1)
self._pi = h.mlp(cfg.latent_dim, cfg.mlp_dim, action_dim)
self._Qs = nn.ModuleList([h.q(cfg) for _ in range(cfg.num_q)])
@@ -65,20 +65,20 @@ class TOLD(nn.Module):
return h.TruncatedNormal(mu, std).sample(clip=0.3)
return mu
def V(self, z):
def V(self, z): # noqa: N802
"""Predict state value (V)."""
return self._V(z)
def Q(self, z, a, return_type):
def Q(self, z, a, return_type): # noqa: N802
"""Predict state-action value (Q)."""
assert return_type in {"min", "avg", "all"}
x = torch.cat([z, a], dim=-1)
if return_type == "all":
return torch.stack(list(q(x) for q in self._Qs), dim=0)
return torch.stack([q(x) for q in self._Qs], dim=0)
idxs = np.random.choice(self.cfg.num_q, 2, replace=False)
Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x)
Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x) # noqa: N806
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
@@ -146,25 +146,21 @@ class TDMPC(nn.Module):
@torch.no_grad()
def estimate_value(self, z, actions, horizon):
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
G, discount = 0, 1
G, discount = 0, 1 # noqa: N806
for t in range(horizon):
if self.cfg.uncertainty_cost > 0:
G -= (
G -= ( # noqa: N806
discount
* self.cfg.uncertainty_cost
* self.model.Q(z, actions[t], return_type="all").std(dim=0)
)
z, reward = self.model.next(z, actions[t])
G += discount * reward
G += discount * reward # noqa: N806
discount *= self.cfg.discount
pi = self.model.pi(z, self.cfg.min_std)
G += discount * self.model.Q(z, pi, return_type="min")
G += discount * self.model.Q(z, pi, return_type="min") # noqa: N806
if self.cfg.uncertainty_cost > 0:
G -= (
discount
* self.cfg.uncertainty_cost
* self.model.Q(z, pi, return_type="all").std(dim=0)
)
G -= discount * self.cfg.uncertainty_cost * self.model.Q(z, pi, return_type="all").std(dim=0) # noqa: N806
return G
@torch.no_grad()
@@ -180,19 +176,13 @@ class TDMPC(nn.Module):
assert step is not None
# Seed steps
if step < self.cfg.seed_steps and self.model.training:
return torch.empty(
self.action_dim, dtype=torch.float32, device=self.device
).uniform_(-1, 1)
return torch.empty(self.action_dim, dtype=torch.float32, device=self.device).uniform_(-1, 1)
# Sample policy trajectories
horizon = int(
min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step))
)
horizon = int(min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step)))
num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples)
if num_pi_trajs > 0:
pi_actions = torch.empty(
horizon, num_pi_trajs, self.action_dim, device=self.device
)
pi_actions = torch.empty(horizon, num_pi_trajs, self.action_dim, device=self.device)
_z = z.repeat(num_pi_trajs, 1)
for t in range(horizon):
pi_actions[t] = self.model.pi(_z, self.cfg.min_std)
@@ -201,20 +191,16 @@ class TDMPC(nn.Module):
# Initialize state and parameters
z = z.repeat(self.cfg.num_samples + num_pi_trajs, 1)
mean = torch.zeros(horizon, self.action_dim, device=self.device)
std = self.cfg.max_std * torch.ones(
horizon, self.action_dim, device=self.device
)
std = self.cfg.max_std * torch.ones(horizon, self.action_dim, device=self.device)
if not t0 and hasattr(self, "_prev_mean"):
mean[:-1] = self._prev_mean[1:]
# Iterate CEM
for i in range(self.cfg.iterations):
for _ in range(self.cfg.iterations):
actions = torch.clamp(
mean.unsqueeze(1)
+ std.unsqueeze(1)
* torch.randn(
horizon, self.cfg.num_samples, self.action_dim, device=std.device
),
* torch.randn(horizon, self.cfg.num_samples, self.action_dim, device=std.device),
-1,
1,
)
@@ -223,18 +209,14 @@ class TDMPC(nn.Module):
# Compute elite actions
value = self.estimate_value(z, actions, horizon).nan_to_num_(0)
elite_idxs = torch.topk(
value.squeeze(1), self.cfg.num_elites, dim=0
).indices
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
# Update parameters
max_value = elite_value.max(0)[0]
score = torch.exp(self.cfg.temperature * (elite_value - max_value))
score /= score.sum(0)
_mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (
score.sum(0) + 1e-9
)
_mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9)
_std = torch.sqrt(
torch.sum(
score.unsqueeze(0) * (elite_actions - _mean.unsqueeze(1)) ** 2,
@@ -331,7 +313,6 @@ class TDMPC(nn.Module):
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
batch = batch.to(self.device)
FIRST_FRAME = 0
obs = {
"rgb": batch["observation", "image"][FIRST_FRAME].float(),
"state": batch["observation", "state"][FIRST_FRAME],
@@ -359,10 +340,7 @@ class TDMPC(nn.Module):
weights = batch["_weight"][FIRST_FRAME, :, None]
return obs, action, next_obses, reward, mask, done, idxs, weights
if self.cfg.balanced_sampling:
batch = replay_buffer.sample(batch_size)
else:
batch = replay_buffer.sample()
batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
batch, self.cfg.horizon, num_slices
@@ -384,10 +362,7 @@ class TDMPC(nn.Module):
if isinstance(obs, dict):
obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
next_obses = {
k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1)
for k in next_obses
}
next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses}
else:
obs = torch.cat([obs, demo_obs])
next_obses = torch.cat([next_obses, demo_next_obses], dim=1)
@@ -429,9 +404,7 @@ class TDMPC(nn.Module):
td_targets = self._td_target(next_z, reward, mask)
# Latent rollout
zs = torch.empty(
horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device
)
zs = torch.empty(horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device)
reward_preds = torch.empty_like(reward, device=self.device)
assert reward.shape[0] == horizon
z = self.model.encode(obs)
@@ -452,12 +425,10 @@ class TDMPC(nn.Module):
value_info["V"] = v.mean().item()
# Losses
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(
-1, 1, 1
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1, 1)
consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask).sum(
dim=0
)
consistency_loss = (
rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask
).sum(dim=0)
reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0)
q_value_loss, priority_loss = 0, 0
for q in range(self.cfg.num_q):
@@ -465,9 +436,7 @@ class TDMPC(nn.Module):
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
expectile = h.linear_schedule(self.cfg.expectile, step)
v_value_loss = (
rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask
).sum(dim=0)
v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask).sum(dim=0)
total_loss = (
self.cfg.consistency_coef * consistency_loss

View File

@@ -5,11 +5,15 @@ import re
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.functional as F # noqa: N812
from torch import distributions as pyd
from torch.distributions.utils import _standard_normal
__REDUCE__ = lambda b: "mean" if b else "none"
DEFAULT_ACT_FN = nn.Mish()
def __REDUCE__(b): # noqa: N802, N807
return "mean" if b else "none"
def l1(pred, target, reduce=False):
@@ -36,11 +40,7 @@ def l2_expectile(diff, expectile=0.7, reduce=False):
def _get_out_shape(in_shape, layers):
"""Utility function. Returns the output shape of a network for a given input shape."""
x = torch.randn(*in_shape).unsqueeze(0)
return (
(nn.Sequential(*layers) if isinstance(layers, list) else layers)(x)
.squeeze(0)
.shape
)
return (nn.Sequential(*layers) if isinstance(layers, list) else layers)(x).squeeze(0).shape
def gaussian_logprob(eps, log_std):
@@ -73,7 +73,7 @@ def orthogonal_init(m):
def ema(m, m_target, tau):
"""Update slow-moving average of online network (target network) at rate tau."""
with torch.no_grad():
for p, p_target in zip(m.parameters(), m_target.parameters()):
for p, p_target in zip(m.parameters(), m_target.parameters(), strict=False):
p_target.data.lerp_(p.data, tau)
@@ -86,6 +86,8 @@ def set_requires_grad(net, value):
class TruncatedNormal(pyd.Normal):
"""Utility class implementing the truncated normal distribution."""
default_sample_shape = torch.Size()
def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
super().__init__(loc, scale, validate_args=False)
self.low = low
@@ -97,7 +99,7 @@ class TruncatedNormal(pyd.Normal):
x = x - x.detach() + clamped_x.detach()
return x
def sample(self, clip=None, sample_shape=torch.Size()):
def sample(self, clip=None, sample_shape=default_sample_shape):
shape = self._extended_shape(sample_shape)
eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
eps *= self.scale
@@ -136,7 +138,7 @@ def enc(cfg):
"""Returns a TOLD encoder."""
pixels_enc_layers, state_enc_layers = None, None
if cfg.modality in {"pixels", "all"}:
C = int(3 * cfg.frame_stack)
C = int(3 * cfg.frame_stack) # noqa: N806
pixels_enc_layers = [
NormalizeImg(),
nn.Conv2d(C, cfg.num_channels, 7, stride=2),
@@ -184,7 +186,7 @@ def enc(cfg):
return Multiplexer(nn.ModuleDict(encoders))
def mlp(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()):
def mlp(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN):
"""Returns an MLP."""
if isinstance(mlp_dim, int):
mlp_dim = [mlp_dim, mlp_dim]
@@ -199,7 +201,7 @@ def mlp(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()):
)
def dynamics(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()):
def dynamics(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN):
"""Returns a dynamics network."""
return nn.Sequential(
mlp(in_dim, mlp_dim, out_dim, act_fn),
@@ -327,7 +329,7 @@ class RandomShiftsAug(nn.Module):
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
class Episode(object):
class Episode:
"""Storage object for a single episode."""
def __init__(self, cfg, init_obs):
@@ -354,18 +356,10 @@ class Episode(object):
self.obses[k][0] = torch.tensor(v, dtype=dtype, device=self.device)
else:
raise ValueError
self.actions = torch.empty(
(cfg.episode_length, action_dim), dtype=torch.float32, device=self.device
)
self.rewards = torch.empty(
(cfg.episode_length,), dtype=torch.float32, device=self.device
)
self.dones = torch.empty(
(cfg.episode_length,), dtype=torch.bool, device=self.device
)
self.masks = torch.empty(
(cfg.episode_length,), dtype=torch.float32, device=self.device
)
self.actions = torch.empty((cfg.episode_length, action_dim), dtype=torch.float32, device=self.device)
self.rewards = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device)
self.dones = torch.empty((cfg.episode_length,), dtype=torch.bool, device=self.device)
self.masks = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device)
self.cumulative_reward = 0
self.done = False
self.success = False
@@ -380,23 +374,17 @@ class Episode(object):
if cfg.modality in {"pixels", "state"}:
episode = cls(cfg, obses[0])
episode.obses[1:] = torch.tensor(
obses[1:], dtype=episode.obses.dtype, device=episode.device
)
episode.obses[1:] = torch.tensor(obses[1:], dtype=episode.obses.dtype, device=episode.device)
elif cfg.modality == "all":
episode = cls(cfg, {k: v[0] for k, v in obses.items()})
for k, v in obses.items():
for k in obses:
episode.obses[k][1:] = torch.tensor(
obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device
)
else:
raise NotImplementedError
episode.actions = torch.tensor(
actions, dtype=episode.actions.dtype, device=episode.device
)
episode.rewards = torch.tensor(
rewards, dtype=episode.rewards.dtype, device=episode.device
)
episode.actions = torch.tensor(actions, dtype=episode.actions.dtype, device=episode.device)
episode.rewards = torch.tensor(rewards, dtype=episode.rewards.dtype, device=episode.device)
episode.dones = (
torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device)
if dones is not None
@@ -428,9 +416,7 @@ class Episode(object):
v, dtype=self.obses[k].dtype, device=self.obses[k].device
)
else:
self.obses[self._idx + 1] = torch.tensor(
obs, dtype=self.obses.dtype, device=self.obses.device
)
self.obses[self._idx + 1] = torch.tensor(obs, dtype=self.obses.dtype, device=self.obses.device)
self.actions[self._idx] = action
self.rewards[self._idx] = reward
self.dones[self._idx] = done
@@ -453,7 +439,7 @@ def get_dataset_dict(cfg, env, return_reward_normalizer=False):
]
if cfg.task.startswith("xarm"):
dataset_path = os.path.join(cfg.dataset_dir, f"buffer.pkl")
dataset_path = os.path.join(cfg.dataset_dir, "buffer.pkl")
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
@@ -461,7 +447,7 @@ def get_dataset_dict(cfg, env, return_reward_normalizer=False):
if k not in dataset_dict and k[:-1] in dataset_dict:
dataset_dict[k] = dataset_dict.pop(k[:-1])
elif cfg.task.startswith("legged"):
dataset_path = os.path.join(cfg.dataset_dir, f"buffer.pkl")
dataset_path = os.path.join(cfg.dataset_dir, "buffer.pkl")
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
@@ -475,10 +461,7 @@ def get_dataset_dict(cfg, env, return_reward_normalizer=False):
for i in range(len(dones) - 1):
if (
np.linalg.norm(
dataset_dict["observations"][i + 1]
- dataset_dict["next_observations"][i]
)
np.linalg.norm(dataset_dict["observations"][i + 1] - dataset_dict["next_observations"][i])
> 1e-6
or dataset_dict["terminals"][i] == 1.0
):
@@ -501,7 +484,7 @@ def get_dataset_dict(cfg, env, return_reward_normalizer=False):
dataset_dict["rewards"] = reward_normalizer(dataset_dict["rewards"])
for key in required_keys:
assert key in dataset_dict.keys(), f"Missing `{key}` in dataset."
assert key in dataset_dict, f"Missing `{key}` in dataset."
if return_reward_normalizer:
return dataset_dict, reward_normalizer
@@ -553,9 +536,7 @@ def get_reward_normalizer(cfg, dataset):
return lambda x: x - 1.0
elif cfg.task.split("-")[0] in ["hopper", "halfcheetah", "walker2d"]:
(_, _, episode_returns) = get_trajectory_boundaries_and_returns(dataset)
return (
lambda x: x / (np.max(episode_returns) - np.min(episode_returns)) * 1000.0
)
return lambda x: x / (np.max(episode_returns) - np.min(episode_returns)) * 1000.0
elif hasattr(cfg, "reward_scale"):
return lambda x: x * cfg.reward_scale
return lambda x: x
@@ -571,12 +552,12 @@ def linear_schedule(schdl, step):
except ValueError:
match = re.match(r"linear\((.+),(.+),(.+),(.+)\)", schdl)
if match:
init, final, start, end = [float(g) for g in match.groups()]
init, final, start, end = (float(g) for g in match.groups())
mix = np.clip((step - start) / (end - start), 0.0, 1.0)
return (1.0 - mix) * init + mix * final
match = re.match(r"linear\((.+),(.+),(.+)\)", schdl)
if match:
init, final, duration = [float(g) for g in match.groups()]
init, final, duration = (float(g) for g in match.groups())
mix = np.clip(step / duration, 0.0, 1.0)
return (1.0 - mix) * init + mix * final
raise NotImplementedError(schdl)