Several fixes to move the actor_server and learner_server code from the maniskill environment to the real robot environment.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi
2025-02-10 16:03:39 +01:00
committed by AdilZouitine
parent 434d1e0614
commit f4f5b26a21
10 changed files with 597 additions and 318 deletions

View File

@@ -36,6 +36,8 @@ from termcolor import colored
from torch import nn
from torch.optim.optimizer import Optimizer
from lerobot.common.datasets.factory import make_dataset
# TODO: Remove the import of maniskill
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger, log_output_dir
@@ -52,6 +54,7 @@ from lerobot.common.utils.utils import (
)
from lerobot.scripts.server.buffer import (
ReplayBuffer,
concatenate_batch_transitions,
move_state_dict_to_device,
move_transition_to_device,
)
@@ -259,8 +262,15 @@ def learner_push_parameters(
while True:
with policy_lock:
params_dict = policy.actor.state_dict()
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
params_dict = {k: v for k, v in params_dict if not k.startswith("encoder.")}
if policy.config.vision_encoder_name is not None:
if policy.config.freeze_vision_encoder:
params_dict: dict[str, torch.Tensor] = {
k: v for k, v in params_dict.items() if not k.startswith("encoder.")
}
else:
raise NotImplementedError(
"Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
)
params_dict = move_state_dict_to_device(params_dict, device="cpu")
# Serialize
@@ -322,6 +332,7 @@ def add_actor_information_and_train(
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
# are divided by 200. So we need to have a single thread that does all the work.
time.time()
logging.info("Starting learner thread")
interaction_message, transition = None, None
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
@@ -340,16 +351,21 @@ def add_actor_information_and_train(
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
interaction_message["Interaction step"] += interaction_step_shift
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
logging.info(f"Interaction message: {interaction_message}")
if len(replay_buffer) < cfg.training.online_step_before_learning:
continue
# logging.info(f"Size of replay buffer: {len(replay_buffer)}")
# logging.info(f"Size of offline replay buffer: {len(offline_replay_buffer)}")
time_for_one_optimization_step = time.time()
for _ in range(cfg.policy.utd_ratio - 1):
batch = replay_buffer.sample(batch_size)
# if cfg.offline_dataset_repo_id is not None:
# batch_offline = offline_replay_buffer.sample(batch_size)
# batch = concatenate_batch_transitions(batch, batch_offline)
if cfg.dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(batch, batch_offline)
actions = batch["action"]
rewards = batch["reward"]
@@ -371,11 +387,11 @@ def add_actor_information_and_train(
batch = replay_buffer.sample(batch_size)
# if cfg.offline_dataset_repo_id is not None:
# batch_offline = offline_replay_buffer.sample(batch_size)
# batch = concatenate_batch_transitions(
# left_batch_transitions=batch, right_batch_transition=batch_offline
# )
if cfg.dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
actions = batch["action"]
rewards = batch["reward"]
@@ -423,7 +439,7 @@ def add_actor_information_and_train(
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
logging.debug(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
logger.log_dict(
{"Optimization frequency loop [Hz]": frequency_for_one_optimization_step},
@@ -560,14 +576,14 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
batch_size = cfg.training.batch_size
offline_replay_buffer = None
# if cfg.dataset_repo_id is not None:
# logging.info("make_dataset offline buffer")
# offline_dataset = make_dataset(cfg)
# logging.info("Convertion to a offline replay buffer")
# offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
# offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys()
# )
# batch_size: int = batch_size // 2 # We will sample from both replay buffer
if cfg.dataset_repo_id is not None:
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
logging.info("Convertion to a offline replay buffer")
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys()
)
batch_size: int = batch_size // 2 # We will sample from both replay buffer
start_learner_threads(
cfg,