forked from tangger/lerobot
Added server directory in lerobot/scripts that contains scripts and the protobuf message types to split training into two processes, acting and learning. The actor rollouts the policy and collects interaction data while the learner recieves the data, trains the policy and sends the updated parameters to the actor. The two scripts are ran simultaneously
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
committed by
AdilZouitine
parent
83dc00683c
commit
ef64ba91d9
394
lerobot/scripts/server/learner_server.py
Normal file
394
lerobot/scripts/server/learner_server.py
Normal file
@@ -0,0 +1,394 @@
|
||||
import grpc
|
||||
from concurrent import futures
|
||||
import functools
|
||||
import logging
|
||||
import queue
|
||||
import pickle
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import io
|
||||
import time
|
||||
|
||||
from pprint import pformat
|
||||
import random
|
||||
from typing import Optional, Sequence, TypedDict, Callable
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from threading import Thread, Lock
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# TODO: Remove the import of maniskill
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.scripts.server.buffer import ReplayBuffer, move_transition_to_device, concatenate_batch_transitions, move_state_dict_to_device, Transition
|
||||
|
||||
# Import generated stubs
|
||||
import hilserl_pb2
|
||||
import hilserl_pb2_grpc
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
|
||||
# TODO: Implement it in cleaner way maybe
|
||||
transition_queue = queue.Queue()
|
||||
interaction_message_queue = queue.Queue()
|
||||
|
||||
|
||||
# 1) Implement the LearnerService so the Actor can send transitions here.
|
||||
class LearnerServiceServicer(hilserl_pb2_grpc.LearnerServiceServicer):
|
||||
# def SendTransition(self, request, context):
|
||||
# """
|
||||
# Actor calls this method to push a Transition -> Learner.
|
||||
# """
|
||||
# buffer = io.BytesIO(request.transition_bytes)
|
||||
# transition = torch.load(buffer)
|
||||
# transition_queue.put(transition)
|
||||
# return hilserl_pb2.Empty()
|
||||
def SendInteractionMessage(self, request, context):
|
||||
"""
|
||||
Actor calls this method to push a Transition -> Learner.
|
||||
"""
|
||||
content = pickle.loads(request.interaction_message_bytes)
|
||||
interaction_message_queue.put(content)
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
|
||||
|
||||
def stream_transitions_from_actor(port=50051):
|
||||
"""
|
||||
Runs a gRPC server listening for transitions from the Actor.
|
||||
"""
|
||||
time.sleep(10)
|
||||
channel = grpc.insecure_channel(f'127.0.0.1:{port}',
|
||||
options=[('grpc.max_send_message_length', -1),
|
||||
('grpc.max_receive_message_length', -1)])
|
||||
stub = hilserl_pb2_grpc.ActorServiceStub(channel)
|
||||
for response in stub.StreamTransition(hilserl_pb2.Empty()):
|
||||
if response.HasField('transition'):
|
||||
buffer = io.BytesIO(response.transition.transition_bytes)
|
||||
transition = torch.load(buffer)
|
||||
transition_queue.put(transition)
|
||||
if response.HasField('interaction_message'):
|
||||
content = pickle.loads(response.interaction_message.interaction_message_bytes)
|
||||
interaction_message_queue.put(content)
|
||||
# NOTE: Cool down the CPU, if you comment this line you will make a huge bottleneck
|
||||
time.sleep(0.001)
|
||||
|
||||
def learner_push_parameters(
|
||||
policy: nn.Module, policy_lock: Lock, actor_host="127.0.0.1", actor_port=50052, seconds_between_pushes=5
|
||||
):
|
||||
"""
|
||||
As a client, connect to the Actor's gRPC server (ActorService)
|
||||
and periodically push new parameters.
|
||||
"""
|
||||
time.sleep(10)
|
||||
# The Actor's server is presumably listening on a different port, e.g. 50052
|
||||
channel = grpc.insecure_channel(f"{actor_host}:{actor_port}",
|
||||
options=[('grpc.max_send_message_length', -1),
|
||||
('grpc.max_receive_message_length', -1)])
|
||||
actor_stub = hilserl_pb2_grpc.ActorServiceStub(channel)
|
||||
|
||||
while True:
|
||||
with policy_lock:
|
||||
params_dict = policy.actor.state_dict()
|
||||
params_dict = move_state_dict_to_device(params_dict, device="cpu")
|
||||
# Serialize
|
||||
buf = io.BytesIO()
|
||||
torch.save(params_dict, buf)
|
||||
params_bytes = buf.getvalue()
|
||||
|
||||
# Push them to the Actor’s "SendParameters" method
|
||||
response = actor_stub.SendParameters(hilserl_pb2.Parameters(parameter_bytes=params_bytes))
|
||||
time.sleep(seconds_between_pushes)
|
||||
|
||||
|
||||
# Checked
|
||||
def add_actor_information(
|
||||
cfg,
|
||||
device,
|
||||
replay_buffer: ReplayBuffer,
|
||||
offline_replay_buffer: ReplayBuffer,
|
||||
batch_size: int,
|
||||
optimizers,
|
||||
policy,
|
||||
policy_lock: Lock,
|
||||
buffer_lock: Lock,
|
||||
offline_buffer_lock: Lock,
|
||||
logger_lock: Lock,
|
||||
logger: Logger,
|
||||
):
|
||||
"""
|
||||
In a real application, you might run your training loop here,
|
||||
reading from the transition queue and doing gradient updates.
|
||||
"""
|
||||
# NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
|
||||
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
|
||||
# are divided by 200. So we need to have a single thread that does all the work.
|
||||
start = time.time()
|
||||
optimization_step = 0
|
||||
|
||||
while True:
|
||||
time_for_adding_transitions = time.time()
|
||||
while not transition_queue.empty():
|
||||
|
||||
transition_list = transition_queue.get()
|
||||
for transition in transition_list:
|
||||
transition = move_transition_to_device(transition, device=device)
|
||||
replay_buffer.add(**transition)
|
||||
logging.info(f"[LEARNER] size of replay buffer: {len(replay_buffer)}")
|
||||
logging.info(f"[LEARNER] size of transition queues: {transition_queue.qsize()}")
|
||||
|
||||
|
||||
while not interaction_message_queue.empty():
|
||||
interaction_message = interaction_message_queue.get()
|
||||
logger.log_dict(interaction_message,mode="train",custom_step_key="interaction_step")
|
||||
logging.info(f"[LEARNER] size of interaction message queue: {interaction_message_queue.qsize()}")
|
||||
|
||||
# if len(replay_buffer.memory) < cfg.training.online_step_before_learning:
|
||||
# continue
|
||||
|
||||
# for _ in range(cfg.policy.utd_ratio - 1):
|
||||
|
||||
# batch = replay_buffer.sample(batch_size)
|
||||
# if cfg.dataset_repo_id is not None:
|
||||
# batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
# batch = concatenate_batch_transitions(batch, batch_offline)
|
||||
|
||||
# actions = batch["action"]
|
||||
# rewards = batch["reward"]
|
||||
# observations = batch["state"]
|
||||
# next_observations = batch["next_state"]
|
||||
# done = batch["done"]
|
||||
|
||||
# with policy_lock:
|
||||
# loss_critic = policy.compute_loss_critic(
|
||||
# observations=observations,
|
||||
# actions=actions,
|
||||
# rewards=rewards,
|
||||
# next_observations=next_observations,
|
||||
# done=done,
|
||||
# )
|
||||
# optimizers["critic"].zero_grad()
|
||||
# loss_critic.backward()
|
||||
# optimizers["critic"].step()
|
||||
|
||||
# batch = replay_buffer.sample(batch_size)
|
||||
|
||||
# if cfg.dataset_repo_id is not None:
|
||||
# batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
# batch = concatenate_batch_transitions(
|
||||
# left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
# )
|
||||
|
||||
# actions = batch["action"]
|
||||
# rewards = batch["reward"]
|
||||
# observations = batch["state"]
|
||||
# next_observations = batch["next_state"]
|
||||
# done = batch["done"]
|
||||
|
||||
# with policy_lock:
|
||||
# loss_critic = policy.compute_loss_critic(
|
||||
# observations=observations,
|
||||
# actions=actions,
|
||||
# rewards=rewards,
|
||||
# next_observations=next_observations,
|
||||
# done=done,
|
||||
# )
|
||||
# optimizers["critic"].zero_grad()
|
||||
# loss_critic.backward()
|
||||
# optimizers["critic"].step()
|
||||
|
||||
# training_infos = {}
|
||||
# training_infos["loss_critic"] = loss_critic.item()
|
||||
|
||||
# if optimization_step % cfg.training.policy_update_freq == 0:
|
||||
# for _ in range(cfg.training.policy_update_freq):
|
||||
# with policy_lock:
|
||||
# loss_actor = policy.compute_loss_actor(observations=observations)
|
||||
|
||||
# optimizers["actor"].zero_grad()
|
||||
# loss_actor.backward()
|
||||
# optimizers["actor"].step()
|
||||
|
||||
# training_infos["loss_actor"] = loss_actor.item()
|
||||
|
||||
# loss_temperature = policy.compute_loss_temperature(observations=observations)
|
||||
# optimizers["temperature"].zero_grad()
|
||||
# loss_temperature.backward()
|
||||
# optimizers["temperature"].step()
|
||||
|
||||
# training_infos["loss_temperature"] = loss_temperature.item()
|
||||
|
||||
# if optimization_step % cfg.training.log_freq == 0:
|
||||
# logger.log_dict(training_infos, step=optimization_step, mode="train")
|
||||
|
||||
# policy.update_target_networks()
|
||||
# optimization_step += 1
|
||||
# time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
||||
|
||||
# logger.log_dict({"[LEARNER] Time optimization step":time_for_one_optimization_step}, step=optimization_step, mode="train")
|
||||
# time_for_one_optimization_step = time.time()
|
||||
|
||||
|
||||
def make_optimizers_and_scheduler(cfg, policy):
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
|
||||
params=policy.actor.parameters_to_optimize,
|
||||
lr=policy.config.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
|
||||
)
|
||||
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
"critic": optimizer_critic,
|
||||
"temperature": optimizer_temperature,
|
||||
}
|
||||
return optimizers, lr_scheduler
|
||||
|
||||
|
||||
|
||||
|
||||
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
if job_name is None:
|
||||
raise NotImplementedError()
|
||||
|
||||
init_logging()
|
||||
logging.info(pformat(OmegaConf.to_container(cfg)))
|
||||
|
||||
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
||||
logger_lock = Lock()
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("make_policy")
|
||||
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
# TODO: At some point we should just need make sac policy
|
||||
policy_lock = Lock()
|
||||
with logger_lock:
|
||||
policy: SACPolicy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online traning, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
device=device,
|
||||
)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||
|
||||
# TODO: Handle resume
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
log_output_dir(out_dir)
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.training.online_steps=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
buffer_lock = Lock()
|
||||
replay_buffer = ReplayBuffer(
|
||||
capacity=cfg.training.online_buffer_capacity, device=device, state_keys=cfg.policy.input_shapes.keys()
|
||||
)
|
||||
|
||||
batch_size = cfg.training.batch_size
|
||||
offline_buffer_lock = None
|
||||
offline_replay_buffer = None
|
||||
if cfg.dataset_repo_id is not None:
|
||||
logging.info("make_dataset offline buffer")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
logging.info("Convertion to a offline replay buffer")
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys()
|
||||
)
|
||||
offline_buffer_lock = Lock()
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
server_thread = Thread(target=stream_transitions_from_actor, args=(50051,), daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
|
||||
# Start a background thread to process transitions from the queue
|
||||
transition_thread = Thread(
|
||||
target=add_actor_information,
|
||||
daemon=True,
|
||||
args=(cfg,
|
||||
device,
|
||||
replay_buffer,
|
||||
offline_replay_buffer,
|
||||
batch_size,
|
||||
optimizers,
|
||||
policy,
|
||||
policy_lock,
|
||||
buffer_lock,
|
||||
offline_buffer_lock,
|
||||
logger_lock,
|
||||
logger),
|
||||
)
|
||||
transition_thread.start()
|
||||
|
||||
# param_push_thread = Thread(
|
||||
# target=learner_push_parameters,
|
||||
# args=(policy, policy_lock, "127.0.0.1", 50052, 15),
|
||||
# # args=("127.0.0.1", 50052),
|
||||
# daemon=True,
|
||||
# )
|
||||
# param_push_thread.start()
|
||||
|
||||
# interaction_thread = Thread(
|
||||
# target=add_message_interaction_to_wandb,
|
||||
# daemon=True,
|
||||
# args=(cfg, logger, logger_lock),
|
||||
# )
|
||||
# interaction_thread.start()
|
||||
|
||||
transition_thread.join()
|
||||
# param_push_thread.join()
|
||||
server_thread.join()
|
||||
# interaction_thread.join()
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
||||
def train_cli(cfg: dict):
|
||||
train(
|
||||
cfg,
|
||||
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
||||
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_cli()
|
||||
Reference in New Issue
Block a user