Refactor SACPolicy and configuration to replace 'grasp_critic' terminology with 'discrete_critic'. Update related methods and comments for clarity and consistency in handling discrete actions.

This commit is contained in:
AdilZouitine
2025-04-18 14:38:22 +00:00
parent 0d70f0b85c
commit a7a51cfc9c
4 changed files with 70 additions and 68 deletions

View File

@@ -179,7 +179,7 @@ class SACConfig(PreTrainedConfig):
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
grasp_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)

View File

@@ -71,7 +71,7 @@ class SACPolicy(
"temperature": self.log_alpha,
}
if self.config.num_discrete_actions is not None:
optim_params["grasp_critic"] = self.grasp_critic.parameters()
optim_params["discrete_critic"] = self.discrete_critic.parameters()
return optim_params
def reset(self):
@@ -90,7 +90,7 @@ class SACPolicy(
actions = self.unnormalize_outputs({"action": actions})["action"]
if self.config.num_discrete_actions is not None:
discrete_action_value = self.grasp_critic(batch, observations_features)
discrete_action_value = self.discrete_critic(batch, observations_features)
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
actions = torch.cat([actions, discrete_action], dim=-1)
@@ -118,8 +118,10 @@ class SACPolicy(
q_values = critics(observations, actions, observation_features)
return q_values
def grasp_critic_forward(self, observations, use_target=False, observation_features=None) -> torch.Tensor:
"""Forward pass through a grasp critic network
def discrete_critic_forward(
self, observations, use_target=False, observation_features=None
) -> torch.Tensor:
"""Forward pass through a discrete critic network
Args:
observations: Dictionary of observations
@@ -127,16 +129,16 @@ class SACPolicy(
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
Returns:
Tensor of Q-values from the grasp critic network
Tensor of Q-values from the discrete critic network
"""
grasp_critic = self.grasp_critic_target if use_target else self.grasp_critic
q_values = grasp_critic(observations, observation_features)
discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic
q_values = discrete_critic(observations, observation_features)
return q_values
def forward(
self,
batch: dict[str, Tensor | dict[str, Tensor]],
model: Literal["actor", "critic", "temperature", "grasp_critic"] = "critic",
model: Literal["actor", "critic", "temperature", "discrete_critic"] = "critic",
) -> dict[str, Tensor]:
"""Compute the loss for the given model
@@ -149,7 +151,7 @@ class SACPolicy(
- done: Done mask tensor
- observation_feature: Optional pre-computed observation features
- next_observation_feature: Optional pre-computed next observation features
model: Which model to compute the loss for ("actor", "critic", "grasp_critic", or "temperature")
model: Which model to compute the loss for ("actor", "critic", "discrete_critic", or "temperature")
Returns:
The computed loss tensor
@@ -178,14 +180,14 @@ class SACPolicy(
return {"loss_critic": loss_critic}
if model == "grasp_critic" and self.config.num_discrete_actions is not None:
if model == "discrete_critic" and self.config.num_discrete_actions is not None:
# Extract critic-specific components
rewards: Tensor = batch["reward"]
next_observations: dict[str, Tensor] = batch["next_state"]
done: Tensor = batch["done"]
next_observation_features: Tensor = batch.get("next_observation_feature")
complementary_info = batch.get("complementary_info")
loss_grasp_critic = self.compute_loss_grasp_critic(
loss_discrete_critic = self.compute_loss_discrete_critic(
observations=observations,
actions=actions,
rewards=rewards,
@@ -195,7 +197,7 @@ class SACPolicy(
next_observation_features=next_observation_features,
complementary_info=complementary_info,
)
return {"loss_grasp_critic": loss_grasp_critic}
return {"loss_discrete_critic": loss_discrete_critic}
if model == "actor":
return {
"loss_actor": self.compute_loss_actor(
@@ -227,8 +229,8 @@ class SACPolicy(
)
if self.config.num_discrete_actions is not None:
for target_param, param in zip(
self.grasp_critic_target.parameters(),
self.grasp_critic.parameters(),
self.discrete_critic_target.parameters(),
self.discrete_critic.parameters(),
strict=False,
):
target_param.data.copy_(
@@ -302,7 +304,7 @@ class SACPolicy(
).sum()
return critics_loss
def compute_loss_grasp_critic(
def compute_loss_discrete_critic(
self,
observations,
actions,
@@ -320,46 +322,46 @@ class SACPolicy(
actions_discrete = torch.round(actions_discrete)
actions_discrete = actions_discrete.long()
gripper_penalties: Tensor | None = None
discrete_penalties: Tensor | None = None
if complementary_info is not None:
gripper_penalties: Tensor | None = complementary_info.get("gripper_penalty")
discrete_penalties: Tensor | None = complementary_info.get("discrete_penalty")
with torch.no_grad():
# For DQN, select actions using online network, evaluate with target network
next_grasp_qs = self.grasp_critic_forward(
next_discrete_qs = self.discrete_critic_forward(
next_observations, use_target=False, observation_features=next_observation_features
)
best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1, keepdim=True)
best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
# Get target Q-values from target network
target_next_grasp_qs = self.grasp_critic_forward(
target_next_discrete_qs = self.discrete_critic_forward(
observations=next_observations,
use_target=True,
observation_features=next_observation_features,
)
# Use gather to select Q-values for best actions
target_next_grasp_q = torch.gather(
target_next_grasp_qs, dim=1, index=best_next_grasp_action
target_next_discrete_q = torch.gather(
target_next_discrete_qs, dim=1, index=best_next_discrete_action
).squeeze(-1)
# Compute target Q-value with Bellman equation
rewards_gripper = rewards
if gripper_penalties is not None:
rewards_gripper = rewards + gripper_penalties
target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
rewards_discrete = rewards
if discrete_penalties is not None:
rewards_discrete = rewards + discrete_penalties
target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q
# Get predicted Q-values for current observations
predicted_grasp_qs = self.grasp_critic_forward(
predicted_discrete_qs = self.discrete_critic_forward(
observations=observations, use_target=False, observation_features=observation_features
)
# Use gather to select Q-values for taken actions
predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions_discrete).squeeze(-1)
predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1)
# Compute MSE loss between predicted and target Q-values
grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q)
return grasp_critic_loss
discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q)
return discrete_critic_loss
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
"""Compute the temperature loss"""
@@ -419,7 +421,7 @@ class SACPolicy(
)
def _init_critics(self, continuous_action_dim):
"""Build critic ensemble, targets, and optional grasp critic."""
"""Build critic ensemble, targets, and optional discrete critic."""
heads = [
CriticHead(
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
@@ -446,25 +448,25 @@ class SACPolicy(
self.critic_target = torch.compile(self.critic_target)
if self.config.num_discrete_actions is not None:
self._init_grasp_critics()
self._init_discrete_critics()
def _init_grasp_critics(self):
"""Build discrete grasp critic ensemble and target networks."""
self.grasp_critic = GraspCritic(
def _init_discrete_critics(self):
"""Build discrete discrete critic ensemble and target networks."""
self.discrete_critic = DiscreteCritic(
encoder=self.encoder_critic,
input_dim=self.encoder_critic.output_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.grasp_critic_network_kwargs),
**asdict(self.config.discrete_critic_network_kwargs),
)
self.grasp_critic_target = GraspCritic(
self.discrete_critic_target = DiscreteCritic(
encoder=self.encoder_critic,
input_dim=self.encoder_critic.output_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.grasp_critic_network_kwargs),
**asdict(self.config.discrete_critic_network_kwargs),
)
# TODO: (maractingi, azouitine) Compile the grasp critic
self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict())
# TODO: (maractingi, azouitine) Compile the discrete critic
self.discrete_critic_target.load_state_dict(self.discrete_critic.state_dict())
def _init_actor(self, continuous_action_dim):
"""Initialize policy actor network and default target entropy."""
@@ -590,7 +592,7 @@ class SACObservationEncoder(nn.Module):
This function processes image observations through the vision encoder once and returns
the resulting features.
When the image encoder is shared between actor and critics AND frozen, these features can be safely cached and
reused across policy components (actor, critic, grasp_critic), avoiding redundant forward passes.
reused across policy components (actor, critic, discrete_critic), avoiding redundant forward passes.
Performance impact:
- The vision encoder forward pass is typically the main computational bottleneck during training and inference
@@ -794,7 +796,7 @@ class CriticEnsemble(nn.Module):
return q_values
class GraspCritic(nn.Module):
class DiscreteCritic(nn.Module):
def __init__(
self,
encoder: nn.Module,

View File

@@ -833,7 +833,7 @@ class GripperPenaltyWrapper(gym.RewardWrapper):
if self.gripper_penalty_in_reward:
reward += gripper_penalty
else:
info["gripper_penalty"] = gripper_penalty
info["discrete_penalty"] = gripper_penalty
return obs, reward, terminated, truncated, info

View File

@@ -410,7 +410,7 @@ def add_actor_information_and_train(
"complementary_info": batch["complementary_info"],
}
# Use the forward method for critic loss (includes both main critic and grasp critic)
# Use the forward method for critic loss (includes both main critic and discrete critic)
critic_output = policy.forward(forward_batch, model="critic")
# Main critic optimization
@@ -422,16 +422,16 @@ def add_actor_information_and_train(
)
optimizers["critic"].step()
# Grasp critic optimization (if available)
# Discrete critic optimization (if available)
if policy.config.num_discrete_actions is not None:
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
optimizers["grasp_critic"].zero_grad()
loss_grasp_critic.backward()
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
discrete_critic_output = policy.forward(forward_batch, model="discrete_critic")
loss_discrete_critic = discrete_critic_output["loss_discrete_critic"]
optimizers["discrete_critic"].zero_grad()
loss_discrete_critic.backward()
discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value
)
optimizers["grasp_critic"].step()
optimizers["discrete_critic"].step()
# Update target networks
policy.update_target_networks()
@@ -468,7 +468,7 @@ def add_actor_information_and_train(
"next_observation_feature": next_observation_features,
}
# Use the forward method for critic loss (includes both main critic and grasp critic)
# Use the forward method for critic loss (includes both main critic and discrete critic)
critic_output = policy.forward(forward_batch, model="critic")
# Main critic optimization
@@ -486,20 +486,20 @@ def add_actor_information_and_train(
"critic_grad_norm": critic_grad_norm,
}
# Grasp critic optimization (if available)
# Discrete critic optimization (if available)
if policy.config.num_discrete_actions is not None:
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
optimizers["grasp_critic"].zero_grad()
loss_grasp_critic.backward()
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
discrete_critic_output = policy.forward(forward_batch, model="discrete_critic")
loss_discrete_critic = discrete_critic_output["loss_discrete_critic"]
optimizers["discrete_critic"].zero_grad()
loss_discrete_critic.backward()
discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value
).item()
optimizers["grasp_critic"].step()
optimizers["discrete_critic"].step()
# Add grasp critic info to training info
training_infos["loss_grasp_critic"] = loss_grasp_critic.item()
training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm
# Add discrete critic info to training info
training_infos["loss_discrete_critic"] = loss_discrete_critic.item()
training_infos["discrete_critic_grad_norm"] = discrete_critic_grad_norm
# Actor and temperature optimization (at specified frequency)
if optimization_step % policy_update_freq == 0:
@@ -782,8 +782,8 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
if cfg.policy.num_discrete_actions is not None:
optimizer_grasp_critic = torch.optim.Adam(
params=policy.grasp_critic.parameters(), lr=cfg.policy.critic_lr
optimizer_discrete_critic = torch.optim.Adam(
params=policy.discrete_critic.parameters(), lr=cfg.policy.critic_lr
)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
lr_scheduler = None
@@ -793,7 +793,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
"temperature": optimizer_temperature,
}
if cfg.policy.num_discrete_actions is not None:
optimizers["grasp_critic"] = optimizer_grasp_critic
optimizers["discrete_critic"] = optimizer_discrete_critic
return optimizers, lr_scheduler