forked from tangger/lerobot
Add mock gripper support and enhance SAC policy action handling
- Introduced mock_gripper parameter in ManiskillEnvConfig to enable gripper simulation. - Added ManiskillMockGripperWrapper to adjust action space for environments with discrete actions. - Updated SACPolicy to compute continuous action dimensions correctly, ensuring compatibility with the new gripper setup. - Refactored action handling in the training loop to accommodate the changes in action dimensions.
This commit is contained in:
committed by
Michel Aractingi
parent
f83d215e7a
commit
d86d29fe21
@@ -33,7 +33,7 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
|
||||
DISCRETE_DIMENSION_INDEX = -1
|
||||
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
|
||||
|
||||
class SACPolicy(
|
||||
PreTrainedPolicy,
|
||||
@@ -82,7 +82,7 @@ class SACPolicy(
|
||||
# Create a list of critic heads
|
||||
critic_heads = [
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
|
||||
input_dim=encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
@@ -97,7 +97,7 @@ class SACPolicy(
|
||||
# Create target critic heads as deepcopies of the original critic heads
|
||||
target_critic_heads = [
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
|
||||
input_dim=encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
@@ -117,7 +117,10 @@ class SACPolicy(
|
||||
self.grasp_critic = None
|
||||
self.grasp_critic_target = None
|
||||
|
||||
continuous_action_dim = config.output_features["action"].shape[0]
|
||||
if config.num_discrete_actions is not None:
|
||||
|
||||
continuous_action_dim -= 1
|
||||
# Create grasp critic
|
||||
self.grasp_critic = GraspCritic(
|
||||
encoder=encoder_critic,
|
||||
@@ -139,15 +142,16 @@ class SACPolicy(
|
||||
self.grasp_critic = torch.compile(self.grasp_critic)
|
||||
self.grasp_critic_target = torch.compile(self.grasp_critic_target)
|
||||
|
||||
|
||||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
|
||||
action_dim=config.output_features["action"].shape[0],
|
||||
action_dim=continuous_action_dim,
|
||||
encoder_is_shared=config.shared_encoder,
|
||||
**asdict(config.policy_kwargs),
|
||||
)
|
||||
if config.target_entropy is None:
|
||||
config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2 # (-dim(A)/2)
|
||||
config.target_entropy = -np.prod(continuous_action_dim) / 2 # (-dim(A)/2)
|
||||
|
||||
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
||||
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
|
||||
@@ -275,7 +279,9 @@ class SACPolicy(
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
)
|
||||
return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic}
|
||||
return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic}
|
||||
|
||||
return {"loss_critic": loss_critic}
|
||||
|
||||
|
||||
if model == "actor":
|
||||
|
||||
Reference in New Issue
Block a user