Compare commits
2 Commits
user/miche
...
tdmpc23
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
14490148f3 | ||
|
|
16edbbdeee |
@@ -14,11 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
from collections import deque
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
@@ -33,10 +30,6 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
||||
if cfg.env.name == "real_world":
|
||||
return
|
||||
|
||||
if "maniskill" in cfg.env.name:
|
||||
env = make_maniskill_env(cfg, n_envs if n_envs is not None else cfg.eval.batch_size)
|
||||
return env
|
||||
|
||||
package_name = f"gym_{cfg.env.name}"
|
||||
|
||||
try:
|
||||
@@ -63,58 +56,3 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
||||
)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
|
||||
"""Make ManiSkill3 gym environment"""
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
|
||||
env = gym.make(
|
||||
cfg.env.task,
|
||||
obs_mode=cfg.env.obs,
|
||||
control_mode=cfg.env.control_mode,
|
||||
render_mode=cfg.env.render_mode,
|
||||
sensor_configs=dict(width=cfg.env.image_size, height=cfg.env.image_size),
|
||||
num_envs=n_envs,
|
||||
)
|
||||
# cfg.env_cfg.control_mode = cfg.eval_env_cfg.control_mode = env.control_mode
|
||||
env = ManiSkillVectorEnv(env, ignore_terminations=True)
|
||||
# env = PixelWrapper(cfg, env, n_envs)
|
||||
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
|
||||
env.unwrapped.metadata["render_fps"] = 20
|
||||
|
||||
return env
|
||||
|
||||
|
||||
class PixelWrapper(gym.Wrapper):
|
||||
"""
|
||||
Wrapper for pixel observations. Works with Maniskill vectorized environments
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, env, num_envs, num_frames=3):
|
||||
super().__init__(env)
|
||||
self.cfg = cfg
|
||||
self.env = env
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(num_envs, num_frames * 3, cfg.env.render_size, cfg.env.render_size),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
self._frames = deque([], maxlen=num_frames)
|
||||
self._render_size = cfg.env.render_size
|
||||
|
||||
def _get_obs(self, obs):
|
||||
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
|
||||
self._frames.append(frame)
|
||||
return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)}
|
||||
|
||||
def reset(self, seed):
|
||||
obs, info = self.env.reset() # (seed=seed)
|
||||
for _ in range(self._frames.maxlen):
|
||||
obs_frames = self._get_obs(obs)
|
||||
return obs_frames, info
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
return self._get_obs(obs), reward, terminated, truncated, info
|
||||
|
||||
@@ -19,7 +19,7 @@ from dataclasses import dataclass, field
|
||||
|
||||
@dataclass
|
||||
class TDMPC2Config:
|
||||
"""Configuration class for TDMPCPolicy.
|
||||
"""Configuration class for TDMPC2Policy.
|
||||
|
||||
Defaults are configured for training with xarm_lift_medium_replay providing proprioceptive and single
|
||||
camera observations.
|
||||
@@ -77,18 +77,9 @@ class TDMPC2Config:
|
||||
image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation
|
||||
is applied. Note that the input images are assumed to be square for this augmentation.
|
||||
reward_coeff: Loss weighting coefficient for the reward regression loss.
|
||||
expectile_weight: Weighting (τ) used in expectile regression for the state value function (V).
|
||||
v_pred < v_target is weighted by τ and v_pred >= v_target is weighted by (1-τ). τ is expected to
|
||||
be in [0, 1]. Setting τ closer to 1 results in a more "optimistic" V. This is sensible to do
|
||||
because v_target is obtained by evaluating the learned state-action value functions (Q) with
|
||||
in-sample actions that may not be always optimal.
|
||||
value_coeff: Loss weighting coefficient for both the state-action value (Q) TD loss, and the state
|
||||
value (V) expectile regression loss.
|
||||
consistency_coeff: Loss weighting coefficient for the consistency loss.
|
||||
advantage_scaling: A factor by which the advantages are scaled prior to exponentiation for advantage
|
||||
weighted regression of the policy (π) estimator parameters. Note that the exponentiated advantages
|
||||
are clamped at 100.0.
|
||||
pi_coeff: Loss weighting coefficient for the action regression loss.
|
||||
temporal_decay_coeff: Exponential decay coefficient for decaying the loss coefficient for future time-
|
||||
steps. Hint: each loss computation involves `horizon` steps worth of actions starting from the
|
||||
current time step.
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2024 Nicklas Hansen and The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,11 +14,11 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Implementation of Finetuning Offline World Models in the Real World.
|
||||
"""Implementation of TD-MPC2: Scalable, Robust World Models for Continuous Control
|
||||
|
||||
The comments in this code may sometimes refer to these references:
|
||||
TD-MPC paper: Temporal Difference Learning for Model Predictive Control (https://arxiv.org/abs/2203.04955)
|
||||
FOWM paper: Finetuning Offline World Models in the Real World (https://arxiv.org/abs/2310.16029)
|
||||
We refer to the main paper and codebase:
|
||||
TD-MPC2 paper: (https://arxiv.org/abs/2310.16828)
|
||||
TD-MPC2 code: (https://github.com/nicklashansen/tdmpc2)
|
||||
"""
|
||||
|
||||
# ruff: noqa: N806
|
||||
@@ -56,22 +56,7 @@ class TDMPC2Policy(
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "tdmpc2"],
|
||||
):
|
||||
"""Implementation of TD-MPC2 learning + inference.
|
||||
|
||||
Please note several warnings for this policy.
|
||||
- Evaluation of pretrained weights created with the original FOWM code
|
||||
(https://github.com/fyhMer/fowm) works as expected. To be precise: we trained and evaluated a
|
||||
model with the FOWM code for the xarm_lift_medium_replay dataset. We ported the weights across
|
||||
to LeRobot, and were able to evaluate with the same success metric. BUT, we had to use inter-
|
||||
process communication to use the xarm environment from FOWM. This is because our xarm
|
||||
environment uses newer dependencies and does not match the environment in FOWM. See
|
||||
https://github.com/huggingface/lerobot/pull/103 for implementation details.
|
||||
- We have NOT checked that training on LeRobot reproduces the results from FOWM.
|
||||
- Nevertheless, we have verified that we can train TD-MPC for PushT. See
|
||||
`lerobot/configs/policy/tdmpc2_pusht_keypoints.yaml`.
|
||||
- Our current xarm datasets were generated using the environment from FOWM. Therefore they do not
|
||||
match our xarm environment.
|
||||
"""
|
||||
"""Implementation of TD-MPC2 learning + inference."""
|
||||
|
||||
name = "tdmpc2"
|
||||
|
||||
@@ -404,7 +389,7 @@ class TDMPC2Policy(
|
||||
reward_loss = (
|
||||
(
|
||||
temporal_loss_coeffs
|
||||
* soft_cross_entropy(reward_preds, reward, self.config)
|
||||
* soft_cross_entropy(reward_preds, reward, self.config).mean(1)
|
||||
* ~batch["next.reward_is_pad"]
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
@@ -412,10 +397,11 @@ class TDMPC2Policy(
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
ce_value_loss = 0.0
|
||||
for i in range(self.config.q_ensemble_size):
|
||||
ce_value_loss += soft_cross_entropy(q_preds_ensemble[i], td_targets, self.config)
|
||||
ce_value_loss += soft_cross_entropy(q_preds_ensemble[i], td_targets, self.config).mean(1)
|
||||
|
||||
q_value_loss = (
|
||||
(
|
||||
@@ -435,7 +421,6 @@ class TDMPC2Policy(
|
||||
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
|
||||
# We won't need these gradients again so detach.
|
||||
z_preds = z_preds.detach()
|
||||
self.model.change_q_grad(mode=False)
|
||||
action_preds, _, log_pis, _ = self.model.pi(z_preds[:-1])
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -445,14 +430,9 @@ class TDMPC2Policy(
|
||||
self.scale.update(qs[0])
|
||||
qs = self.scale(qs)
|
||||
|
||||
rho = torch.pow(self.config.temporal_decay_coeff, torch.arange(len(qs), device=qs.device)).unsqueeze(
|
||||
-1
|
||||
)
|
||||
|
||||
pi_loss = (
|
||||
(self.config.entropy_coef * log_pis - qs).mean(dim=(1, 2))
|
||||
* rho
|
||||
# * temporal_loss_coeffs
|
||||
(self.config.entropy_coef * log_pis - qs).mean(dim=2)
|
||||
* temporal_loss_coeffs
|
||||
# `action_preds` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
@@ -462,7 +442,7 @@ class TDMPC2Policy(
|
||||
self.config.consistency_coeff * consistency_loss
|
||||
+ self.config.reward_coeff * reward_loss
|
||||
+ self.config.value_coeff * q_value_loss
|
||||
+ self.config.pi_coeff * pi_loss
|
||||
+ pi_loss
|
||||
)
|
||||
|
||||
info.update(
|
||||
|
||||
@@ -75,9 +75,6 @@ def soft_cross_entropy(pred, target, cfg):
|
||||
"""Computes the cross entropy loss between predictions and soft targets."""
|
||||
pred = F.log_softmax(pred, dim=-1)
|
||||
target = two_hot(target, cfg)
|
||||
import pudb
|
||||
|
||||
pudb.set_trace()
|
||||
return -(target * pred).sum(-1, keepdim=True)
|
||||
|
||||
|
||||
@@ -137,16 +134,20 @@ def symexp(x):
|
||||
|
||||
def two_hot(x, cfg):
|
||||
"""Converts a batch of scalars to soft two-hot encoded targets for discrete regression."""
|
||||
|
||||
# x shape [horizon, num_features]
|
||||
if cfg.num_bins == 0:
|
||||
return x
|
||||
elif cfg.num_bins == 1:
|
||||
return symlog(x)
|
||||
x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax)
|
||||
bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long()
|
||||
bin_offset = (x - cfg.vmin) / cfg.bin_size - bin_idx.float()
|
||||
soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device)
|
||||
soft_two_hot.scatter_(1, bin_idx, 1 - bin_offset)
|
||||
soft_two_hot.scatter_(1, (bin_idx + 1) % cfg.num_bins, bin_offset)
|
||||
bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long() # shape [num_features]
|
||||
bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx.float()).unsqueeze(-1) # shape [num_features , 1]
|
||||
soft_two_hot = torch.zeros(
|
||||
*x.shape, cfg.num_bins, device=x.device
|
||||
) # shape [horizon, num_features, num_bins]
|
||||
soft_two_hot.scatter_(2, bin_idx.unsqueeze(-1), 1 - bin_offset)
|
||||
soft_two_hot.scatter_(2, (bin_idx.unsqueeze(-1) + 1) % cfg.num_bins, bin_offset)
|
||||
return soft_two_hot
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user