Clean the code

This commit is contained in:
AdilZouitine
2025-04-24 17:22:54 +02:00
parent b8c2b0bb93
commit a8da4a347e
4 changed files with 56 additions and 39 deletions

View File

@@ -474,14 +474,6 @@ class ManipulatorRobot:
before_fwrite_t = time.perf_counter() before_fwrite_t = time.perf_counter()
goal_pos = leader_pos[name] goal_pos = leader_pos[name]
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
# if self.config.joint_position_relative_bounds is not None:
# goal_pos = torch.clamp(
# goal_pos,
# self.config.joint_position_relative_bounds["min"],
# self.config.joint_position_relative_bounds["max"],
# )
# Cap goal position when too far away from present position. # Cap goal position when too far away from present position.
# Slower fps expected due to reading from the follower. # Slower fps expected due to reading from the follower.
if self.config.max_relative_target is not None: if self.config.max_relative_target is not None:
@@ -603,14 +595,6 @@ class ManipulatorRobot:
goal_pos = action[from_idx:to_idx] goal_pos = action[from_idx:to_idx]
from_idx = to_idx from_idx = to_idx
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
# if self.config.joint_position_relative_bounds is not None:
# goal_pos = torch.clamp(
# goal_pos,
# self.config.joint_position_relative_bounds["min"],
# self.config.joint_position_relative_bounds["max"],
# )
# Cap goal position when too far away from present position. # Cap goal position when too far away from present position.
# Slower fps expected due to reading from the follower. # Slower fps expected due to reading from the follower.
if self.config.max_relative_target is not None: if self.config.max_relative_target is not None:

View File

