Clean the code

This commit is contained in:
AdilZouitine
2025-04-24 17:22:54 +02:00
parent b8c2b0bb93
commit a8da4a347e
4 changed files with 56 additions and 39 deletions

View File

@@ -474,14 +474,6 @@ class ManipulatorRobot:
before_fwrite_t = time.perf_counter()
goal_pos = leader_pos[name]
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
# if self.config.joint_position_relative_bounds is not None:
# goal_pos = torch.clamp(
# goal_pos,
# self.config.joint_position_relative_bounds["min"],
# self.config.joint_position_relative_bounds["max"],
# )
# Cap goal position when too far away from present position.
# Slower fps expected due to reading from the follower.
if self.config.max_relative_target is not None:
@@ -603,14 +595,6 @@ class ManipulatorRobot:
goal_pos = action[from_idx:to_idx]
from_idx = to_idx
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
# if self.config.joint_position_relative_bounds is not None:
# goal_pos = torch.clamp(
# goal_pos,
# self.config.joint_position_relative_bounds["min"],
# self.config.joint_position_relative_bounds["max"],
# )
# Cap goal position when too far away from present position.
# Slower fps expected due to reading from the follower.
if self.config.max_relative_target is not None:

View File

@@ -356,10 +356,19 @@ def act_with_policy(
def establish_learner_connection(
stub,
shutdown_event: any, # Event,
attempts=30,
stub: hilserl_pb2_grpc.LearnerServiceStub,
shutdown_event: Event, # type: ignore
attempts: int = 30,
):
"""Establish a connection with the learner.
Args:
stub (hilserl_pb2_grpc.LearnerServiceStub): The stub to use for the connection.
shutdown_event (Event): The event to check if the connection should be established.
attempts (int): The number of attempts to establish the connection.
Returns:
bool: True if the connection is established, False otherwise.
"""
for _ in range(attempts):
if shutdown_event.is_set():
logging.info("[ACTOR] Shutting down establish_learner_connection")
@@ -378,7 +387,8 @@ def establish_learner_connection(
@lru_cache(maxsize=1)
def learner_service_client(
host="127.0.0.1", port=50051
host: str = "127.0.0.1",
port: int = 50051,
) -> tuple[hilserl_pb2_grpc.LearnerServiceStub, grpc.Channel]:
import json
@@ -426,12 +436,18 @@ def learner_service_client(
def receive_policy(
cfg: TrainPipelineConfig,
parameters_queue: Queue,
shutdown_event: any, # Event,
shutdown_event: Event, # type: ignore
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
):
logging.info("[ACTOR] Start receiving parameters from the Learner")
"""Receive parameters from the learner.
Args:
cfg (TrainPipelineConfig): The configuration for the actor.
parameters_queue (Queue): The queue to receive the parameters.
shutdown_event (Event): The event to check if the process should shutdown.
"""
logging.info("[ACTOR] Start receiving parameters from the Learner")
if not use_threads(cfg):
# Create a process-specific log file
log_dir = os.path.join(cfg.output_dir, "logs")
@@ -481,7 +497,7 @@ def send_transitions(
This function continuously retrieves messages from the queue and processes:
- **Transition Data:**
- Transition Data:
- A batch of transitions (observation, action, reward, next observation) is collected.
- Transitions are moved to the CPU and serialized using PyTorch.
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
@@ -522,7 +538,7 @@ def send_transitions(
def send_interactions(
cfg: TrainPipelineConfig,
interactions_queue: Queue,
shutdown_event: any, # Event,
shutdown_event: Event, # type: ignore
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> hilserl_pb2.Empty:
@@ -531,7 +547,7 @@ def send_interactions(
This function continuously retrieves messages from the queue and processes:
- **Interaction Messages:**
- Interaction Messages:
- Contains useful statistics about episodic rewards and policy timings.
- The message is serialized using `pickle` and sent to the learner.
"""
@@ -568,7 +584,7 @@ def send_interactions(
logging.info("[ACTOR] Interactions process stopped")
def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilserl_pb2.Empty:
def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilserl_pb2.Empty: # type: ignore
while not shutdown_event.is_set():
try:
message = transitions_queue.get(block=True, timeout=5)
@@ -584,7 +600,7 @@ def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilse
def interactions_stream(
shutdown_event: any, # Event,
shutdown_event: Event, # type: ignore
interactions_queue: Queue,
) -> hilserl_pb2.Empty:
while not shutdown_event.is_set():
@@ -643,6 +659,14 @@ def push_transitions_to_transport_queue(transitions: list, transitions_queue):
def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
"""Get the frequency statistics of the policy.
Args:
list_policy_time (list[float]): The list of policy times.
Returns:
dict[str, float]: The frequency statistics of the policy.
"""
stats = {}
list_policy_fps = [1.0 / t for t in list_policy_time]
if len(list_policy_fps) > 1:

View File

@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import io
import pickle # nosec B403: Safe usage of pickle
@@ -194,6 +195,10 @@ class ReplayBuffer:
optimize_memory: bool = False,
):
"""
Replay buffer for storing transitions.
It will allocate tensors on the specified device, when the first transition is added.
NOTE: If you encounter memory issues, you can try to use the `optimize_memory` flag to save memory or
and use the `storage_device` flag to store the buffer on a different device.
Args:
capacity (int): Maximum number of transitions to store in the buffer.
device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu").
@@ -368,7 +373,7 @@ class ReplayBuffer:
all_images.append(batch_state[key])
all_images.append(batch_next_state[key])
# Batch all images and apply augmentation once
# Optimization: Batch all images and apply augmentation once
all_images_tensor = torch.cat(all_images, dim=0)
augmented_images = self.image_augmentation_function(all_images_tensor)

View File

@@ -256,7 +256,8 @@ def add_actor_information_and_train(
interaction_message_queue (Queue): Queue for receiving interaction messages from the actor.
parameters_queue (Queue): Queue for sending policy parameters to the actor.
"""
# Extract all configuration variables at the beginning
# Extract all configuration variables at the beginning, it improve the speed performance
# of 7%
device = get_safe_torch_device(try_device=cfg.policy.device, log=True)
storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device)
clip_grad_norm_value = cfg.policy.grad_clip_norm
@@ -283,11 +284,11 @@ def add_actor_information_and_train(
policy: SACPolicy = make_policy(
cfg=cfg.policy,
# ds_meta=cfg.dataset,
env_cfg=cfg.env,
)
assert isinstance(policy, nn.Module)
policy.train()
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
@@ -295,6 +296,8 @@ def add_actor_information_and_train(
last_time_policy_pushed = time.time()
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy)
# If we are resuming, we need to load the training state
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
log_training_info(cfg=cfg, policy=policy)
@@ -330,6 +333,7 @@ def add_actor_information_and_train(
# Initialize iterators
online_iterator = None
offline_iterator = None
# NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
while True:
# Exit the training loop if shutdown is requested
@@ -337,7 +341,7 @@ def add_actor_information_and_train(
logging.info("[LEARNER] Shutdown signal received. Exiting...")
break
# Process all available transitions
# Process all available transitions to the replay buffer, send by the actor server
logging.debug("[LEARNER] Waiting for transitions")
process_transitions(
transition_queue=transition_queue,
@@ -349,7 +353,7 @@ def add_actor_information_and_train(
)
logging.debug("[LEARNER] Received transitions")
# Process all available interaction messages
# Process all available interaction messages sent by the actor server
logging.debug("[LEARNER] Waiting for interactions")
interaction_message = process_interaction_messages(
interaction_message_queue=interaction_message_queue,
@@ -359,7 +363,7 @@ def add_actor_information_and_train(
)
logging.debug("[LEARNER] Received interactions")
# Wait until the replay buffer has enough samples
# Wait until the replay buffer has enough samples to start training
if len(replay_buffer) < online_step_before_learning:
continue
@@ -410,7 +414,7 @@ def add_actor_information_and_train(
"complementary_info": batch["complementary_info"],
}
# Use the forward method for critic loss (includes both main critic and discrete critic)
# Use the forward method for critic loss
critic_output = policy.forward(forward_batch, model="critic")
# Main critic optimization
@@ -433,7 +437,7 @@ def add_actor_information_and_train(
)
optimizers["discrete_critic"].step()
# Update target networks
# Update target networks (main and discrete)
policy.update_target_networks()
# Sample for the last update in the UTD ratio
@@ -468,10 +472,8 @@ def add_actor_information_and_train(
"next_observation_feature": next_observation_features,
}
# Use the forward method for critic loss (includes both main critic and discrete critic)
critic_output = policy.forward(forward_batch, model="critic")
# Main critic optimization
loss_critic = critic_output["loss_critic"]
optimizers["critic"].zero_grad()
loss_critic.backward()
@@ -541,7 +543,7 @@ def add_actor_information_and_train(
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
last_time_policy_pushed = time.time()
# Update target networks
# Update target networks (main and discrete)
policy.update_target_networks()
# Log training metrics at specified intervals
@@ -601,6 +603,8 @@ def start_learner_server(
):
"""
Start the learner server for training.
It will receive transitions and interaction messages from the actor server,
and send policy parameters to the actor server.
Args:
parameters_queue: Queue for sending policy parameters to the actor
@@ -756,7 +760,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
It also initializes a learning rate scheduler, though currently, it is set to `None`.
**NOTE:**
NOTE:
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.