Stable version of rlpd + drq

This commit is contained in:
AdilZouitine
2025-01-22 09:00:16 +00:00
committed by Michel Aractingi
parent 1fb03d4cf2
commit d75b44f89f
6 changed files with 467 additions and 174 deletions

View File

@@ -14,34 +14,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from contextlib import nullcontext
from copy import deepcopy
from pathlib import Path
import functools
from pprint import pformat
import random
from typing import Optional, Sequence, TypedDict
from typing import Optional, Sequence, TypedDict, Callable
import hydra
import numpy as np
import torch
from deepdiff import DeepDiff
from omegaconf import DictConfig, ListConfig, OmegaConf
from termcolor import colored
import torch.nn.functional as F
from torch import nn
from torch.cuda.amp import GradScaler
from tqdm import tqdm
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset, LeRobotDataset
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
# TODO: Remove the import of maniskill
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env, make_maniskill_env
from lerobot.common.envs.utils import preprocess_observation, preprocess_maniskill_observation
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import PolicyWithUpdate
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.utils import (
@@ -56,7 +49,8 @@ from lerobot.scripts.eval import eval_policy
def make_optimizers_and_scheduler(cfg, policy):
optimizer_actor = torch.optim.Adam(
params=policy.actor.parameters(),
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
params=policy.actor.parameters_to_optimize,
lr=policy.config.actor_lr,
)
optimizer_critic = torch.optim.Adam(
@@ -73,11 +67,6 @@ def make_optimizers_and_scheduler(cfg, policy):
return optimizers, lr_scheduler
# def update_policy(policy, batch, optimizers, grad_clip_norm):
# NOTE: This is temporary, online buffer or query lerobot dataset is not performant enough yet
class Transition(TypedDict):
state: dict[str, torch.Tensor]
action: torch.Tensor
@@ -95,13 +84,62 @@ class BatchTransition(TypedDict):
done: torch.Tensor
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
"""
Perform a per-image random crop over a batch of images in a vectorized way.
(Same as shown previously.)
"""
B, C, H, W = images.shape
crop_h, crop_w = output_size
if crop_h > H or crop_w > W:
raise ValueError(
f"Requested crop size ({crop_h}, {crop_w}) is bigger than the image size ({H}, {W})."
)
tops = torch.randint(0, H - crop_h + 1, (B,), device=images.device)
lefts = torch.randint(0, W - crop_w + 1, (B,), device=images.device)
rows = torch.arange(crop_h, device=images.device).unsqueeze(0) + tops.unsqueeze(1)
cols = torch.arange(crop_w, device=images.device).unsqueeze(0) + lefts.unsqueeze(1)
rows = rows.unsqueeze(2).expand(-1, -1, crop_w) # (B, crop_h, crop_w)
cols = cols.unsqueeze(1).expand(-1, crop_h, -1) # (B, crop_h, crop_w)
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
# Gather pixels
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
# cropped_hwcn => (B, crop_h, crop_w, C)
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
return cropped
def random_shift(images: torch.Tensor, pad: int = 4):
"""Vectorized random shift, imgs: (B,C,H,W), pad: #pixels"""
_, _, h, w = images.shape
images = F.pad(input=images, pad=(pad, pad, pad, pad), mode="replicate")
return random_crop_vectorized(images=images, output_size=(h, w))
class ReplayBuffer:
def __init__(self, capacity: int, device: str = "cuda:0", state_keys: Optional[Sequence[str]] = None):
def __init__(
self,
capacity: int,
device: str = "cuda:0",
state_keys: Optional[Sequence[str]] = None,
image_augmentation_function: Optional[Callable] = None,
use_drq: bool = True,
):
"""
Args:
capacity (int): Maximum number of transitions to store in the buffer.
device (str): The device where the tensors will be moved ("cuda:0" or "cpu").
state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
and returns a batch of augmented images. If None, a default augmentation function is used.
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
"""
self.capacity = capacity
self.device = device
@@ -111,6 +149,9 @@ class ReplayBuffer:
# If no state_keys provided, default to an empty list
# (you can handle this differently if needed)
self.state_keys = state_keys if state_keys is not None else []
if image_augmentation_function is None:
self.image_augmentation_function = functools.partial(random_shift, pad=4)
self.use_drq = use_drq
def add(
self,
@@ -134,7 +175,7 @@ class ReplayBuffer:
done=done,
complementary_info=complementary_info,
)
self.position = (self.position + 1) % self.capacity
self.position: int = (self.position + 1) % self.capacity
@classmethod
def from_lerobot_dataset(
@@ -143,6 +184,18 @@ class ReplayBuffer:
device: str = "cuda:0",
state_keys: Optional[Sequence[str]] = None,
) -> "ReplayBuffer":
"""
Convert a LeRobotDataset into a ReplayBuffer.
Args:
lerobot_dataset (LeRobotDataset): The dataset to convert.
device (str): The device . Defaults to "cuda:0".
state_keys (Optional[Sequence[str]], optional): The list of keys that appear in `state` and `next_state`.
Defaults to None.
Returns:
ReplayBuffer: The replay buffer with offline dataset transitions.
"""
# We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from
# a replay buffer than from a lerobot dataset.
replay_buffer = cls(capacity=len(lerobot_dataset), device=device, state_keys=state_keys)
@@ -248,6 +301,8 @@ class ReplayBuffer:
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to(
self.device
)
if key.startswith("observation.image") and self.use_drq:
batch_state[key] = self.image_augmentation_function(batch_state[key])
# -- Build batched actions --
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device)
@@ -263,6 +318,8 @@ class ReplayBuffer:
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to(
self.device
)
if key.startswith("observation.image") and self.use_drq:
batch_next_state[key] = self.image_augmentation_function(batch_next_state[key])
# -- Build batched dones --
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
@@ -285,7 +342,7 @@ class ReplayBuffer:
def concatenate_batch_transitions(
left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
) -> BatchTransition:
"""Be careful it change the left_batch_transitions in place"""
"""NOTE: Be careful it change the left_batch_transitions in place"""
left_batch_transitions["state"] = {
key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0)
for key in left_batch_transitions["state"]
@@ -321,11 +378,14 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# online_env = make_env(cfg, n_envs=cfg.training.online_rollout_batch_size)
# NOTE: Off policy algorithm are efficient enought to use a single environment
logging.info("make_env online")
online_env = make_env(cfg, n_envs=1)
# online_env = make_env(cfg, n_envs=1)
# TODO: Remove the import of maniskill and unifiy with make env
online_env = make_maniskill_env(cfg, n_envs=1)
if cfg.training.eval_freq > 0:
logging.info("make_env eval")
eval_env = make_env(cfg, n_envs=1)
# eval_env = make_env(cfg, n_envs=1)
# TODO: Remove the import of maniskill and unifiy with make env
eval_env = make_maniskill_env(cfg, n_envs=1)
# TODO: Add a way to resume training
@@ -348,6 +408,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Hack: But if we do online traning, we do not need dataset_stats
dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
device=device,
)
assert isinstance(policy, nn.Module)
@@ -360,17 +421,15 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
log_output_dir(out_dir)
logging.info(f"{cfg.env.task=}")
# TODO: Handle offline steps
# logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
logging.info(f"{cfg.training.online_steps=}")
# logging.info(f"{offline_dataset.num_frames=} ({format_big_number(offline_dataset.num_frames)})")
# logging.info(f"{offline_dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
obs, info = online_env.reset()
obs = preprocess_observation(obs)
# HACK for maniskill
# obs = preprocess_observation(obs)
obs = preprocess_maniskill_observation(obs)
obs = {key: obs[key].to(device, non_blocking=True) for key in obs}
replay_buffer = ReplayBuffer(
@@ -378,8 +437,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
)
batch_size = cfg.training.batch_size
# if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig):
# raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
if cfg.dataset_repo_id is not None:
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
@@ -404,7 +462,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# HACK
action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True)
next_obs = preprocess_observation(next_obs)
# HACK: For maniskill
# next_obs = preprocess_observation(next_obs)
next_obs = preprocess_maniskill_observation(next_obs)
next_obs = {key: next_obs[key].to(device, non_blocking=True) for key in obs}
sum_reward_episode += float(reward[0])
# Because we are using a single environment
@@ -413,16 +473,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info(f"Global step {interaction_step}: Episode reward: {sum_reward_episode}")
logger.log_dict({"Sum episode reward": sum_reward_episode}, interaction_step)
sum_reward_episode = 0
if "final_info" in info:
if "is_success" in info["final_info"][0]:
logging.info(
f"Global step {interaction_step}: Episode success: {info['final_info'][0]['is_success']}"
)
if "coverage" in info["final_info"][0]:
logging.info(
f"Global step {interaction_step}: Episode final coverage: {info['final_info'][0]['coverage']} \n"
)
logger.log_dict({"Final coverage": info["final_info"][0]["coverage"]}, interaction_step)
# HACK: This is for maniskill
logging.info(
f"global step {interaction_step}: episode success: {info['success'].float().item()} \n"
)
logger.log_dict({"Episode success": info["success"].float().item()}, interaction_step)
replay_buffer.add(
state=obs,
@@ -433,38 +488,13 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
)
obs = next_obs
if interaction_step >= cfg.training.online_step_before_learning:
for _ in range(cfg.policy.utd_ratio - 1):
batch = replay_buffer.sample(batch_size)
if cfg.dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(batch, batch_offline)
actions = batch["action"]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
loss_critic = policy.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
optimizers["critic"].step()
if interaction_step < cfg.training.online_step_before_learning:
continue
for _ in range(cfg.policy.utd_ratio - 1):
batch = replay_buffer.sample(batch_size)
if cfg.dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
# NOTE: We have to handle the normalization for the batch
# batch = policy.normalize_inputs(batch)
batch = concatenate_batch_transitions(batch, batch_offline)
actions = batch["action"]
rewards = batch["reward"]
@@ -483,31 +513,55 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
loss_critic.backward()
optimizers["critic"].step()
training_infos = {}
training_infos["loss_critic"] = loss_critic.item()
batch = replay_buffer.sample(batch_size)
if cfg.dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
if interaction_step % cfg.training.policy_update_freq == 0:
# TD3 Trick
for _ in range(cfg.training.policy_update_freq):
loss_actor = policy.compute_loss_actor(observations=observations)
actions = batch["action"]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
optimizers["actor"].zero_grad()
loss_actor.backward()
optimizers["actor"].step()
loss_critic = policy.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
optimizers["critic"].step()
training_infos["loss_actor"] = loss_actor.item()
training_infos = {}
training_infos["loss_critic"] = loss_critic.item()
loss_temperature = policy.compute_loss_temperature(observations=observations)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
optimizers["temperature"].step()
if interaction_step % cfg.training.policy_update_freq == 0:
# TD3 Trick
for _ in range(cfg.training.policy_update_freq):
loss_actor = policy.compute_loss_actor(observations=observations)
training_infos["loss_temperature"] = loss_temperature.item()
optimizers["actor"].zero_grad()
loss_actor.backward()
optimizers["actor"].step()
if interaction_step % cfg.training.log_freq == 0:
logger.log_dict(training_infos, interaction_step, mode="train")
training_infos["loss_actor"] = loss_actor.item()
policy.update_target_networks()
loss_temperature = policy.compute_loss_temperature(observations=observations)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
optimizers["temperature"].step()
training_infos["loss_temperature"] = loss_temperature.item()
if interaction_step % cfg.training.log_freq == 0:
logger.log_dict(training_infos, interaction_step, mode="train")
policy.update_target_networks()
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")