- Added additional logging information in wandb around the timings of the policy loop and optimization loop.

- Optimized critic design that improves the performance of the learner loop by a factor of 2
- Cleaned the code and fixed style issues

- Completed the config with actor_learner_config field that contains host-ip and port elemnts that are necessary for the actor-learner servers.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi
2025-01-29 15:50:46 +00:00
parent 2ae657f568
commit 8cd44ae163
6 changed files with 461 additions and 313 deletions

View File

@@ -13,117 +13,123 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import logging
import functools
from pprint import pformat
import random
from typing import Optional, Sequence, TypedDict, Callable
import pickle
import queue
import time
from concurrent import futures
from statistics import mean, quantiles
import hydra
import torch
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
# TODO: Remove the import of maniskill
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env, make_maniskill_env
from lerobot.common.envs.utils import preprocess_observation, preprocess_maniskill_observation
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.utils import (
format_big_number,
get_safe_torch_device,
init_hydra_config,
init_logging,
set_global_seed,
)
# from lerobot.scripts.eval import eval_policy
from threading import Thread
import queue
import grpc
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc
import io
import time
import logging
from concurrent import futures
from threading import Thread
from lerobot.scripts.server.buffer import move_state_dict_to_device, move_transition_to_device, Transition
import hydra
import torch
from omegaconf import DictConfig
from torch import nn
import faulthandler
import signal
# TODO: Remove the import of maniskill
from lerobot.common.envs.factory import make_maniskill_env
from lerobot.common.envs.utils import preprocess_maniskill_observation
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.utils.utils import (
get_safe_torch_device,
set_global_seed,
)
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc
from lerobot.scripts.server.buffer import Transition, move_state_dict_to_device, move_transition_to_device
logging.basicConfig(level=logging.INFO)
parameters_queue = queue.Queue(maxsize=1)
message_queue = queue.Queue(maxsize=1_000_000)
class ActorInformation:
"""
This helper class is used to differentiate between two types of messages that are placed in the same queue during streaming:
- **Transition Data:** Contains experience tuples (observation, action, reward, next observation) collected during interaction.
- **Interaction Messages:** Encapsulates statistics related to the interaction process.
Attributes:
transition (Optional): Transition data to be sent to the learner.
interaction_message (Optional): Iteraction message providing additional statistics for logging.
"""
def __init__(self, transition=None, interaction_message=None):
self.transition = transition
self.interaction_message = interaction_message
# 1) Implement ActorService so the Learner can send parameters to this Actor.
class ActorServiceServicer(hilserl_pb2_grpc.ActorServiceServicer):
def StreamTransition(self, request, context):
"""
gRPC service for actor-learner communication in reinforcement learning.
This service is responsible for:
1. Streaming batches of transition data and statistical metrics from the actor to the learner.
2. Receiving updated network parameters from the learner.
"""
def StreamTransition(self, request, context): # noqa: N802
"""
Streams data from the actor to the learner.
This function continuously retrieves messages from the queue and processes them based on their type:
- **Transition Data:**
- A batch of transitions (observation, action, reward, next observation) is collected.
- Transitions are moved to the CPU and serialized using PyTorch.
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
- **Interaction Messages:**
- Contains useful statistics about episodic rewards and policy timings.
- The message is serialized using `pickle` and sent to the learner.
Yields:
hilserl_pb2.ActorInformation: The response message containing either transition data or an interaction message.
"""
while True:
# logging.info(f"[ACTOR] before message.empty()")
# logging.info(f"[ACTOR] size transition queue {message_queue.qsize()}")
# time.sleep(0.01)
# if message_queue.empty():
# continue
# logging.info(f"[ACTOR] after message.empty()")
start = time.time()
message = message_queue.get(block=True)
# logging.info(f"[ACTOR] Message queue get time {time.time() - start}")
if message.transition is not None:
# transition_to_send_to_learner = move_transition_to_device(message.transition, device="cpu")
transition_to_send_to_learner = [move_transition_to_device(T, device="cpu") for T in message.transition]
# logging.info(f"[ACTOR] Message queue get time {time.time() - start}")
transition_to_send_to_learner = [
move_transition_to_device(T, device="cpu") for T in message.transition
]
# Serialize it
buf = io.BytesIO()
torch.save(transition_to_send_to_learner, buf)
transition_bytes = buf.getvalue()
transition_message = hilserl_pb2.Transition(
transition_bytes=transition_bytes
)
response = hilserl_pb2.ActorInformation(
transition=transition_message
)
logging.info(f"[ACTOR] time to yield transition response {time.time() - start}")
logging.info(f"[ACTOR] size transition queue {message_queue.qsize()}")
transition_message = hilserl_pb2.Transition(transition_bytes=transition_bytes)
response = hilserl_pb2.ActorInformation(transition=transition_message)
elif message.interaction_message is not None:
# Serialize it and send it to the Learner's server
content = hilserl_pb2.InteractionMessage(
interaction_message_bytes=pickle.dumps(message.interaction_message)
)
response = hilserl_pb2.ActorInformation(
interaction_message=content
)
response = hilserl_pb2.ActorInformation(interaction_message=content)
# logging.info(f"[ACTOR] yield response before")
yield response
# logging.info(f"[ACTOR] response yielded after")
def SendParameters(self, request, context):
def SendParameters(self, request, context): # noqa: N802
"""
Learner calls this with updated Parameters -> Actor
Receives updated parameters from the learner and updates the actor.
The learner calls this method to send new model parameters. The received parameters are deserialized
and placed in a queue to be consumed by the actor.
Args:
request (hilserl_pb2.ParameterUpdate): The request containing serialized network parameters.
context (grpc.ServicerContext): The gRPC context.
Returns:
hilserl_pb2.Empty: An empty response to acknowledge receipt.
"""
# logging.info("[ACTOR] Received parameters from Learner.")
buffer = io.BytesIO(request.parameter_bytes)
params = torch.load(buffer)
parameters_queue.put(params)
@@ -132,38 +138,38 @@ class ActorServiceServicer(hilserl_pb2_grpc.ActorServiceServicer):
def serve_actor_service(port=50052):
"""
Runs a gRPC server so that the Learner can push parameters to the Actor.
Runs a gRPC server to start streaming the data from the actor to the learner.
Throught this server the learner can push parameters to the Actor as well.
"""
server = grpc.server(futures.ThreadPoolExecutor(max_workers=20),
options=[('grpc.max_send_message_length', -1),
('grpc.max_receive_message_length', -1)])
hilserl_pb2_grpc.add_ActorServiceServicer_to_server(
ActorServiceServicer(), server
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=20),
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
)
server.add_insecure_port(f'[::]:{port}')
hilserl_pb2_grpc.add_ActorServiceServicer_to_server(ActorServiceServicer(), server)
server.add_insecure_port(f"[::]:{port}")
server.start()
logging.info(f"[ACTOR] gRPC server listening on port {port}")
server.wait_for_termination()
def act_with_policy(cfg: DictConfig,
out_dir: str | None = None,
job_name: str | None = None):
if out_dir is None:
raise NotImplementedError()
if job_name is None:
raise NotImplementedError()
def act_with_policy(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
"""
Executes policy interaction within the environment.
This function rolls out the policy in the environment, collecting interaction data and pushing it to a queue for streaming to the learner.
Once an episode is completed, updated network parameters received from the learner are retrieved from a queue and loaded into the network.
Args:
cfg (DictConfig): Configuration settings for the interaction process.
out_dir (Optional[str]): Directory to store output logs or results. Defaults to None.
job_name (Optional[str]): Name of the job for logging or tracking purposes. Defaults to None.
"""
logging.info("make_env online")
# online_env = make_env(cfg, n_envs=1)
# TODO: Remove the import of maniskill and unifiy with make env
online_env = make_maniskill_env(cfg, n_envs=1)
if cfg.training.eval_freq > 0:
logging.info("make_env eval")
# eval_env = make_env(cfg, n_envs=1)
# TODO: Remove the import of maniskill and unifiy with make env
eval_env = make_maniskill_env(cfg, n_envs=1)
set_global_seed(cfg.seed)
device = get_safe_torch_device(cfg.device, log=True)
@@ -172,8 +178,7 @@ def act_with_policy(cfg: DictConfig,
torch.backends.cuda.matmul.allow_tf32 = True
logging.info("make_policy")
### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy intance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
@@ -181,7 +186,7 @@ def act_with_policy(cfg: DictConfig,
policy: SACPolicy = make_policy(
hydra_cfg=cfg,
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online traning, we do not need dataset_stats
# Hack: But if we do online training, we do not need dataset_stats
dataset_stats=None,
# TODO: Handle resume training
pretrained_policy_name_or_path=None,
@@ -195,17 +200,22 @@ def act_with_policy(cfg: DictConfig,
# obs = preprocess_observation(obs)
obs = preprocess_maniskill_observation(obs)
obs = {key: obs[key].to(device, non_blocking=True) for key in obs}
### ACTOR ==================
# NOTE: For the moment we will solely handle the case of a single environment
sum_reward_episode = 0
list_transition_to_send_to_learner = []
list_policy_fps = []
for interaction_step in range(cfg.training.online_steps):
# NOTE: At some point we should use a wrapper to handle the observation
# start = time.time()
if interaction_step >= cfg.training.online_step_before_learning:
start = time.perf_counter()
action = policy.select_action(batch=obs)
list_policy_fps.append(1.0 / (time.perf_counter() - start + 1e-9))
if list_policy_fps[-1] < cfg.fps:
logging.warning(
f"[ACTOR] policy frame rate {list_policy_fps[-1]} during interaction step {interaction_step} is below the required control frame rate {cfg.fps}"
)
next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy())
else:
action = online_env.action_space.sample()
@@ -213,70 +223,88 @@ def act_with_policy(cfg: DictConfig,
# HACK
action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True)
# logging.info(f"[ACTOR] Time for env step {time.time() - start}")
# HACK: For maniskill
# next_obs = preprocess_observation(next_obs)
next_obs = preprocess_maniskill_observation(next_obs)
next_obs = {key: next_obs[key].to(device, non_blocking=True) for key in obs}
sum_reward_episode += float(reward[0])
# Because we are using a single environment
# we can safely assume that the episode is done
# Because we are using a single environment we can index at zero
if done[0].item() or truncated[0].item():
# TODO: Handle logging for episode information
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
if not parameters_queue.empty():
logging.info("[ACTOR] Load new parameters from Learner.")
# Load new parameters from Learner
logging.debug("[ACTOR] Load new parameters from Learner.")
state_dict = parameters_queue.get()
state_dict = move_state_dict_to_device(state_dict, device=device)
policy.actor.load_state_dict(state_dict)
if len(list_transition_to_send_to_learner) > 0:
logging.info(f"[ACTOR] Sending {len(list_transition_to_send_to_learner)} transitions to Learner.")
logging.debug(
f"[ACTOR] Sending {len(list_transition_to_send_to_learner)} transitions to Learner."
)
message_queue.put(ActorInformation(transition=list_transition_to_send_to_learner))
list_transition_to_send_to_learner = []
stats = {}
if len(list_policy_fps) > 0:
policy_fps = mean(list_policy_fps)
quantiles_90 = quantiles(list_policy_fps, n=10)[-1]
logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}")
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}")
stats = {"Policy frequency [Hz]": policy_fps, "Policy frequency 90th-p [Hz]": quantiles_90}
list_policy_fps = []
# Send episodic reward to the learner
message_queue.put(ActorInformation(interaction_message={"episodic_reward": sum_reward_episode,"interaction_step": interaction_step}))
message_queue.put(
ActorInformation(
interaction_message={
"Episodic reward": sum_reward_episode,
"Interaction step": interaction_step,
**stats,
}
)
)
sum_reward_episode = 0.0
# ============================
# Prepare transition to send
# ============================
# Label the reward
# TODO (michel-aractingi): Label the reward
# if config.label_reward_on_actor:
# reward = reward_classifier(obs)
list_transition_to_send_to_learner.append(Transition(
# transition_to_send_to_learner = Transition(
state=obs,
action=action,
reward=reward,
next_state=next_obs,
done=done,
complementary_info=None,
)
list_transition_to_send_to_learner.append(
Transition(
state=obs,
action=action,
reward=reward,
next_state=next_obs,
done=done,
complementary_info=None,
)
)
# message_queue.put(ActorInformation(transition=transition_to_send_to_learner))
# assign obs to the next obs and continue the rollout
obs = next_obs
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
def actor_cli(cfg: dict):
server_thread = Thread(target=serve_actor_service, args=(50051,), daemon=True)
server_thread.start()
policy_thread = Thread(target=act_with_policy,
daemon=True,
args=(cfg,hydra.core.hydra_config.HydraConfig.get().run.dir, hydra.core.hydra_config.HydraConfig.get().job.name))
policy_thread.start()
policy_thread.join()
server_thread.join()
port = cfg.actor_learner_config.port
server_thread = Thread(target=serve_actor_service, args=(port,), daemon=True)
server_thread.start()
policy_thread = Thread(
target=act_with_policy,
daemon=True,
args=(
cfg,
hydra.core.hydra_config.HydraConfig.get().run.dir,
hydra.core.hydra_config.HydraConfig.get().job.name,
),
)
policy_thread.start()
policy_thread.join()
server_thread.join()
if __name__ == "__main__":
with open("traceback.log", "w") as f:
faulthandler.register(signal.SIGUSR1, file=f)
actor_cli()
actor_cli()