Files
lerobot/lerobot/scripts/server/learner_server.py
AdilZouitine 8fb373aeb2 fix
2025-04-18 15:10:22 +02:00

1011 lines
36 KiB
Python

#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import shutil
import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from pprint import pformat
import draccus
import grpc
# Import generated stubs
import hilserl_pb2_grpc # type: ignore
import torch
from termcolor import colored
from torch import nn
from torch.multiprocessing import Queue
from torch.optim.optimizer import Optimizer
from lerobot.common.constants import (
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
PRETRAINED_MODEL_DIR,
TRAINING_STATE_DIR,
TRAINING_STEP,
)
from lerobot.common.datasets.factory import make_dataset
# TODO: Remove the import of maniskill
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACConfig, SACPolicy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
save_checkpoint,
save_training_state,
update_last_checkpoint,
)
from lerobot.common.utils.train_utils import (
load_training_state as utils_load_training_state,
)
from lerobot.common.utils.utils import (
format_big_number,
get_safe_torch_device,
init_logging,
)
from lerobot.common.utils.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.server import learner_service
from lerobot.scripts.server.buffer import (
ReplayBuffer,
bytes_to_python_object,
bytes_to_transitions,
concatenate_batch_transitions,
move_state_dict_to_device,
move_transition_to_device,
state_to_bytes,
)
from lerobot.scripts.server.utils import setup_process_handlers
logging.basicConfig(level=logging.INFO)
def handle_resume_logic(cfg: TrainPipelineConfig) -> TrainPipelineConfig:
"""
Handle the resume logic for training.
If resume is True:
- Verifies that a checkpoint exists
- Loads the checkpoint configuration
- Logs resumption details
- Returns the checkpoint configuration
If resume is False:
- Checks if an output directory exists (to prevent accidental overwriting)
- Returns the original configuration
Args:
cfg (TrainPipelineConfig): The training configuration
Returns:
TrainPipelineConfig: The updated configuration
Raises:
RuntimeError: If resume is True but no checkpoint found, or if resume is False but directory exists
"""
out_dir = cfg.output_dir
# Case 1: Not resuming, but need to check if directory exists to prevent overwrites
if not cfg.resume:
checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
if os.path.exists(checkpoint_dir):
raise RuntimeError(
f"Output directory {checkpoint_dir} already exists. Use `resume=true` to resume training."
)
return cfg
# Case 2: Resuming training
checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
if not os.path.exists(checkpoint_dir):
raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True")
# Log that we found a valid checkpoint and are resuming
logging.info(
colored(
"Valid checkpoint found: resume=True detected, resuming previous run",
color="yellow",
attrs=["bold"],
)
)
# Load config using Draccus
checkpoint_cfg_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR, "train_config.json")
checkpoint_cfg = TrainPipelineConfig.from_pretrained(checkpoint_cfg_path)
# Ensure resume flag is set in returned config
checkpoint_cfg.resume = True
return checkpoint_cfg
def load_training_state(
cfg: TrainPipelineConfig,
optimizers: Optimizer | dict[str, Optimizer],
):
"""
Loads the training state (optimizers, step count, etc.) from a checkpoint.
Args:
cfg (TrainPipelineConfig): Training configuration
optimizers (Optimizer | dict): Optimizers to load state into
Returns:
tuple: (optimization_step, interaction_step) or (None, None) if not resuming
"""
if not cfg.resume:
return None, None
# Construct path to the last checkpoint directory
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
logging.info(f"Loading training state from {checkpoint_dir}")
try:
# Use the utility function from train_utils which loads the optimizer state
step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None)
# Load interaction step separately from training_state.pt
training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt")
interaction_step = 0
if os.path.exists(training_state_path):
training_state = torch.load(training_state_path, weights_only=False)
interaction_step = training_state.get("interaction_step", 0)
logging.info(f"Resuming from step {step}, interaction step {interaction_step}")
return step, interaction_step
except Exception as e:
logging.error(f"Failed to load training state: {e}")
return None, None
def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None:
"""
Log information about the training process.
Args:
cfg (TrainPipelineConfig): Training configuration
policy (nn.Module): Policy model
"""
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.policy.online_steps=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
def initialize_replay_buffer(cfg: TrainPipelineConfig, device: str, storage_device: str) -> ReplayBuffer:
"""
Initialize a replay buffer, either empty or from a dataset if resuming.
Args:
cfg (TrainPipelineConfig): Training configuration
device (str): Device to store tensors on
storage_device (str): Device for storage optimization
Returns:
ReplayBuffer: Initialized replay buffer
"""
if not cfg.resume:
return ReplayBuffer(
capacity=cfg.policy.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_features.keys(),
storage_device=storage_device,
optimize_memory=True,
)
logging.info("Resume training load the online dataset")
dataset_path = os.path.join(cfg.output_dir, "dataset")
# NOTE: In RL is possible to not have a dataset.
repo_id = None
if cfg.dataset is not None:
repo_id = cfg.dataset.repo_id
dataset = LeRobotDataset(
repo_id=repo_id,
root=dataset_path,
)
return ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=dataset,
capacity=cfg.policy.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_features.keys(),
optimize_memory=True,
)
def initialize_offline_replay_buffer(
cfg: TrainPipelineConfig,
device: str,
storage_device: str,
active_action_dims: list[int] | None = None,
) -> ReplayBuffer:
"""
Initialize an offline replay buffer from a dataset.
Args:
cfg (TrainPipelineConfig): Training configuration
device (str): Device to store tensors on
storage_device (str): Device for storage optimization
active_action_dims (list[int] | None): Active action dimensions for masking
Returns:
ReplayBuffer: Initialized offline replay buffer
"""
if not cfg.resume:
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
else:
logging.info("load offline dataset")
dataset_offline_path = os.path.join(cfg.output_dir, "dataset_offline")
offline_dataset = LeRobotDataset(
repo_id=cfg.dataset.repo_id,
root=dataset_offline_path,
)
logging.info("Convert to a offline replay buffer")
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset,
device=device,
state_keys=cfg.policy.input_features.keys(),
action_mask=active_action_dims,
action_delta=cfg.env.wrapper.delta_action,
storage_device=storage_device,
optimize_memory=True,
capacity=cfg.policy.offline_buffer_capacity,
)
return offline_replay_buffer
def get_observation_features(
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
return None, None
with torch.no_grad():
observation_features = (
policy.actor.encoder(observations) if policy.actor.encoder is not None else None
)
next_observation_features = (
policy.actor.encoder(next_observations) if policy.actor.encoder is not None else None
)
return observation_features, next_observation_features
def use_threads(cfg: TrainPipelineConfig) -> bool:
return cfg.policy.concurrency.learner == "threads"
def start_learner_threads(
cfg: TrainPipelineConfig,
wandb_logger: WandBLogger | None,
shutdown_event: any, # Event,
) -> None:
"""
Start the learner threads for training.
Args:
cfg (TrainPipelineConfig): Training configuration
wandb_logger (WandBLogger | None): Logger for metrics
shutdown_event: Event to signal shutdown
"""
# Create multiprocessing queues
transition_queue = Queue()
interaction_message_queue = Queue()
parameters_queue = Queue()
concurrency_entity = None
if use_threads(cfg):
from threading import Thread
concurrency_entity = Thread
else:
from torch.multiprocessing import Process
concurrency_entity = Process
communication_process = concurrency_entity(
target=start_learner_server,
args=(
parameters_queue,
transition_queue,
interaction_message_queue,
shutdown_event,
cfg,
),
daemon=True,
)
communication_process.start()
add_actor_information_and_train(
cfg=cfg,
wandb_logger=wandb_logger,
shutdown_event=shutdown_event,
transition_queue=transition_queue,
interaction_message_queue=interaction_message_queue,
parameters_queue=parameters_queue,
)
logging.info("[LEARNER] Training process stopped")
logging.info("[LEARNER] Closing queues")
transition_queue.close()
interaction_message_queue.close()
parameters_queue.close()
communication_process.join()
logging.info("[LEARNER] Communication process joined")
logging.info("[LEARNER] join queues")
transition_queue.cancel_join_thread()
interaction_message_queue.cancel_join_thread()
parameters_queue.cancel_join_thread()
logging.info("[LEARNER] queues closed")
def start_learner_server(
parameters_queue: Queue,
transition_queue: Queue,
interaction_message_queue: Queue,
shutdown_event: any, # Event,
cfg: TrainPipelineConfig,
):
if not use_threads(cfg):
# Create a process-specific log file
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"learner_server_process_{os.getpid()}.log")
# Initialize logging with explicit log file
init_logging(log_file=log_file)
logging.info(f"Learner server process logging initialized")
# Setup process handlers to handle shutdown signal
# But use shutdown event from the main process
# Return back for MP
setup_process_handlers(False)
service = learner_service.LearnerService(
shutdown_event=shutdown_event,
parameters_queue=parameters_queue,
seconds_between_pushes=cfg.policy.actor_learner_config.policy_parameters_push_frequency,
transition_queue=transition_queue,
interaction_message_queue=interaction_message_queue,
)
server = grpc.server(
ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS),
options=[
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
],
)
hilserl_pb2_grpc.add_LearnerServiceServicer_to_server(
service,
server,
)
host = cfg.policy.actor_learner_config.learner_host
port = cfg.policy.actor_learner_config.learner_port
server.add_insecure_port(f"{host}:{port}")
server.start()
logging.info("[LEARNER] gRPC server started")
shutdown_event.wait()
logging.info("[LEARNER] Stopping gRPC server...")
server.stop(learner_service.STUTDOWN_TIMEOUT)
logging.info("[LEARNER] gRPC server stopped")
def check_nan_in_transition(
observations: torch.Tensor,
actions: torch.Tensor,
next_state: torch.Tensor,
raise_error: bool = False,
) -> bool:
"""
Check for NaN values in transition data.
Args:
observations: Dictionary of observation tensors
actions: Action tensor
next_state: Dictionary of next state tensors
raise_error: If True, raises ValueError when NaN is detected
Returns:
bool: True if NaN values were detected, False otherwise
"""
nan_detected = False
# Check observations
for key, tensor in observations.items():
if torch.isnan(tensor).any():
logging.error(f"observations[{key}] contains NaN values")
nan_detected = True
if raise_error:
raise ValueError(f"NaN detected in observations[{key}]")
# Check next state
for key, tensor in next_state.items():
if torch.isnan(tensor).any():
logging.error(f"next_state[{key}] contains NaN values")
nan_detected = True
if raise_error:
raise ValueError(f"NaN detected in next_state[{key}]")
# Check actions
if torch.isnan(actions).any():
logging.error("actions contains NaN values")
nan_detected = True
if raise_error:
raise ValueError("NaN detected in actions")
return nan_detected
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
logging.debug("[LEARNER] Pushing actor policy to the queue")
state_dict = move_state_dict_to_device(policy.actor.state_dict(), device="cpu")
state_bytes = state_to_bytes(state_dict)
parameters_queue.put(state_bytes)
def add_actor_information_and_train(
cfg: TrainPipelineConfig,
wandb_logger: WandBLogger | None,
shutdown_event: any, # Event,
transition_queue: Queue,
interaction_message_queue: Queue,
parameters_queue: Queue,
):
"""
Handles data transfer from the actor to the learner, manages training updates,
and logs training progress in an online reinforcement learning setup.
This function continuously:
- Transfers transitions from the actor to the replay buffer.
- Logs received interaction messages.
- Ensures training begins only when the replay buffer has a sufficient number of transitions.
- Samples batches from the replay buffer and performs multiple critic updates.
- Periodically updates the actor, critic, and temperature optimizers.
- Logs training statistics, including loss values and optimization frequency.
Args:
cfg (TrainPipelineConfig): Configuration object containing hyperparameters.
wandb_logger (WandBLogger | None): Logger for tracking training progress.
shutdown_event (Event): Event to signal shutdown.
transition_queue (Queue): Queue for receiving transitions from the actor.
interaction_message_queue (Queue): Queue for receiving interaction messages from the actor.
parameters_queue (Queue): Queue for sending policy parameters to the actor.
"""
# Initialize logging for multiprocessing
if not use_threads(cfg):
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"learner_train_process_{os.getpid()}.log")
init_logging(log_file=log_file)
logging.info(f"Initialized logging for actor information and training process")
device = get_safe_torch_device(try_device=cfg.policy.device, log=True)
storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device)
logging.info("Initializing policy")
# Get checkpoint dir for resuming
checkpoint_dir = (
os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None
)
pretrained_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR) if checkpoint_dir else None
policy: SACPolicy = make_policy(
cfg=cfg.policy,
# ds_meta=cfg.dataset,
env_cfg=cfg.env,
)
# Update the policy config with the grad_clip_norm value from training config if it exists
clip_grad_norm_value: float = cfg.policy.grad_clip_norm
# compile policy
policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
policy.train()
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
last_time_policy_pushed = time.time()
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy)
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
log_training_info(cfg=cfg, policy=policy)
replay_buffer = initialize_replay_buffer(cfg, device, storage_device)
batch_size = cfg.batch_size
offline_replay_buffer = None
if cfg.dataset is not None:
active_action_dims = None
# TODO: FIX THIS
if cfg.env.wrapper.joint_masking_action_space is not None:
active_action_dims = [
i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask
]
offline_replay_buffer = initialize_offline_replay_buffer(
cfg=cfg,
device=device,
storage_device=storage_device,
active_action_dims=active_action_dims,
)
batch_size: int = batch_size // 2 # We will sample from both replay buffer
# NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
# are divided by 200. So we need to have a single thread that does all the work.
time.time()
logging.info("Starting learner thread")
interaction_message, transition = None, None
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
# Extract variables from cfg
online_step_before_learning = cfg.policy.online_step_before_learning
utd_ratio = cfg.policy.utd_ratio
dataset_repo_id = None
if cfg.dataset is not None:
dataset_repo_id = cfg.dataset.repo_id
fps = cfg.env.fps
log_freq = cfg.log_freq
save_freq = cfg.save_freq
device = cfg.policy.device
storage_device = cfg.policy.storage_device
policy_update_freq = cfg.policy.policy_update_freq
policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
saving_checkpoint = cfg.save_checkpoint
online_steps = cfg.policy.online_steps
while True:
if shutdown_event is not None and shutdown_event.is_set():
logging.info("[LEARNER] Shutdown signal received. Exiting...")
break
logging.debug("[LEARNER] Waiting for transitions")
while not transition_queue.empty() and not shutdown_event.is_set():
transition_list = transition_queue.get()
transition_list = bytes_to_transitions(transition_list)
for transition in transition_list:
transition = move_transition_to_device(transition, device=device)
if check_nan_in_transition(
transition["state"], transition["action"], transition["next_state"]
):
logging.warning("NaN detected in transition, skipping")
continue
replay_buffer.add(**transition)
if dataset_repo_id is not None and transition.get("complementary_info", {}).get(
"is_intervention"
):
offline_replay_buffer.add(**transition)
logging.debug("[LEARNER] Received transitions")
logging.debug("[LEARNER] Waiting for interactions")
while not interaction_message_queue.empty() and not shutdown_event.is_set():
interaction_message = interaction_message_queue.get()
interaction_message = bytes_to_python_object(interaction_message)
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
interaction_message["Interaction step"] += interaction_step_shift
# Log interaction messages with WandB if available
if wandb_logger:
wandb_logger.log_dict(d=interaction_message, mode="train", custom_step_key="Interaction step")
logging.debug("[LEARNER] Received interactions")
if len(replay_buffer) < online_step_before_learning:
continue
logging.debug("[LEARNER] Starting optimization loop")
time_for_one_optimization_step = time.time()
for _ in range(utd_ratio - 1):
batch = replay_buffer.sample(batch_size=batch_size)
if dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
actions = batch["action"]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observation_features, next_observation_features = get_observation_features(
policy=policy, observations=observations, next_observations=next_observations
)
loss_critic = policy.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
# clip gradients
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
)
optimizers["critic"].step()
batch = replay_buffer.sample(batch_size=batch_size)
if dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
actions = batch["action"]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observation_features, next_observation_features = get_observation_features(
policy=policy, observations=observations, next_observations=next_observations
)
loss_critic = policy.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
# clip gradients
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
).item()
optimizers["critic"].step()
training_infos = {}
training_infos["loss_critic"] = loss_critic.item()
training_infos["critic_grad_norm"] = critic_grad_norm
if optimization_step % policy_update_freq == 0:
for _ in range(policy_update_freq):
loss_actor = policy.compute_loss_actor(
observations=observations,
observation_features=observation_features,
)
optimizers["actor"].zero_grad()
loss_actor.backward()
# clip gradients
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
).item()
optimizers["actor"].step()
training_infos["loss_actor"] = loss_actor.item()
training_infos["actor_grad_norm"] = actor_grad_norm
# Temperature optimization
loss_temperature = policy.compute_loss_temperature(
observations=observations,
observation_features=observation_features,
)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
# clip gradients
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
).item()
optimizers["temperature"].step()
training_infos["loss_temperature"] = loss_temperature.item()
training_infos["temperature_grad_norm"] = temp_grad_norm
training_infos["temperature"] = policy.temperature
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
last_time_policy_pushed = time.time()
policy.update_target_networks()
if optimization_step % log_freq == 0:
training_infos["replay_buffer_size"] = len(replay_buffer)
if offline_replay_buffer is not None:
training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer)
training_infos["Optimization step"] = optimization_step
# Log training metrics
if wandb_logger:
wandb_logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step")
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
# Log optimization frequency
if wandb_logger:
wandb_logger.log_dict(
{
"Optimization frequency loop [Hz]": frequency_for_one_optimization_step,
"Optimization step": optimization_step,
},
mode="train",
custom_step_key="Optimization step",
)
optimization_step += 1
if optimization_step % log_freq == 0:
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
if saving_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps):
save_training_checkpoint(
cfg=cfg,
optimization_step=optimization_step,
online_steps=online_steps,
interaction_message=interaction_message,
policy=policy,
optimizers=optimizers,
replay_buffer=replay_buffer,
offline_replay_buffer=offline_replay_buffer,
dataset_repo_id=dataset_repo_id,
fps=fps,
)
def save_training_checkpoint(
cfg: TrainPipelineConfig,
optimization_step: int,
online_steps: int,
interaction_message: dict | None,
policy: nn.Module,
optimizers: dict[str, Optimizer],
replay_buffer: ReplayBuffer,
offline_replay_buffer: ReplayBuffer | None = None,
dataset_repo_id: str | None = None,
fps: int = 30,
) -> None:
"""
Save training checkpoint and associated data.
Args:
cfg: Training configuration
optimization_step: Current optimization step
online_steps: Total number of online steps
interaction_message: Dictionary containing interaction information
policy: Policy model to save
optimizers: Dictionary of optimizers
replay_buffer: Replay buffer to save as dataset
offline_replay_buffer: Optional offline replay buffer to save
dataset_repo_id: Repository ID for dataset
fps: Frames per second for dataset
"""
logging.info(f"Checkpoint policy after step {optimization_step}")
_num_digits = max(6, len(str(online_steps)))
step_identifier = f"{optimization_step:0{_num_digits}d}"
interaction_step = (
interaction_message["Interaction step"] if interaction_message is not None else 0
)
# Create checkpoint directory
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step)
# Save checkpoint
save_checkpoint(
checkpoint_dir=checkpoint_dir,
step=optimization_step,
cfg=cfg,
policy=policy,
optimizer=optimizers,
scheduler=None
)
# Save interaction step manually
training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR)
os.makedirs(training_state_dir, exist_ok=True)
training_state = {
"step": optimization_step,
"interaction_step": interaction_step
}
torch.save(training_state, os.path.join(training_state_dir, "training_state.pt"))
# Update the "last" symlink
update_last_checkpoint(checkpoint_dir)
# TODO : temporarly save replay buffer here, remove later when on the robot
# We want to control this with the keyboard inputs
dataset_dir = os.path.join(cfg.output_dir, "dataset")
if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir):
shutil.rmtree(dataset_dir)
# Save dataset
# NOTE: Handle the case where the dataset repo id is not specified in the config
# eg. RL training without demonstrations data
repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id
replay_buffer.to_lerobot_dataset(
repo_id=repo_id_buffer_save,
fps=fps,
root=dataset_dir
)
if offline_replay_buffer is not None:
dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline")
if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir):
shutil.rmtree(dataset_offline_dir)
offline_replay_buffer.to_lerobot_dataset(
cfg.dataset.repo_id,
fps=fps,
root=dataset_offline_dir,
)
logging.info("Resume training")
def make_optimizers_and_scheduler(cfg, policy: nn.Module):
"""
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
This function sets up Adam optimizers for:
- The **actor network**, ensuring that only relevant parameters are optimized.
- The **critic ensemble**, which evaluates the value function.
- The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods.
It also initializes a learning rate scheduler, though currently, it is set to `None`.
**NOTE:**
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
Args:
cfg: Configuration object containing hyperparameters.
policy (nn.Module): The policy model containing the actor, critic, and temperature components.
Returns:
Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]:
A tuple containing:
- `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers.
- `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling.
"""
optimizer_actor = torch.optim.Adam(
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
params=policy.actor.parameters_to_optimize,
lr=cfg.policy.actor_lr,
)
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
lr_scheduler = None
optimizers = {
"actor": optimizer_actor,
"critic": optimizer_critic,
"temperature": optimizer_temperature,
}
return optimizers, lr_scheduler
def train(cfg: TrainPipelineConfig, job_name: str | None = None):
"""
Main training function that initializes and runs the training process.
Args:
cfg (TrainPipelineConfig): The training configuration
job_name (str | None, optional): Job name for logging. Defaults to None.
"""
cfg.validate()
# if cfg.output_dir is None:
# raise ValueError("Output directory must be specified in config")
if job_name is None:
job_name = cfg.job_name
if job_name is None:
raise ValueError("Job name must be specified either in config or as a parameter")
# Create logs directory to ensure it exists
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"learner_{job_name}.log")
# Initialize logging with explicit log file
init_logging(log_file=log_file)
logging.info(f"Learner logging initialized, writing to {log_file}")
logging.info(pformat(cfg.to_dict()))
# Setup WandB logging if enabled
if cfg.wandb.enable and cfg.wandb.project:
from lerobot.common.utils.wandb_utils import WandBLogger
wandb_logger = WandBLogger(cfg)
else:
wandb_logger = None
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
# Handle resume logic
cfg = handle_resume_logic(cfg)
set_seed(seed=cfg.seed)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
shutdown_event = setup_process_handlers(use_threads(cfg))
start_learner_threads(
cfg=cfg,
wandb_logger=wandb_logger,
shutdown_event=shutdown_event,
)
@parser.wrap()
def train_cli(cfg: TrainPipelineConfig):
if not use_threads(cfg):
import torch.multiprocessing as mp
mp.set_start_method("spawn")
# Use the job_name from the config
train(
cfg,
job_name=cfg.job_name,
)
logging.info("[LEARNER] train_cli finished")
if __name__ == "__main__":
train_cli()
logging.info("[LEARNER] main finished")