forked from tangger/lerobot
Several fixes to move the actor_server and learner_server code from the maniskill environment to the real robot environment.
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
committed by
AdilZouitine
parent
434d1e0614
commit
f4f5b26a21
@@ -36,6 +36,8 @@ from termcolor import colored
|
||||
from torch import nn
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
|
||||
# TODO: Remove the import of maniskill
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
@@ -52,6 +54,7 @@ from lerobot.common.utils.utils import (
|
||||
)
|
||||
from lerobot.scripts.server.buffer import (
|
||||
ReplayBuffer,
|
||||
concatenate_batch_transitions,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
)
|
||||
@@ -259,8 +262,15 @@ def learner_push_parameters(
|
||||
while True:
|
||||
with policy_lock:
|
||||
params_dict = policy.actor.state_dict()
|
||||
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
|
||||
params_dict = {k: v for k, v in params_dict if not k.startswith("encoder.")}
|
||||
if policy.config.vision_encoder_name is not None:
|
||||
if policy.config.freeze_vision_encoder:
|
||||
params_dict: dict[str, torch.Tensor] = {
|
||||
k: v for k, v in params_dict.items() if not k.startswith("encoder.")
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
|
||||
)
|
||||
|
||||
params_dict = move_state_dict_to_device(params_dict, device="cpu")
|
||||
# Serialize
|
||||
@@ -322,6 +332,7 @@ def add_actor_information_and_train(
|
||||
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
|
||||
# are divided by 200. So we need to have a single thread that does all the work.
|
||||
time.time()
|
||||
logging.info("Starting learner thread")
|
||||
interaction_message, transition = None, None
|
||||
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
|
||||
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
|
||||
@@ -340,16 +351,21 @@ def add_actor_information_and_train(
|
||||
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
|
||||
interaction_message["Interaction step"] += interaction_step_shift
|
||||
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
|
||||
logging.info(f"Interaction message: {interaction_message}")
|
||||
|
||||
if len(replay_buffer) < cfg.training.online_step_before_learning:
|
||||
continue
|
||||
|
||||
# logging.info(f"Size of replay buffer: {len(replay_buffer)}")
|
||||
# logging.info(f"Size of offline replay buffer: {len(offline_replay_buffer)}")
|
||||
|
||||
time_for_one_optimization_step = time.time()
|
||||
for _ in range(cfg.policy.utd_ratio - 1):
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
|
||||
# if cfg.offline_dataset_repo_id is not None:
|
||||
# batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
# batch = concatenate_batch_transitions(batch, batch_offline)
|
||||
if cfg.dataset_repo_id is not None:
|
||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
batch = concatenate_batch_transitions(batch, batch_offline)
|
||||
|
||||
actions = batch["action"]
|
||||
rewards = batch["reward"]
|
||||
@@ -371,11 +387,11 @@ def add_actor_information_and_train(
|
||||
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
|
||||
# if cfg.offline_dataset_repo_id is not None:
|
||||
# batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
# batch = concatenate_batch_transitions(
|
||||
# left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
# )
|
||||
if cfg.dataset_repo_id is not None:
|
||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
||||
actions = batch["action"]
|
||||
rewards = batch["reward"]
|
||||
@@ -423,7 +439,7 @@ def add_actor_information_and_train(
|
||||
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
||||
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
|
||||
|
||||
logging.debug(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
|
||||
logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
|
||||
|
||||
logger.log_dict(
|
||||
{"Optimization frequency loop [Hz]": frequency_for_one_optimization_step},
|
||||
@@ -560,14 +576,14 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
batch_size = cfg.training.batch_size
|
||||
offline_replay_buffer = None
|
||||
|
||||
# if cfg.dataset_repo_id is not None:
|
||||
# logging.info("make_dataset offline buffer")
|
||||
# offline_dataset = make_dataset(cfg)
|
||||
# logging.info("Convertion to a offline replay buffer")
|
||||
# offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
# offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys()
|
||||
# )
|
||||
# batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
if cfg.dataset_repo_id is not None:
|
||||
logging.info("make_dataset offline buffer")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
logging.info("Convertion to a offline replay buffer")
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys()
|
||||
)
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
start_learner_threads(
|
||||
cfg,
|
||||
|
||||
Reference in New Issue
Block a user