Refactor modeling_sac and parameter handling for clarity and reusability.

Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
AdilZouitine
2025-04-14 14:00:57 +00:00
committed by Michel Aractingi
parent 267a837a2c
commit 9386892f8e
2 changed files with 67 additions and 45 deletions

View File

@@ -510,7 +510,7 @@ def add_actor_information_and_train(
optimizers["actor"].zero_grad()
loss_actor.backward()
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value
).item()
optimizers["actor"].step()
@@ -773,12 +773,14 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
"""
optimizer_actor = torch.optim.Adam(
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
params=policy.actor.parameters_to_optimize,
params=[
p
for n, p in policy.actor.named_parameters()
if not n.startswith("encoder") or not policy.config.shared_encoder
],
lr=cfg.policy.actor_lr,
)
optimizer_critic = torch.optim.Adam(
params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr
)
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
if cfg.policy.num_discrete_actions is not None:
optimizer_grasp_critic = torch.optim.Adam(
@@ -1089,6 +1091,44 @@ def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
parameters_queue.put(state_bytes)
def check_weight_gradients(module: nn.Module) -> dict[str, bool]:
"""
Checks whether each parameter in the module has a gradient.
Args:
module (nn.Module): A PyTorch module whose parameters will be inspected.
Returns:
dict[str, bool]: A dictionary where each key is the parameter name and the value is
True if the parameter has an associated gradient (i.e. .grad is not None),
otherwise False.
"""
grad_status = {}
for name, param in module.named_parameters():
grad_status[name] = param.grad is not None
return grad_status
def get_overlapping_parameters(model: nn.Module, grad_status: dict[str, bool]) -> dict[str, bool]:
"""
Returns a dictionary of parameters (from actor) that also exist in the grad_status dictionary.
Args:
actor (nn.Module): The actor model.
grad_status (dict[str, bool]): A dictionary where keys are parameter names and values indicate
whether each parameter has a gradient.
Returns:
dict[str, bool]: A dictionary containing only the overlapping parameter names and their gradient status.
"""
# Get actor parameter names as a set.
model_param_names = {name for name, _ in model.named_parameters()}
# Intersect parameter names between actor and grad_status.
overlapping = {name: grad_status[name] for name in grad_status if name in model_param_names}
return overlapping
def process_interaction_message(
message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None
):