Added support for checkpointing the policy. We can save and load the policy state dict, optimizers state, optimization step and interaction step

Added functions for converting the replay buffer from and to LeRobotDataset. When we want to save the replay buffer, we convert it first to LeRobotDataset format and save it locally and vice-versa.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi
2025-01-30 17:39:41 +00:00
committed by AdilZouitine
parent 9afd093030
commit 2023289ce8
5 changed files with 215 additions and 91 deletions

View File

@@ -18,6 +18,7 @@ import io
import logging
import pickle
import queue
import shutil
import time
from pprint import pformat
from threading import Lock, Thread
@@ -29,18 +30,25 @@ import hilserl_pb2 # type: ignore
import hilserl_pb2_grpc # type: ignore
import hydra
import torch
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from torch import nn
# TODO: Remove the import of maniskill
from lerobot.common.datasets.factory import make_dataset
# TODO: Remove the import of maniskill
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.utils.utils import (
format_big_number,
get_global_random_state,
get_safe_torch_device,
init_hydra_config,
init_logging,
set_global_random_state,
set_global_seed,
)
from lerobot.scripts.server.buffer import (
@@ -127,10 +135,9 @@ def add_actor_information_and_train(
optimizers: dict[str, torch.optim.Optimizer],
policy: nn.Module,
policy_lock: Lock,
buffer_lock: Lock,
offline_buffer_lock: Lock,
logger_lock: Lock,
logger: Logger,
resume_optimization_step: int | None = None,
resume_interaction_step: int | None = None,
):
"""
Handles data transfer from the actor to the learner, manages training updates,
@@ -159,16 +166,17 @@ def add_actor_information_and_train(
optimizers (Dict[str, torch.optim.Optimizer]): A dictionary of optimizers (`"actor"`, `"critic"`, `"temperature"`).
policy (nn.Module): The reinforcement learning policy with critic, actor, and temperature parameters.
policy_lock (Lock): A threading lock to ensure safe policy updates.
buffer_lock (Lock): A threading lock to safely access the online replay buffer.
offline_buffer_lock (Lock): A threading lock to safely access the offline replay buffer.
logger_lock (Lock): A threading lock to safely log training metrics.
logger (Logger): Logger instance for tracking training progress.
resume_optimization_step (int | None): In the case of resume training, start from the last optimization step reached.
resume_interaction_step (int | None): In the case of resume training, shift the interaction step with the last saved step in order to not break logging.
"""
# NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
# are divided by 200. So we need to have a single thread that does all the work.
time.time()
optimization_step = 0
interaction_message, transition = None, None
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
while True:
while not transition_queue.empty():
transition_list = transition_queue.get()
@@ -178,6 +186,8 @@ def add_actor_information_and_train(
while not interaction_message_queue.empty():
interaction_message = interaction_message_queue.get()
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
interaction_message["Interaction step"] += interaction_step_shift
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
if len(replay_buffer) < cfg.training.online_step_before_learning:
@@ -186,9 +196,9 @@ def add_actor_information_and_train(
for _ in range(cfg.policy.utd_ratio - 1):
batch = replay_buffer.sample(batch_size)
if cfg.dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(batch, batch_offline)
# if cfg.offline_dataset_repo_id is not None:
# batch_offline = offline_replay_buffer.sample(batch_size)
# batch = concatenate_batch_transitions(batch, batch_offline)
actions = batch["action"]
rewards = batch["reward"]
@@ -210,11 +220,11 @@ def add_actor_information_and_train(
batch = replay_buffer.sample(batch_size)
if cfg.dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
# if cfg.offline_dataset_repo_id is not None:
# batch_offline = offline_replay_buffer.sample(batch_size)
# batch = concatenate_batch_transitions(
# left_batch_transitions=batch, right_batch_transition=batch_offline
# )
actions = batch["action"]
rewards = batch["reward"]
@@ -274,6 +284,39 @@ def add_actor_information_and_train(
if optimization_step % cfg.training.log_freq == 0:
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
if cfg.training.save_checkpoint and (
optimization_step % cfg.training.save_freq == 0 or optimization_step == cfg.training.online_steps
):
logging.info(f"Checkpoint policy after step {optimization_step}")
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
# needed (choose 6 as a minimum for consistency without being overkill).
_num_digits = max(6, len(str(cfg.training.online_steps)))
step_identifier = f"{optimization_step:0{_num_digits}d}"
interaction_step = (
interaction_message["Interaction step"] if interaction_message is not None else 0
)
logger.save_checkpoint(
optimization_step,
policy,
optimizers,
scheduler=None,
identifier=step_identifier,
interaction_step=interaction_step,
)
# TODO : temporarly save replay buffer here, remove later when on the robot
# We want to control this with the keyboard inputs
dataset_dir = logger.log_dir / "dataset"
if dataset_dir.exists() and dataset_dir.is_dir():
shutil.rmtree(
dataset_dir,
)
replay_buffer.to_lerobot_dataset(
cfg.dataset_repo_id, fps=cfg.fps, root=logger.log_dir / "dataset"
)
logging.info("Resume training")
def make_optimizers_and_scheduler(cfg, policy: nn.Module):
"""
@@ -330,7 +373,49 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info(pformat(OmegaConf.to_container(cfg)))
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
logger_lock = Lock()
## Handle resume by reloading the state of the policy and optimization
# If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need
# to check for any differences between the provided config and the checkpoint's config.
if cfg.resume:
if not Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError(
"You have set resume=True, but there is no model checkpoint in "
f"{Logger.get_last_checkpoint_dir(out_dir)}"
)
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
logging.info(
colored(
"You have set resume=True, indicating that you wish to resume a run",
color="yellow",
attrs=["bold"],
)
)
# Get the configuration file from the last checkpoint.
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
# Check for differences between the checkpoint configuration and provided configuration.
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
# Ignore the `resume` and parameters.
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
del diff["values_changed"]["root['resume']"]
# Log a warning about differences between the checkpoint configuration and the provided
# configuration.
if len(diff) > 0:
logging.warning(
"At least one difference was detected between the checkpoint configuration and "
f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration "
"takes precedence.",
)
# Use the checkpoint config instead of the provided config (but keep `resume` parameter).
cfg = checkpoint_cfg
cfg.resume = True
elif Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError(
f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. If "
"you meant to resume training, please use `resume=true` in your command or yaml configuration."
)
# ===========================
set_global_seed(cfg.seed)
@@ -346,20 +431,38 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
# TODO: At some point we should just need make sac policy
policy_lock = Lock()
with logger_lock:
policy: SACPolicy = make_policy(
hydra_cfg=cfg,
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online traning, we do not need dataset_stats
dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
device=device,
)
policy: SACPolicy = make_policy(
hydra_cfg=cfg,
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online traning, we do not need dataset_stats
dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
)
# device=device,
# )
assert isinstance(policy, nn.Module)
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
# load last training state
# We can't use the logger function in `lerobot/common/logger.py`
# because it only loads the optimization step and not the interaction one
# to avoid altering that code, we will just load the optimization state manually
resume_interaction_step, resume_optimization_step = None, None
if cfg.resume:
training_state = torch.load(logger.last_checkpoint_dir / logger.training_state_file_name)
if type(training_state["optimizer"]) is dict:
assert set(training_state["optimizer"].keys()) == set(optimizers.keys()), (
"Optimizer dictionaries do not have the same keys during resume!"
)
for k, v in training_state["optimizer"].items():
optimizers[k].load_state_dict(v)
else:
optimizers.load_state_dict(training_state["optimizer"])
# Small hack to get the expected keys: use `get_global_random_state`.
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
resume_optimization_step = training_state["step"]
resume_interaction_step = training_state["interaction_step"]
# TODO: Handle resume
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
@@ -369,24 +472,34 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
buffer_lock = Lock()
replay_buffer = ReplayBuffer(
capacity=cfg.training.online_buffer_capacity, device=device, state_keys=cfg.policy.input_shapes.keys()
)
if not cfg.resume:
replay_buffer = ReplayBuffer(
capacity=cfg.training.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
)
else:
# Reload replay buffer
dataset = LeRobotDataset(
repo_id=cfg.dataset_repo_id, local_files_only=True, root=logger.log_dir / "dataset"
)
replay_buffer = ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=dataset,
capacity=cfg.training.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
)
batch_size = cfg.training.batch_size
offline_buffer_lock = None
offline_replay_buffer = None
if cfg.dataset_repo_id is not None:
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
logging.info("Convertion to a offline replay buffer")
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys()
)
offline_buffer_lock = Lock()
batch_size: int = batch_size // 2 # We will sample from both replay buffer
# if cfg.dataset_repo_id is not None:
# logging.info("make_dataset offline buffer")
# offline_dataset = make_dataset(cfg)
# logging.info("Convertion to a offline replay buffer")
# offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
# offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys()
# )
# batch_size: int = batch_size // 2 # We will sample from both replay buffer
actor_ip = cfg.actor_learner_config.actor_ip
port = cfg.actor_learner_config.port
@@ -413,10 +526,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
optimizers,
policy,
policy_lock,
buffer_lock,
offline_buffer_lock,
logger_lock,
logger,
resume_optimization_step,
resume_interaction_step,
),
)
transition_thread.start()