Sub threading for multiprocessing

This commit is contained in:
Michel Aractingi
2025-02-20 17:21:55 +00:00
parent ff47c0b0d3
commit a9e912a05c
5 changed files with 870 additions and 36 deletions

View File

@@ -78,6 +78,7 @@ class SACPolicy(
# NOTE: For images the encoder should be shared between the actor and critic
if config.shared_encoder:
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
encoder_critic = torch.compile(encoder_critic)
encoder_actor: SACObservationEncoder = encoder_critic
else:
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
@@ -96,6 +97,7 @@ class SACPolicy(
),
output_normalization=self.normalize_targets,
)
self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = CriticEnsemble(
encoder=encoder_critic,
@@ -110,6 +112,7 @@ class SACPolicy(
),
output_normalization=self.normalize_targets,
)
self.critic_target = torch.compile(self.critic_target)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
@@ -120,6 +123,9 @@ class SACPolicy(
encoder_is_shared=config.shared_encoder,
**config.policy_kwargs,
)
# self.actor = torch.compile(self.actor)
if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
@@ -148,7 +154,7 @@ class SACPolicy(
return actions
def critic_forward(
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False, image_features: Tensor | None = None
) -> Tensor:
"""Forward pass through a critic network ensemble
@@ -161,7 +167,7 @@ class SACPolicy(
Tensor of Q-values from all critics
"""
critics = self.critic_target if use_target else self.critic_ensemble
q_values = critics(observations, actions)
q_values = critics(observations, actions, image_features=image_features)
return q_values
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ...
@@ -175,14 +181,14 @@ class SACPolicy(
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
def compute_loss_critic(self, observations, actions, rewards, next_observations, done) -> Tensor:
def compute_loss_critic(self, observations, actions, rewards, next_observations, done, image_features: Tensor | None = None, next_image_features: Tensor | None = None) -> Tensor:
temperature = self.log_alpha.exp().item()
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations)
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_image_features)
# 2- compute q targets
q_targets = self.critic_forward(
observations=next_observations, actions=next_action_preds, use_target=True
observations=next_observations, actions=next_action_preds, use_target=True, image_features=next_image_features
)
# subsample critics to prevent overfitting if use high UTD (update to date)
@@ -214,18 +220,18 @@ class SACPolicy(
).sum()
return critics_loss
def compute_loss_temperature(self, observations) -> Tensor:
def compute_loss_temperature(self, observations, image_features: Tensor | None = None) -> Tensor:
"""Compute the temperature loss"""
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.actor(observations)
_, log_probs, _ = self.actor(observations, image_features)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
return temperature_loss
def compute_loss_actor(self, observations) -> Tensor:
def compute_loss_actor(self, observations, image_features: Tensor | None = None) -> Tensor:
temperature = self.log_alpha.exp().item()
actions_pi, log_probs, _ = self.actor(observations)
actions_pi, log_probs, _ = self.actor(observations, image_features)
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
min_q_preds = q_preds.min(dim=0)[0]
@@ -360,6 +366,7 @@ class CriticEnsemble(nn.Module):
self,
observations: dict[str, torch.Tensor],
actions: torch.Tensor,
image_features: torch.Tensor | None = None,
) -> torch.Tensor:
device = get_device_from_parameters(self)
# Move each tensor in observations to device
@@ -370,7 +377,7 @@ class CriticEnsemble(nn.Module):
actions = self.output_normalization(actions)["action"]
actions = actions.to(device)
obs_enc = observations if self.encoder is None else self.encoder(observations)
obs_enc = image_features if image_features is not None else (observations if self.encoder is None else self.encoder(observations))
inputs = torch.cat([obs_enc, actions], dim=-1)
list_q_values = []
@@ -435,9 +442,10 @@ class Policy(nn.Module):
def forward(
self,
observations: torch.Tensor,
image_features: torch.Tensor | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists
obs_enc = observations if self.encoder is None else self.encoder(observations)
obs_enc = image_features if image_features is not None else (observations if self.encoder is None else self.encoder(observations))
# Get network outputs
outputs = self.network(obs_enc)

View File

@@ -31,7 +31,7 @@ training:
online_env_seed: 10000
online_buffer_capacity: 1000000
online_buffer_seed_size: 0
online_step_before_learning: 5000
online_step_before_learning: 500
do_online_rollout_async: false
policy_update_freq: 1
@@ -52,10 +52,10 @@ policy:
n_action_steps: 1
shared_encoder: true
vision_encoder_name: null
# vision_encoder_name: "helper2424/resnet10"
# freeze_vision_encoder: true
freeze_vision_encoder: false
# vision_encoder_name: null
vision_encoder_name: "helper2424/resnet10"
freeze_vision_encoder: true
# freeze_vision_encoder: false
input_shapes:
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.state: ["${env.state_dim}"]

View File

@@ -19,6 +19,7 @@ from typing import Any, Callable, Optional, Sequence, TypedDict
import torch
import torch.nn.functional as F # noqa: N812
import multiprocessing
from tqdm import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
@@ -135,10 +136,11 @@ class ReplayBuffer:
self,
capacity: int,
device: str = "cuda:0",
state_keys: Optional[Sequence[str]] = None,
state_keys: Optional[list[str]] = None,
image_augmentation_function: Optional[Callable] = None,
use_drq: bool = True,
storage_device: str = "cpu",
use_shared_memory: bool = False,
):
"""
Args:
@@ -150,16 +152,17 @@ class ReplayBuffer:
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored when adding transitions to the buffer.
Using "cpu" can help save GPU memory.
use_shared_memory (bool): Whether to use shared memory for the buffer.
"""
self.capacity = capacity
self.device = device
self.storage_device = storage_device
self.memory: list[Transition] = []
self.memory: list[Transition] = torch.multiprocessing.Manager().list() if use_shared_memory else []
self.position = 0
# If no state_keys provided, default to an empty list
# (you can handle this differently if needed)
self.state_keys = state_keys if state_keys is not None else []
# Convert state_keys to a list for pickling
self.state_keys = list(state_keys) if state_keys is not None else []
if image_augmentation_function is None:
self.image_augmentation_function = functools.partial(random_shift, pad=4)
self.use_drq = use_drq
@@ -187,7 +190,7 @@ class ReplayBuffer:
# }
if len(self.memory) < self.capacity:
self.memory.append(None)
self.memory.append({}) # Need to append something first for Manager().list()
# Create and store the Transition
self.memory[self.position] = Transition(
@@ -210,6 +213,7 @@ class ReplayBuffer:
capacity: Optional[int] = None,
action_mask: Optional[Sequence[int]] = None,
action_delta: Optional[float] = None,
use_shared_memory: bool = False,
) -> "ReplayBuffer":
"""
Convert a LeRobotDataset into a ReplayBuffer.
@@ -233,7 +237,7 @@ class ReplayBuffer:
"The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset."
)
replay_buffer = cls(capacity=capacity, device=device, state_keys=state_keys)
replay_buffer = cls(capacity=capacity, device=device, state_keys=state_keys, use_shared_memory=use_shared_memory)
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
# Fill the replay buffer with the lerobot dataset transitions
for data in list_transition:
@@ -345,7 +349,19 @@ class ReplayBuffer:
def sample(self, batch_size: int) -> BatchTransition:
"""Sample a random batch of transitions and collate them into batched tensors."""
batch_size = min(batch_size, len(self.memory))
list_of_transitions = random.sample(self.memory, batch_size)
# Different sampling approach for shared memory list vs regular list
list_of_transitions = random.sample(list(self.memory), batch_size)
# if isinstance(self.memory, multiprocessing.managers.ListProxy):
# # For shared memory list, we need to be careful about thread safety
# with torch.multiprocessing.Lock():
# # Get indices first to minimize lock time
# indices = torch.randint(len(self.memory), size=(batch_size,)).tolist()
# # Convert to list to avoid multiple proxy accesses
# list_of_transitions = [self.memory[i] for i in indices]
# else:
# # For regular list, use faster random.sample
# list_of_transitions = random.sample(self.memory, batch_size)
# -- Build batched states --
batch_state = {}

View File

@@ -36,6 +36,9 @@ from termcolor import colored
from torch import nn
from torch.optim.optimizer import Optimizer
# For profiling only
import datetime
from lerobot.common.datasets.factory import make_dataset
# TODO: Remove the import of maniskill
@@ -262,15 +265,15 @@ def learner_push_parameters(
while True:
with policy_lock:
params_dict = policy.actor.state_dict()
if policy.config.vision_encoder_name is not None:
if policy.config.freeze_vision_encoder:
params_dict: dict[str, torch.Tensor] = {
k: v for k, v in params_dict.items() if not k.startswith("encoder.")
}
else:
raise NotImplementedError(
"Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
)
# if policy.config.vision_encoder_name is not None:
# if policy.config.freeze_vision_encoder:
# params_dict: dict[str, torch.Tensor] = {
# k: v for k, v in params_dict.items() if not k.startswith("encoder.")
# }
# else:
# raise NotImplementedError(
# "Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
# )
params_dict = move_state_dict_to_device(params_dict, device="cpu")
# Serialize
@@ -347,6 +350,7 @@ def add_actor_information_and_train(
interaction_message, transition = None, None
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
while True:
while not transition_queue.empty():
transition_list = transition_queue.get()
@@ -370,6 +374,7 @@ def add_actor_information_and_train(
# logging.info(f"Size of replay buffer: {len(replay_buffer)}")
# logging.info(f"Size of offline replay buffer: {len(offline_replay_buffer)}")
image_features, next_image_features = None, None
time_for_one_optimization_step = time.time()
for _ in range(cfg.policy.utd_ratio - 1):
batch = replay_buffer.sample(batch_size)
@@ -385,6 +390,21 @@ def add_actor_information_and_train(
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
# Precompute encoder features from the frozen vision encoder if enabled
with record_function("encoder_forward"):
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
with torch.no_grad():
image_features = (
policy.actor.encoder(observations)
if policy.actor.encoder is not None
else None
)
next_image_features = (
policy.actor.encoder(next_observations)
if policy.actor.encoder is not None
else None
)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
@@ -392,6 +412,8 @@ def add_actor_information_and_train(
rewards=rewards,
next_observations=next_observations,
done=done,
image_features=image_features,
next_image_features=next_image_features,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
@@ -413,6 +435,19 @@ def add_actor_information_and_train(
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
# Precompute encoder features from the frozen vision encoder if enabled
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
with torch.no_grad():
image_features = (
policy.actor.encoder(observations)
if policy.actor.encoder is not None
else None
)
next_image_features = (
policy.actor.encoder(next_observations)
if policy.actor.encoder is not None
else None
)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
@@ -420,6 +455,8 @@ def add_actor_information_and_train(
rewards=rewards,
next_observations=next_observations,
done=done,
image_features=image_features,
next_image_features=next_image_features,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
@@ -431,7 +468,7 @@ def add_actor_information_and_train(
if optimization_step % cfg.training.policy_update_freq == 0:
for _ in range(cfg.training.policy_update_freq):
with policy_lock:
loss_actor = policy.compute_loss_actor(observations=observations)
loss_actor = policy.compute_loss_actor(observations=observations, image_features=image_features)
optimizers["actor"].zero_grad()
loss_actor.backward()
@@ -439,7 +476,7 @@ def add_actor_information_and_train(
training_infos["loss_actor"] = loss_actor.item()
loss_temperature = policy.compute_loss_temperature(observations=observations)
loss_temperature = policy.compute_loss_temperature(observations=observations, image_features=image_features)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
optimizers["temperature"].step()
@@ -503,6 +540,12 @@ def add_actor_information_and_train(
logging.info("Resume training")
profiler.step()
if optimization_step >= 50: # Profile for 500 steps
profiler.stop()
break
def make_optimizers_and_scheduler(cfg, policy: nn.Module):
"""
@@ -583,7 +626,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
)
# compile policy
policy = torch.compile(policy)
# policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)

