Add mock gripper support and enhance SAC policy action handling

- Introduced mock_gripper parameter in ManiskillEnvConfig to enable gripper simulation.
- Added ManiskillMockGripperWrapper to adjust action space for environments with discrete actions.
- Updated SACPolicy to compute continuous action dimensions correctly, ensuring compatibility with the new gripper setup.
- Refactored action handling in the training loop to accommodate the changes in action dimensions.
This commit is contained in:
AdilZouitine
2025-04-01 14:22:08 +00:00
committed by Michel Aractingi
parent f83d215e7a
commit d86d29fe21
3 changed files with 59 additions and 31 deletions

View File

@@ -33,7 +33,7 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.policies.utils import get_device_from_parameters
DISCRETE_DIMENSION_INDEX = -1
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
class SACPolicy(
PreTrainedPolicy,
@@ -82,7 +82,7 @@ class SACPolicy(
# Create a list of critic heads
critic_heads = [
CriticHead(
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
input_dim=encoder_critic.output_dim + continuous_action_dim,
**asdict(config.critic_network_kwargs),
)
for _ in range(config.num_critics)
@@ -97,7 +97,7 @@ class SACPolicy(
# Create target critic heads as deepcopies of the original critic heads
target_critic_heads = [
CriticHead(
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
input_dim=encoder_critic.output_dim + continuous_action_dim,
**asdict(config.critic_network_kwargs),
)
for _ in range(config.num_critics)
@@ -117,7 +117,10 @@ class SACPolicy(
self.grasp_critic = None
self.grasp_critic_target = None
continuous_action_dim = config.output_features["action"].shape[0]
if config.num_discrete_actions is not None:
continuous_action_dim -= 1
# Create grasp critic
self.grasp_critic = GraspCritic(
encoder=encoder_critic,
@@ -139,15 +142,16 @@ class SACPolicy(
self.grasp_critic = torch.compile(self.grasp_critic)
self.grasp_critic_target = torch.compile(self.grasp_critic_target)
self.actor = Policy(
encoder=encoder_actor,
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
action_dim=config.output_features["action"].shape[0],
action_dim=continuous_action_dim,
encoder_is_shared=config.shared_encoder,
**asdict(config.policy_kwargs),
)
if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2 # (-dim(A)/2)
config.target_entropy = -np.prod(continuous_action_dim) / 2 # (-dim(A)/2)
# TODO (azouitine): Handle the case where the temparameter is a fixed
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
@@ -275,7 +279,9 @@ class SACPolicy(
next_observations=next_observations,
done=done,
)
return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic}
return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic}
return {"loss_critic": loss_critic}
if model == "actor":