From a8da4a347e2660e897e5c454791f9eb337e8172a Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Thu, 24 Apr 2025 17:22:54 +0200 Subject: [PATCH] Clean the code --- .../robot_devices/robots/manipulator.py | 16 ------- lerobot/scripts/server/actor_server.py | 46 ++++++++++++++----- lerobot/scripts/server/buffer.py | 7 ++- lerobot/scripts/server/learner_server.py | 26 ++++++----- 4 files changed, 56 insertions(+), 39 deletions(-) diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py index e14a5264..9aec3c4a 100644 --- a/lerobot/common/robot_devices/robots/manipulator.py +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -474,14 +474,6 @@ class ManipulatorRobot: before_fwrite_t = time.perf_counter() goal_pos = leader_pos[name] - # If specified, clip the goal positions within predefined bounds specified in the config of the robot - # if self.config.joint_position_relative_bounds is not None: - # goal_pos = torch.clamp( - # goal_pos, - # self.config.joint_position_relative_bounds["min"], - # self.config.joint_position_relative_bounds["max"], - # ) - # Cap goal position when too far away from present position. # Slower fps expected due to reading from the follower. if self.config.max_relative_target is not None: @@ -603,14 +595,6 @@ class ManipulatorRobot: goal_pos = action[from_idx:to_idx] from_idx = to_idx - # If specified, clip the goal positions within predefined bounds specified in the config of the robot - # if self.config.joint_position_relative_bounds is not None: - # goal_pos = torch.clamp( - # goal_pos, - # self.config.joint_position_relative_bounds["min"], - # self.config.joint_position_relative_bounds["max"], - # ) - # Cap goal position when too far away from present position. # Slower fps expected due to reading from the follower. if self.config.max_relative_target is not None: diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py index af7bf11f..3bf30b26 100644 --- a/lerobot/scripts/server/actor_server.py +++ b/lerobot/scripts/server/actor_server.py @@ -356,10 +356,19 @@ def act_with_policy( def establish_learner_connection( - stub, - shutdown_event: any, # Event, - attempts=30, + stub: hilserl_pb2_grpc.LearnerServiceStub, + shutdown_event: Event, # type: ignore + attempts: int = 30, ): + """Establish a connection with the learner. + + Args: + stub (hilserl_pb2_grpc.LearnerServiceStub): The stub to use for the connection. + shutdown_event (Event): The event to check if the connection should be established. + attempts (int): The number of attempts to establish the connection. + Returns: + bool: True if the connection is established, False otherwise. + """ for _ in range(attempts): if shutdown_event.is_set(): logging.info("[ACTOR] Shutting down establish_learner_connection") @@ -378,7 +387,8 @@ def establish_learner_connection( @lru_cache(maxsize=1) def learner_service_client( - host="127.0.0.1", port=50051 + host: str = "127.0.0.1", + port: int = 50051, ) -> tuple[hilserl_pb2_grpc.LearnerServiceStub, grpc.Channel]: import json @@ -426,12 +436,18 @@ def learner_service_client( def receive_policy( cfg: TrainPipelineConfig, parameters_queue: Queue, - shutdown_event: any, # Event, + shutdown_event: Event, # type: ignore learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None, grpc_channel: grpc.Channel | None = None, ): - logging.info("[ACTOR] Start receiving parameters from the Learner") + """Receive parameters from the learner. + Args: + cfg (TrainPipelineConfig): The configuration for the actor. + parameters_queue (Queue): The queue to receive the parameters. + shutdown_event (Event): The event to check if the process should shutdown. + """ + logging.info("[ACTOR] Start receiving parameters from the Learner") if not use_threads(cfg): # Create a process-specific log file log_dir = os.path.join(cfg.output_dir, "logs") @@ -481,7 +497,7 @@ def send_transitions( This function continuously retrieves messages from the queue and processes: - - **Transition Data:** + - Transition Data: - A batch of transitions (observation, action, reward, next observation) is collected. - Transitions are moved to the CPU and serialized using PyTorch. - The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner. @@ -522,7 +538,7 @@ def send_transitions( def send_interactions( cfg: TrainPipelineConfig, interactions_queue: Queue, - shutdown_event: any, # Event, + shutdown_event: Event, # type: ignore learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None, grpc_channel: grpc.Channel | None = None, ) -> hilserl_pb2.Empty: @@ -531,7 +547,7 @@ def send_interactions( This function continuously retrieves messages from the queue and processes: - - **Interaction Messages:** + - Interaction Messages: - Contains useful statistics about episodic rewards and policy timings. - The message is serialized using `pickle` and sent to the learner. """ @@ -568,7 +584,7 @@ def send_interactions( logging.info("[ACTOR] Interactions process stopped") -def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilserl_pb2.Empty: +def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilserl_pb2.Empty: # type: ignore while not shutdown_event.is_set(): try: message = transitions_queue.get(block=True, timeout=5) @@ -584,7 +600,7 @@ def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilse def interactions_stream( - shutdown_event: any, # Event, + shutdown_event: Event, # type: ignore interactions_queue: Queue, ) -> hilserl_pb2.Empty: while not shutdown_event.is_set(): @@ -643,6 +659,14 @@ def push_transitions_to_transport_queue(transitions: list, transitions_queue): def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]: + """Get the frequency statistics of the policy. + + Args: + list_policy_time (list[float]): The list of policy times. + + Returns: + dict[str, float]: The frequency statistics of the policy. + """ stats = {} list_policy_fps = [1.0 / t for t in list_policy_time] if len(list_policy_fps) > 1: diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 7e216ed6..a54016eb 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import functools import io import pickle # nosec B403: Safe usage of pickle @@ -194,6 +195,10 @@ class ReplayBuffer: optimize_memory: bool = False, ): """ + Replay buffer for storing transitions. + It will allocate tensors on the specified device, when the first transition is added. + NOTE: If you encounter memory issues, you can try to use the `optimize_memory` flag to save memory or + and use the `storage_device` flag to store the buffer on a different device. Args: capacity (int): Maximum number of transitions to store in the buffer. device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu"). @@ -368,7 +373,7 @@ class ReplayBuffer: all_images.append(batch_state[key]) all_images.append(batch_next_state[key]) - # Batch all images and apply augmentation once + # Optimization: Batch all images and apply augmentation once all_images_tensor = torch.cat(all_images, dim=0) augmented_images = self.image_augmentation_function(all_images_tensor) diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 3f86b2b3..b9247fa8 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -256,7 +256,8 @@ def add_actor_information_and_train( interaction_message_queue (Queue): Queue for receiving interaction messages from the actor. parameters_queue (Queue): Queue for sending policy parameters to the actor. """ - # Extract all configuration variables at the beginning + # Extract all configuration variables at the beginning, it improve the speed performance + # of 7% device = get_safe_torch_device(try_device=cfg.policy.device, log=True) storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device) clip_grad_norm_value = cfg.policy.grad_clip_norm @@ -283,11 +284,11 @@ def add_actor_information_and_train( policy: SACPolicy = make_policy( cfg=cfg.policy, - # ds_meta=cfg.dataset, env_cfg=cfg.env, ) assert isinstance(policy, nn.Module) + policy.train() push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) @@ -295,6 +296,8 @@ def add_actor_information_and_train( last_time_policy_pushed = time.time() optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy) + + # If we are resuming, we need to load the training state resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers) log_training_info(cfg=cfg, policy=policy) @@ -330,6 +333,7 @@ def add_actor_information_and_train( # Initialize iterators online_iterator = None offline_iterator = None + # NOTE: THIS IS THE MAIN LOOP OF THE LEARNER while True: # Exit the training loop if shutdown is requested @@ -337,7 +341,7 @@ def add_actor_information_and_train( logging.info("[LEARNER] Shutdown signal received. Exiting...") break - # Process all available transitions + # Process all available transitions to the replay buffer, send by the actor server logging.debug("[LEARNER] Waiting for transitions") process_transitions( transition_queue=transition_queue, @@ -349,7 +353,7 @@ def add_actor_information_and_train( ) logging.debug("[LEARNER] Received transitions") - # Process all available interaction messages + # Process all available interaction messages sent by the actor server logging.debug("[LEARNER] Waiting for interactions") interaction_message = process_interaction_messages( interaction_message_queue=interaction_message_queue, @@ -359,7 +363,7 @@ def add_actor_information_and_train( ) logging.debug("[LEARNER] Received interactions") - # Wait until the replay buffer has enough samples + # Wait until the replay buffer has enough samples to start training if len(replay_buffer) < online_step_before_learning: continue @@ -410,7 +414,7 @@ def add_actor_information_and_train( "complementary_info": batch["complementary_info"], } - # Use the forward method for critic loss (includes both main critic and discrete critic) + # Use the forward method for critic loss critic_output = policy.forward(forward_batch, model="critic") # Main critic optimization @@ -433,7 +437,7 @@ def add_actor_information_and_train( ) optimizers["discrete_critic"].step() - # Update target networks + # Update target networks (main and discrete) policy.update_target_networks() # Sample for the last update in the UTD ratio @@ -468,10 +472,8 @@ def add_actor_information_and_train( "next_observation_feature": next_observation_features, } - # Use the forward method for critic loss (includes both main critic and discrete critic) critic_output = policy.forward(forward_batch, model="critic") - # Main critic optimization loss_critic = critic_output["loss_critic"] optimizers["critic"].zero_grad() loss_critic.backward() @@ -541,7 +543,7 @@ def add_actor_information_and_train( push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) last_time_policy_pushed = time.time() - # Update target networks + # Update target networks (main and discrete) policy.update_target_networks() # Log training metrics at specified intervals @@ -601,6 +603,8 @@ def start_learner_server( ): """ Start the learner server for training. + It will receive transitions and interaction messages from the actor server, + and send policy parameters to the actor server. Args: parameters_queue: Queue for sending policy parameters to the actor @@ -756,7 +760,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module): It also initializes a learning rate scheduler, though currently, it is set to `None`. - **NOTE:** + NOTE: - If the encoder is shared, its parameters are excluded from the actor's optimization process. - The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.