View File

@@ -0,0 +1,767 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import logging
import pickle
import queue
import shutil
import time
from pprint import pformat
from multiprocessing import Process, Event
from torch.multiprocessing import Queue, Lock, set_start_method
import logging.handlers
from pathlib import Path
import grpc
# Import generated stubs
import hilserl_pb2 # type: ignore
import hilserl_pb2_grpc # type: ignore
import hydra
import torch
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from torch import nn
from torch.optim.optimizer import Optimizer
from lerobot.common.datasets.factory import make_dataset
# TODO: Remove the import of maniskill
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.utils.utils import (
format_big_number,
get_global_random_state,
get_safe_torch_device,
init_hydra_config,
init_logging,
set_global_random_state,
set_global_seed,
)
from lerobot.scripts.server.buffer import (
ReplayBuffer,
concatenate_batch_transitions,
move_state_dict_to_device,
move_transition_to_device,
)
logging.basicConfig(level=logging.INFO)
# Initialize these in the main process
# transition_queue = Queue(maxsize=1_000_000) # Set a maximum size
# interaction_message_queue = Queue(maxsize=1_000_000) # Set a maximum size
policy_lock = Lock()
replay_buffer_lock = Lock()
offline_replay_buffer_lock = Lock()
# logging_queue = Queue(maxsize=1_000_000) # Set a maximum size
def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
if not cfg.resume:
if Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError(
f"Output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. "
"Use `resume=true` to resume training."
)
return cfg
# if resume == True
checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir)
if not checkpoint_dir.exists():
raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True")
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
logging.info(
colored(
"Resume=True detected, resuming previous run",
color="yellow",
attrs=["bold"],
)
)
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
del diff["values_changed"]["root['resume']"]
if len(diff) > 0:
logging.warning(
f"Differences between the checkpoint config and the provided config detected: \n{pformat(diff)}\n"
"Checkpoint configuration takes precedence."
)
checkpoint_cfg.resume = True
return checkpoint_cfg
def load_training_state(
cfg: DictConfig,
logger: Logger,
optimizers: Optimizer | dict,
):
if not cfg.resume:
return None, None
training_state = torch.load(logger.last_checkpoint_dir / logger.training_state_file_name)
if isinstance(training_state["optimizer"], dict):
assert set(training_state["optimizer"].keys()) == set(optimizers.keys())
for k, v in training_state["optimizer"].items():
optimizers[k].load_state_dict(v)
else:
optimizers.load_state_dict(training_state["optimizer"])
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
return training_state["step"], training_state["interaction_step"]
def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
log_output_dir(out_dir)
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.training.online_steps=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
def initialize_replay_buffer(cfg: DictConfig, logger: Logger, device: str) -> ReplayBuffer:
if not cfg.resume:
return ReplayBuffer(
capacity=cfg.training.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
storage_device=device,
use_shared_memory=True
)
dataset = LeRobotDataset(
repo_id=cfg.dataset_repo_id, local_files_only=True, root=logger.log_dir / "dataset"
)
return ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=dataset,
capacity=cfg.training.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
use_shared_memory=True
)
def start_learner_threads(
cfg: DictConfig,
device: str,
replay_buffer: ReplayBuffer,
offline_replay_buffer: ReplayBuffer,
batch_size: int,
optimizers: dict,
policy: SACPolicy,
log_dir: Path,
transition_queue: Queue,
interaction_message_queue: Queue,
logging_queue: Queue,
resume_optimization_step: int | None = None,
resume_interaction_step: int | None = None,
) -> None:
actor_ip = cfg.actor_learner_config.actor_ip
port = cfg.actor_learner_config.port
# Move policy to shared memory
policy.share_memory()
server_process = Process(
target=stream_transitions_from_actor,
args=(
transition_queue,
interaction_message_queue,
actor_ip,
port,
),
daemon=True,
)
transition_process = Process(
target=train_offpolicy_rl,
daemon=True,
args=(
cfg,
replay_buffer,
offline_replay_buffer,
batch_size,
optimizers,
policy,
log_dir,
resume_optimization_step,
),
)
param_push_process = Process(
target=learner_push_parameters,
args=(
policy,
actor_ip,
port,
15
),
daemon=True,
)
fill_replay_buffers_process = Process(
target=fill_replay_buffers,
args=(
replay_buffer,
offline_replay_buffer,
transition_queue,
interaction_message_queue,
logging_queue,
resume_interaction_step,
device,
)
)
return server_process, transition_process, param_push_process, fill_replay_buffers_process
def stream_transitions_from_actor(
transition_queue: Queue,
interaction_message_queue: Queue,
host: str,
port: int,
):
"""
Runs a gRPC client that listens for transition and interaction messages from an Actor service.
This function establishes a gRPC connection with the given `host` and `port`, then continuously
streams transition data from the `ActorServiceStub`. The received transition data is deserialized
and stored in a queue (`transition_queue`). Similarly, interaction messages are also deserialized
and stored in a separate queue (`interaction_message_queue`).
Args:
host (str, optional): The IP address or hostname of the gRPC server. Defaults to `"127.0.0.1"`.
port (int, optional): The port number on which the gRPC server is running. Defaults to `50051`.
"""
# NOTE: This is waiting for the handshake to be done
# In the future we will do it in a canonical way with a proper handshake
time.sleep(10)
channel = grpc.insecure_channel(
f"{host}:{port}",
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
)
stub = hilserl_pb2_grpc.ActorServiceStub(channel)
while True:
try:
for response in stub.StreamTransition(hilserl_pb2.Empty()):
if response.HasField("transition"):
buffer = io.BytesIO(response.transition.transition_bytes)
transition = torch.load(buffer)
transition_queue.put(transition)
if response.HasField("interaction_message"):
content = pickle.loads(response.interaction_message.interaction_message_bytes)
interaction_message_queue.put(content)
except grpc.RpcError:
time.sleep(2) # Retry connection
continue
def learner_push_parameters(
policy: nn.Module,
actor_host="127.0.0.1",
actor_port=50052,
seconds_between_pushes=5
):
"""
As a client, connect to the Actor's gRPC server (ActorService)
and periodically push new parameters.
"""
time.sleep(10)
channel = grpc.insecure_channel(
f"{actor_host}:{actor_port}",
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
)
actor_stub = hilserl_pb2_grpc.ActorServiceStub(channel)
while True:
with policy_lock:
params_dict = policy.actor.state_dict()
# if policy.config.vision_encoder_name is not None:
# if policy.config.freeze_vision_encoder:
# params_dict: dict[str, torch.Tensor] = {
# k: v for k, v in params_dict.items() if not k.startswith("encoder.")
# }
# else:
# raise NotImplementedError(
# "Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
# )
params_dict = move_state_dict_to_device(params_dict, device="cpu")
# Serialize
buf = io.BytesIO()
torch.save(params_dict, buf)
params_bytes = buf.getvalue()
# Push them to the Actor's "SendParameters" method
logging.info("[LEARNER] Publishing parameters to the Actor")
response = actor_stub.SendParameters(hilserl_pb2.Parameters(parameter_bytes=params_bytes)) # noqa: F841
time.sleep(seconds_between_pushes)
def fill_replay_buffers(
replay_buffer: ReplayBuffer,
offline_replay_buffer: ReplayBuffer,
transition_queue: Queue,
interaction_message_queue: Queue,
logger_queue: Queue,
resume_interaction_step: int | None,
device: str,
):
while True:
while not transition_queue.empty():
transition_list = transition_queue.get() # Increase timeout
for transition in transition_list:
transition = move_transition_to_device(transition, device=device)
with replay_buffer_lock:
replay_buffer.add(**transition)
if transition.get("complementary_info", {}).get("is_intervention"):
with offline_replay_buffer_lock:
offline_replay_buffer.add(**transition)
while not interaction_message_queue.empty():
interaction_message = interaction_message_queue.get()
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
if resume_interaction_step is not None:
interaction_message["Interaction step"] += resume_interaction_step
logger_queue.put({
'info': interaction_message,
'step_key': "Interaction step"
})
def check_nan_in_transition(observations: torch.Tensor, actions: torch.Tensor, next_state: torch.Tensor):
for k in observations:
if torch.isnan(observations[k]).any():
logging.error(f"observations[{k}] contains NaN values")
for k in next_state:
if torch.isnan(next_state[k]).any():
logging.error(f"next_state[{k}] contains NaN values")
if torch.isnan(actions).any():
logging.error("actions contains NaN values")
def train_offpolicy_rl(
cfg,
replay_buffer: ReplayBuffer,
offline_replay_buffer: ReplayBuffer,
batch_size: int,
optimizers: dict[str, torch.optim.Optimizer],
policy: nn.Module,
log_dir: Path,
logging_queue: Queue,
resume_optimization_step: int | None = None,
):
"""
Handles data transfer from the actor to the learner, manages training updates,
and logs training progress in an online reinforcement learning setup.
This function continuously:
- Transfers transitions from the actor to the replay buffer.
- Logs received interaction messages.
- Ensures training begins only when the replay buffer has a sufficient number of transitions.
- Samples batches from the replay buffer and performs multiple critic updates.
- Periodically updates the actor, critic, and temperature optimizers.
- Logs training statistics, including loss values and optimization frequency.
**NOTE:**
- This function performs multiple responsibilities (data transfer, training, and logging).
It should ideally be split into smaller functions in the future.
- Due to Python's **Global Interpreter Lock (GIL)**, running separate threads for different tasks
significantly reduces performance. Instead, this function executes all operations in a single thread.
Args:
cfg: Configuration object containing hyperparameters.
device (str): The computing device (`"cpu"` or `"cuda"`).
replay_buffer (ReplayBuffer): The primary replay buffer storing online transitions.
offline_replay_buffer (ReplayBuffer): An additional buffer for offline transitions.
batch_size (int): The number of transitions to sample per training step.
optimizers (Dict[str, torch.optim.Optimizer]): A dictionary of optimizers (`"actor"`, `"critic"`, `"temperature"`).
policy (nn.Module): The reinforcement learning policy with critic, actor, and temperature parameters.
log_dir (Path): The directory to save the log files.
resume_optimization_step (int | None): In the case of resume training, start from the last optimization step reached.
resume_interaction_step (int | None): In the case of resume training, shift the interaction step with the last saved step in order to not break logging.
"""
# NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
# are divided by 200. So we need to have a single thread that does all the work.
time.time()
logging.info("Starting learner thread")
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
# Wait for stream process to be ready
while True:
with replay_buffer_lock:
logging.info(f"Size of replay buffer: {len(replay_buffer)}")
if len(replay_buffer) < cfg.training.online_step_before_learning:
time.sleep(1)
continue
# logging.info(f"Size of replay buffer: {len(replay_buffer)}")
# logging.info(f"Size of offline replay buffer: {len(offline_replay_buffer)}")
image_features, next_image_features = None, None
time_for_one_optimization_step = time.time()
for _ in range(cfg.policy.utd_ratio - 1):
with replay_buffer_lock:
batch = replay_buffer.sample(batch_size)
if cfg.dataset_repo_id is not None:
with offline_replay_buffer_lock:
batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(batch, batch_offline)
actions = batch["action"]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
# Precompute encoder features from the frozen vision encoder if enabled
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
with torch.no_grad():
image_features = (
policy.actor.encoder(observations)
if policy.actor.encoder is not None
else None
)
next_image_features = (
policy.actor.encoder(next_observations)
if policy.actor.encoder is not None
else None
)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
image_features=image_features,
next_image_features=next_image_features,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
optimizers["critic"].step()
with replay_buffer_lock:
batch = replay_buffer.sample(batch_size)
if cfg.dataset_repo_id is not None:
with offline_replay_buffer_lock:
batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
actions = batch["action"]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
# Precompute encoder features from the frozen vision encoder if enabled
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
with torch.no_grad():
image_features = (
policy.actor.encoder(observations)
if policy.actor.encoder is not None
else None
)
next_image_features = (
policy.actor.encoder(next_observations)
if policy.actor.encoder is not None
else None
)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
image_features=image_features,
next_image_features=next_image_features,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
optimizers["critic"].step()
training_infos = {}
training_infos["loss_critic"] = loss_critic.item()
if optimization_step % cfg.training.policy_update_freq == 0:
for _ in range(cfg.training.policy_update_freq):
with policy_lock:
loss_actor = policy.compute_loss_actor(observations=observations, image_features=image_features)
optimizers["actor"].zero_grad()
loss_actor.backward()
optimizers["actor"].step()
training_infos["loss_actor"] = loss_actor.item()
loss_temperature = policy.compute_loss_temperature(observations=observations, image_features=image_features)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
optimizers["temperature"].step()
training_infos["loss_temperature"] = loss_temperature.item()
policy.update_target_networks()
if optimization_step % cfg.training.log_freq == 0:
training_infos["Optimization step"] = optimization_step
logging_queue.put({
'info': training_infos,
'step_key': "Optimization step"
})
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
optimization_step += 1
if optimization_step % cfg.training.log_freq == 0:
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
if cfg.training.save_checkpoint and (
optimization_step % cfg.training.save_freq == 0 or optimization_step == cfg.training.online_steps
):
logging.info(f"Checkpoint policy after step {optimization_step}")
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
# needed (choose 6 as a minimum for consistency without being overkill).
_num_digits = max(6, len(str(cfg.training.online_steps)))
step_identifier = f"{optimization_step:0{_num_digits}d}"
logging_queue.put({
'checkpoint': {
'step': optimization_step,
'identifier': step_identifier,
}
})
# TODO : temporarly save replay buffer here, remove later when on the robot
# We want to control this with the keyboard inputs
dataset_dir = log_dir / "dataset"
if dataset_dir.exists() and dataset_dir.is_dir():
shutil.rmtree(
dataset_dir,
)
with replay_buffer_lock:
replay_buffer.to_lerobot_dataset(
cfg.dataset_repo_id, fps=cfg.fps, root=dataset_dir
)
logging.info("Resume training")
def make_optimizers_and_scheduler(cfg, policy: nn.Module):
"""
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
This function sets up Adam optimizers for:
- The **actor network**, ensuring that only relevant parameters are optimized.
- The **critic ensemble**, which evaluates the value function.
- The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods.
It also initializes a learning rate scheduler, though currently, it is set to `None`.
**NOTE:**
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
Args:
cfg: Configuration object containing hyperparameters.
policy (nn.Module): The policy model containing the actor, critic, and temperature components.
Returns:
Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]:
A tuple containing:
- `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers.
- `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling.
"""
optimizer_actor = torch.optim.Adam(
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
params=policy.actor.parameters_to_optimize,
lr=policy.config.actor_lr,
)
optimizer_critic = torch.optim.Adam(
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
lr_scheduler = None
optimizers = {
"actor": optimizer_actor,
"critic": optimizer_critic,
"temperature": optimizer_temperature,
}
return optimizers, lr_scheduler
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
# Initialize multiprocessing with spawn method for better compatibility
set_start_method('spawn', force=True)
if out_dir is None:
raise NotImplementedError()
if job_name is None:
raise NotImplementedError()
init_logging()
logging.info(pformat(OmegaConf.to_container(cfg)))
# Create our logger instance in the main process
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
cfg = handle_resume_logic(cfg, out_dir)
set_global_seed(cfg.seed)
device = get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
logging.info("make_policy")
### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy intance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
# TODO: At some point we should just need make sac policy
policy: SACPolicy = make_policy(
hydra_cfg=cfg,
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online traning, we do not need dataset_stats
dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
)
# compile policy
# policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers)
log_training_info(cfg, out_dir, policy)
replay_buffer = initialize_replay_buffer(cfg, logger, device)
batch_size = cfg.training.batch_size
offline_replay_buffer = None
if cfg.dataset_repo_id is not None:
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
logging.info("Convertion to a offline replay buffer")
active_action_dims = [i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask]
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
action_mask=active_action_dims,
action_delta=cfg.env.wrapper.delta_action,
use_shared_memory=True
)
batch_size = batch_size // 2
transition_queue = Queue(maxsize=1_000_000) # Set a maximum size
interaction_message_queue = Queue(maxsize=1_000_000) # Set a maximum size
logging_queue = Queue(maxsize=1_000_000) # Set a maximum size
processes = start_learner_threads(
cfg,
device,
replay_buffer,
offline_replay_buffer,
batch_size,
optimizers,
policy,
logger.log_dir,
transition_queue,
interaction_message_queue,
logging_queue,
resume_optimization_step,
resume_interaction_step,
)
# Consume log messages from the logging_queue in the main process
for p in processes:
p.start()
latest_interaction_step = resume_interaction_step if resume_interaction_step is not None else 0
while True:
try:
message = logging_queue.get(timeout=1)
if 'checkpoint' in message:
ckpt = message['checkpoint']
logger.save_checkpoint(
ckpt['step'],
policy,
optimizers,
scheduler=None,
identifier=ckpt['identifier'],
interaction_step=latest_interaction_step,
)
else:
if 'Interaction step' in message['info']:
latest_interaction_step = message['info']['Interaction step']
logger.log_dict(
message['info'],
mode="train",
custom_step_key=message['step_key']
)
except queue.Empty:
continue
except KeyboardInterrupt:
# Cleanup any remaining processes (if you want to terminate them here)
for p in processes:
if p.is_alive():
p.terminate()
p.join()
break
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
def train_cli(cfg: dict):
train(
cfg,
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
)
if __name__ == "__main__":
train_cli()