@@ -356,10 +356,19 @@ def act_with_policy(
def establish_learner_connection( def establish_learner_connection(
stub, stub: hilserl_pb2_grpc.LearnerServiceStub,
shutdown_event: any, # Event, shutdown_event: Event, # type: ignore
attempts=30, attempts: int = 30,
): ):
"""Establish a connection with the learner.
Args:
stub (hilserl_pb2_grpc.LearnerServiceStub): The stub to use for the connection.
shutdown_event (Event): The event to check if the connection should be established.
attempts (int): The number of attempts to establish the connection.
Returns:
bool: True if the connection is established, False otherwise.
"""
for _ in range(attempts): for _ in range(attempts):
if shutdown_event.is_set(): if shutdown_event.is_set():
logging.info("[ACTOR] Shutting down establish_learner_connection") logging.info("[ACTOR] Shutting down establish_learner_connection")
@@ -378,7 +387,8 @@ def establish_learner_connection(
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def learner_service_client( def learner_service_client(
host="127.0.0.1", port=50051 host: str = "127.0.0.1",
port: int = 50051,
) -> tuple[hilserl_pb2_grpc.LearnerServiceStub, grpc.Channel]: ) -> tuple[hilserl_pb2_grpc.LearnerServiceStub, grpc.Channel]:
import json import json
@@ -426,12 +436,18 @@ def learner_service_client(
def receive_policy( def receive_policy(
cfg: TrainPipelineConfig, cfg: TrainPipelineConfig,
parameters_queue: Queue, parameters_queue: Queue,
shutdown_event: any, # Event, shutdown_event: Event, # type: ignore
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None, learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None, grpc_channel: grpc.Channel | None = None,
): ):
logging.info("[ACTOR] Start receiving parameters from the Learner") """Receive parameters from the learner.
Args:
cfg (TrainPipelineConfig): The configuration for the actor.
parameters_queue (Queue): The queue to receive the parameters.
shutdown_event (Event): The event to check if the process should shutdown.
"""
logging.info("[ACTOR] Start receiving parameters from the Learner")
if not use_threads(cfg): if not use_threads(cfg):
# Create a process-specific log file # Create a process-specific log file
log_dir = os.path.join(cfg.output_dir, "logs") log_dir = os.path.join(cfg.output_dir, "logs")
@@ -481,7 +497,7 @@ def send_transitions(
This function continuously retrieves messages from the queue and processes: This function continuously retrieves messages from the queue and processes:
- **Transition Data:** - Transition Data:
- A batch of transitions (observation, action, reward, next observation) is collected. - A batch of transitions (observation, action, reward, next observation) is collected.
- Transitions are moved to the CPU and serialized using PyTorch. - Transitions are moved to the CPU and serialized using PyTorch.
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner. - The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
@@ -522,7 +538,7 @@ def send_transitions(
def send_interactions( def send_interactions(
cfg: TrainPipelineConfig, cfg: TrainPipelineConfig,
interactions_queue: Queue, interactions_queue: Queue,
shutdown_event: any, # Event, shutdown_event: Event, # type: ignore
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None, learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None, grpc_channel: grpc.Channel | None = None,
) -> hilserl_pb2.Empty: ) -> hilserl_pb2.Empty:
@@ -531,7 +547,7 @@ def send_interactions(
This function continuously retrieves messages from the queue and processes: This function continuously retrieves messages from the queue and processes:
- **Interaction Messages:** - Interaction Messages:
- Contains useful statistics about episodic rewards and policy timings. - Contains useful statistics about episodic rewards and policy timings.
- The message is serialized using `pickle` and sent to the learner. - The message is serialized using `pickle` and sent to the learner.
""" """
@@ -568,7 +584,7 @@ def send_interactions(
logging.info("[ACTOR] Interactions process stopped") logging.info("[ACTOR] Interactions process stopped")
def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilserl_pb2.Empty: def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilserl_pb2.Empty: # type: ignore
while not shutdown_event.is_set(): while not shutdown_event.is_set():
try: try:
message = transitions_queue.get(block=True, timeout=5) message = transitions_queue.get(block=True, timeout=5)
@@ -584,7 +600,7 @@ def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilse
def interactions_stream( def interactions_stream(
shutdown_event: any, # Event, shutdown_event: Event, # type: ignore
interactions_queue: Queue, interactions_queue: Queue,
) -> hilserl_pb2.Empty: ) -> hilserl_pb2.Empty:
while not shutdown_event.is_set(): while not shutdown_event.is_set():
@@ -643,6 +659,14 @@ def push_transitions_to_transport_queue(transitions: list, transitions_queue):
def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]: def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
"""Get the frequency statistics of the policy.
Args:
list_policy_time (list[float]): The list of policy times.
Returns:
dict[str, float]: The frequency statistics of the policy.
"""
stats = {} stats = {}
list_policy_fps = [1.0 / t for t in list_policy_time] list_policy_fps = [1.0 / t for t in list_policy_time]
if len(list_policy_fps) > 1: if len(list_policy_fps) > 1:

View File

@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import functools import functools
import io import io
import pickle # nosec B403: Safe usage of pickle import pickle # nosec B403: Safe usage of pickle
@@ -194,6 +195,10 @@ class ReplayBuffer:
optimize_memory: bool = False, optimize_memory: bool = False,
): ):
""" """
Replay buffer for storing transitions.
It will allocate tensors on the specified device, when the first transition is added.
NOTE: If you encounter memory issues, you can try to use the `optimize_memory` flag to save memory or
and use the `storage_device` flag to store the buffer on a different device.
Args: Args:
capacity (int): Maximum number of transitions to store in the buffer. capacity (int): Maximum number of transitions to store in the buffer.
device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu"). device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu").
@@ -368,7 +373,7 @@ class ReplayBuffer:
all_images.append(batch_state[key]) all_images.append(batch_state[key])
all_images.append(batch_next_state[key]) all_images.append(batch_next_state[key])
# Batch all images and apply augmentation once # Optimization: Batch all images and apply augmentation once
all_images_tensor = torch.cat(all_images, dim=0) all_images_tensor = torch.cat(all_images, dim=0)
augmented_images = self.image_augmentation_function(all_images_tensor) augmented_images = self.image_augmentation_function(all_images_tensor)

View File

@@ -256,7 +256,8 @@ def add_actor_information_and_train(
interaction_message_queue (Queue): Queue for receiving interaction messages from the actor. interaction_message_queue (Queue): Queue for receiving interaction messages from the actor.
parameters_queue (Queue): Queue for sending policy parameters to the actor. parameters_queue (Queue): Queue for sending policy parameters to the actor.
""" """
# Extract all configuration variables at the beginning # Extract all configuration variables at the beginning, it improve the speed performance
# of 7%
device = get_safe_torch_device(try_device=cfg.policy.device, log=True) device = get_safe_torch_device(try_device=cfg.policy.device, log=True)
storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device) storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device)
clip_grad_norm_value = cfg.policy.grad_clip_norm clip_grad_norm_value = cfg.policy.grad_clip_norm
@@ -283,11 +284,11 @@ def add_actor_information_and_train(
policy: SACPolicy = make_policy( policy: SACPolicy = make_policy(
cfg=cfg.policy, cfg=cfg.policy,
# ds_meta=cfg.dataset,
env_cfg=cfg.env, env_cfg=cfg.env,
) )
assert isinstance(policy, nn.Module) assert isinstance(policy, nn.Module)
policy.train() policy.train()
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
@@ -295,6 +296,8 @@ def add_actor_information_and_train(
last_time_policy_pushed = time.time() last_time_policy_pushed = time.time()
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy) optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy)
# If we are resuming, we need to load the training state
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers) resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
log_training_info(cfg=cfg, policy=policy) log_training_info(cfg=cfg, policy=policy)
@@ -330,6 +333,7 @@ def add_actor_information_and_train(
# Initialize iterators # Initialize iterators
online_iterator = None online_iterator = None
offline_iterator = None offline_iterator = None
# NOTE: THIS IS THE MAIN LOOP OF THE LEARNER # NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
while True: while True:
# Exit the training loop if shutdown is requested # Exit the training loop if shutdown is requested
@@ -337,7 +341,7 @@ def add_actor_information_and_train(
logging.info("[LEARNER] Shutdown signal received. Exiting...") logging.info("[LEARNER] Shutdown signal received. Exiting...")
break break
# Process all available transitions # Process all available transitions to the replay buffer, send by the actor server
logging.debug("[LEARNER] Waiting for transitions") logging.debug("[LEARNER] Waiting for transitions")
process_transitions( process_transitions(
transition_queue=transition_queue, transition_queue=transition_queue,
@@ -349,7 +353,7 @@ def add_actor_information_and_train(
) )
logging.debug("[LEARNER] Received transitions") logging.debug("[LEARNER] Received transitions")
# Process all available interaction messages # Process all available interaction messages sent by the actor server
logging.debug("[LEARNER] Waiting for interactions") logging.debug("[LEARNER] Waiting for interactions")
interaction_message = process_interaction_messages( interaction_message = process_interaction_messages(
interaction_message_queue=interaction_message_queue, interaction_message_queue=interaction_message_queue,
@@ -359,7 +363,7 @@ def add_actor_information_and_train(
) )
logging.debug("[LEARNER] Received interactions") logging.debug("[LEARNER] Received interactions")
# Wait until the replay buffer has enough samples # Wait until the replay buffer has enough samples to start training
if len(replay_buffer) < online_step_before_learning: if len(replay_buffer) < online_step_before_learning:
continue continue
@@ -410,7 +414,7 @@ def add_actor_information_and_train(
"complementary_info": batch["complementary_info"], "complementary_info": batch["complementary_info"],
} }
# Use the forward method for critic loss (includes both main critic and discrete critic) # Use the forward method for critic loss
critic_output = policy.forward(forward_batch, model="critic") critic_output = policy.forward(forward_batch, model="critic")
# Main critic optimization # Main critic optimization
@@ -433,7 +437,7 @@ def add_actor_information_and_train(
) )
optimizers["discrete_critic"].step() optimizers["discrete_critic"].step()
# Update target networks # Update target networks (main and discrete)
policy.update_target_networks() policy.update_target_networks()
# Sample for the last update in the UTD ratio # Sample for the last update in the UTD ratio
@@ -468,10 +472,8 @@ def add_actor_information_and_train(
"next_observation_feature": next_observation_features, "next_observation_feature": next_observation_features,
} }
# Use the forward method for critic loss (includes both main critic and discrete critic)
critic_output = policy.forward(forward_batch, model="critic") critic_output = policy.forward(forward_batch, model="critic")
# Main critic optimization
loss_critic = critic_output["loss_critic"] loss_critic = critic_output["loss_critic"]
optimizers["critic"].zero_grad() optimizers["critic"].zero_grad()
loss_critic.backward() loss_critic.backward()
@@ -541,7 +543,7 @@ def add_actor_information_and_train(
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
last_time_policy_pushed = time.time() last_time_policy_pushed = time.time()
# Update target networks # Update target networks (main and discrete)
policy.update_target_networks() policy.update_target_networks()
# Log training metrics at specified intervals # Log training metrics at specified intervals
@@ -601,6 +603,8 @@ def start_learner_server(
): ):
""" """
Start the learner server for training. Start the learner server for training.
It will receive transitions and interaction messages from the actor server,
and send policy parameters to the actor server.
Args: Args:
parameters_queue: Queue for sending policy parameters to the actor parameters_queue: Queue for sending policy parameters to the actor
@@ -756,7 +760,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
It also initializes a learning rate scheduler, though currently, it is set to `None`. It also initializes a learning rate scheduler, though currently, it is set to `None`.
**NOTE:** NOTE:
- If the encoder is shared, its parameters are excluded from the actor's optimization process. - If the encoder is shared, its parameters are excluded from the actor's optimization process.
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor. - The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.