Clean the code
This commit is contained in:
@@ -474,14 +474,6 @@ class ManipulatorRobot:
|
||||
before_fwrite_t = time.perf_counter()
|
||||
goal_pos = leader_pos[name]
|
||||
|
||||
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
|
||||
# if self.config.joint_position_relative_bounds is not None:
|
||||
# goal_pos = torch.clamp(
|
||||
# goal_pos,
|
||||
# self.config.joint_position_relative_bounds["min"],
|
||||
# self.config.joint_position_relative_bounds["max"],
|
||||
# )
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
# Slower fps expected due to reading from the follower.
|
||||
if self.config.max_relative_target is not None:
|
||||
@@ -603,14 +595,6 @@ class ManipulatorRobot:
|
||||
goal_pos = action[from_idx:to_idx]
|
||||
from_idx = to_idx
|
||||
|
||||
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
|
||||
# if self.config.joint_position_relative_bounds is not None:
|
||||
# goal_pos = torch.clamp(
|
||||
# goal_pos,
|
||||
# self.config.joint_position_relative_bounds["min"],
|
||||
# self.config.joint_position_relative_bounds["max"],
|
||||
# )
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
# Slower fps expected due to reading from the follower.
|
||||
if self.config.max_relative_target is not None:
|
||||
|
||||
@@ -356,10 +356,19 @@ def act_with_policy(
|
||||
|
||||
|
||||
def establish_learner_connection(
|
||||
stub,
|
||||
shutdown_event: any, # Event,
|
||||
attempts=30,
|
||||
stub: hilserl_pb2_grpc.LearnerServiceStub,
|
||||
shutdown_event: Event, # type: ignore
|
||||
attempts: int = 30,
|
||||
):
|
||||
"""Establish a connection with the learner.
|
||||
|
||||
Args:
|
||||
stub (hilserl_pb2_grpc.LearnerServiceStub): The stub to use for the connection.
|
||||
shutdown_event (Event): The event to check if the connection should be established.
|
||||
attempts (int): The number of attempts to establish the connection.
|
||||
Returns:
|
||||
bool: True if the connection is established, False otherwise.
|
||||
"""
|
||||
for _ in range(attempts):
|
||||
if shutdown_event.is_set():
|
||||
logging.info("[ACTOR] Shutting down establish_learner_connection")
|
||||
@@ -378,7 +387,8 @@ def establish_learner_connection(
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def learner_service_client(
|
||||
host="127.0.0.1", port=50051
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 50051,
|
||||
) -> tuple[hilserl_pb2_grpc.LearnerServiceStub, grpc.Channel]:
|
||||
import json
|
||||
|
||||
@@ -426,12 +436,18 @@ def learner_service_client(
|
||||
def receive_policy(
|
||||
cfg: TrainPipelineConfig,
|
||||
parameters_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
shutdown_event: Event, # type: ignore
|
||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
):
|
||||
logging.info("[ACTOR] Start receiving parameters from the Learner")
|
||||
"""Receive parameters from the learner.
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): The configuration for the actor.
|
||||
parameters_queue (Queue): The queue to receive the parameters.
|
||||
shutdown_event (Event): The event to check if the process should shutdown.
|
||||
"""
|
||||
logging.info("[ACTOR] Start receiving parameters from the Learner")
|
||||
if not use_threads(cfg):
|
||||
# Create a process-specific log file
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
@@ -481,7 +497,7 @@ def send_transitions(
|
||||
|
||||
This function continuously retrieves messages from the queue and processes:
|
||||
|
||||
- **Transition Data:**
|
||||
- Transition Data:
|
||||
- A batch of transitions (observation, action, reward, next observation) is collected.
|
||||
- Transitions are moved to the CPU and serialized using PyTorch.
|
||||
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
|
||||
@@ -522,7 +538,7 @@ def send_transitions(
|
||||
def send_interactions(
|
||||
cfg: TrainPipelineConfig,
|
||||
interactions_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
shutdown_event: Event, # type: ignore
|
||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
) -> hilserl_pb2.Empty:
|
||||
@@ -531,7 +547,7 @@ def send_interactions(
|
||||
|
||||
This function continuously retrieves messages from the queue and processes:
|
||||
|
||||
- **Interaction Messages:**
|
||||
- Interaction Messages:
|
||||
- Contains useful statistics about episodic rewards and policy timings.
|
||||
- The message is serialized using `pickle` and sent to the learner.
|
||||
"""
|
||||
@@ -568,7 +584,7 @@ def send_interactions(
|
||||
logging.info("[ACTOR] Interactions process stopped")
|
||||
|
||||
|
||||
def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilserl_pb2.Empty:
|
||||
def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilserl_pb2.Empty: # type: ignore
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
message = transitions_queue.get(block=True, timeout=5)
|
||||
@@ -584,7 +600,7 @@ def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilse
|
||||
|
||||
|
||||
def interactions_stream(
|
||||
shutdown_event: any, # Event,
|
||||
shutdown_event: Event, # type: ignore
|
||||
interactions_queue: Queue,
|
||||
) -> hilserl_pb2.Empty:
|
||||
while not shutdown_event.is_set():
|
||||
@@ -643,6 +659,14 @@ def push_transitions_to_transport_queue(transitions: list, transitions_queue):
|
||||
|
||||
|
||||
def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
|
||||
"""Get the frequency statistics of the policy.
|
||||
|
||||
Args:
|
||||
list_policy_time (list[float]): The list of policy times.
|
||||
|
||||
Returns:
|
||||
dict[str, float]: The frequency statistics of the policy.
|
||||
"""
|
||||
stats = {}
|
||||
list_policy_fps = [1.0 / t for t in list_policy_time]
|
||||
if len(list_policy_fps) > 1:
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import io
|
||||
import pickle # nosec B403: Safe usage of pickle
|
||||
@@ -194,6 +195,10 @@ class ReplayBuffer:
|
||||
optimize_memory: bool = False,
|
||||
):
|
||||
"""
|
||||
Replay buffer for storing transitions.
|
||||
It will allocate tensors on the specified device, when the first transition is added.
|
||||
NOTE: If you encounter memory issues, you can try to use the `optimize_memory` flag to save memory or
|
||||
and use the `storage_device` flag to store the buffer on a different device.
|
||||
Args:
|
||||
capacity (int): Maximum number of transitions to store in the buffer.
|
||||
device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu").
|
||||
@@ -368,7 +373,7 @@ class ReplayBuffer:
|
||||
all_images.append(batch_state[key])
|
||||
all_images.append(batch_next_state[key])
|
||||
|
||||
# Batch all images and apply augmentation once
|
||||
# Optimization: Batch all images and apply augmentation once
|
||||
all_images_tensor = torch.cat(all_images, dim=0)
|
||||
augmented_images = self.image_augmentation_function(all_images_tensor)
|
||||
|
||||
|
||||
@@ -256,7 +256,8 @@ def add_actor_information_and_train(
|
||||
interaction_message_queue (Queue): Queue for receiving interaction messages from the actor.
|
||||
parameters_queue (Queue): Queue for sending policy parameters to the actor.
|
||||
"""
|
||||
# Extract all configuration variables at the beginning
|
||||
# Extract all configuration variables at the beginning, it improve the speed performance
|
||||
# of 7%
|
||||
device = get_safe_torch_device(try_device=cfg.policy.device, log=True)
|
||||
storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device)
|
||||
clip_grad_norm_value = cfg.policy.grad_clip_norm
|
||||
@@ -283,11 +284,11 @@ def add_actor_information_and_train(
|
||||
|
||||
policy: SACPolicy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
# ds_meta=cfg.dataset,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
policy.train()
|
||||
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
@@ -295,6 +296,8 @@ def add_actor_information_and_train(
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy)
|
||||
|
||||
# If we are resuming, we need to load the training state
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
|
||||
|
||||
log_training_info(cfg=cfg, policy=policy)
|
||||
@@ -330,6 +333,7 @@ def add_actor_information_and_train(
|
||||
# Initialize iterators
|
||||
online_iterator = None
|
||||
offline_iterator = None
|
||||
|
||||
# NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
|
||||
while True:
|
||||
# Exit the training loop if shutdown is requested
|
||||
@@ -337,7 +341,7 @@ def add_actor_information_and_train(
|
||||
logging.info("[LEARNER] Shutdown signal received. Exiting...")
|
||||
break
|
||||
|
||||
# Process all available transitions
|
||||
# Process all available transitions to the replay buffer, send by the actor server
|
||||
logging.debug("[LEARNER] Waiting for transitions")
|
||||
process_transitions(
|
||||
transition_queue=transition_queue,
|
||||
@@ -349,7 +353,7 @@ def add_actor_information_and_train(
|
||||
)
|
||||
logging.debug("[LEARNER] Received transitions")
|
||||
|
||||
# Process all available interaction messages
|
||||
# Process all available interaction messages sent by the actor server
|
||||
logging.debug("[LEARNER] Waiting for interactions")
|
||||
interaction_message = process_interaction_messages(
|
||||
interaction_message_queue=interaction_message_queue,
|
||||
@@ -359,7 +363,7 @@ def add_actor_information_and_train(
|
||||
)
|
||||
logging.debug("[LEARNER] Received interactions")
|
||||
|
||||
# Wait until the replay buffer has enough samples
|
||||
# Wait until the replay buffer has enough samples to start training
|
||||
if len(replay_buffer) < online_step_before_learning:
|
||||
continue
|
||||
|
||||
@@ -410,7 +414,7 @@ def add_actor_information_and_train(
|
||||
"complementary_info": batch["complementary_info"],
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss (includes both main critic and discrete critic)
|
||||
# Use the forward method for critic loss
|
||||
critic_output = policy.forward(forward_batch, model="critic")
|
||||
|
||||
# Main critic optimization
|
||||
@@ -433,7 +437,7 @@ def add_actor_information_and_train(
|
||||
)
|
||||
optimizers["discrete_critic"].step()
|
||||
|
||||
# Update target networks
|
||||
# Update target networks (main and discrete)
|
||||
policy.update_target_networks()
|
||||
|
||||
# Sample for the last update in the UTD ratio
|
||||
@@ -468,10 +472,8 @@ def add_actor_information_and_train(
|
||||
"next_observation_feature": next_observation_features,
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss (includes both main critic and discrete critic)
|
||||
critic_output = policy.forward(forward_batch, model="critic")
|
||||
|
||||
# Main critic optimization
|
||||
loss_critic = critic_output["loss_critic"]
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
@@ -541,7 +543,7 @@ def add_actor_information_and_train(
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
# Update target networks
|
||||
# Update target networks (main and discrete)
|
||||
policy.update_target_networks()
|
||||
|
||||
# Log training metrics at specified intervals
|
||||
@@ -601,6 +603,8 @@ def start_learner_server(
|
||||
):
|
||||
"""
|
||||
Start the learner server for training.
|
||||
It will receive transitions and interaction messages from the actor server,
|
||||
and send policy parameters to the actor server.
|
||||
|
||||
Args:
|
||||
parameters_queue: Queue for sending policy parameters to the actor
|
||||
@@ -756,7 +760,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
|
||||
|
||||
It also initializes a learning rate scheduler, though currently, it is set to `None`.
|
||||
|
||||
**NOTE:**
|
||||
NOTE:
|
||||
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
|
||||
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user