Enhance SACPolicy and learner server for improved grasp critic integration

- Updated SACPolicy to conditionally compute grasp critic losses based on the presence of discrete actions.
- Refactored the forward method to handle grasp critic model selection and loss computation more clearly.
- Adjusted learner server to utilize optimized parameters for grasp critic during training.
- Improved action handling in the ManiskillMockGripperWrapper to accommodate both tuple and single action inputs.
This commit is contained in:
AdilZouitine
2025-04-02 15:50:39 +00:00
committed by Michel Aractingi
parent f9fb9d4594
commit 6167886472
3 changed files with 72 additions and 50 deletions

View File

@@ -405,12 +405,13 @@ def add_actor_information_and_train(
optimizers["critic"].step()
# Grasp critic optimization (if available)
if "loss_grasp_critic" in critic_output:
loss_grasp_critic = critic_output["loss_grasp_critic"]
if policy.config.num_discrete_actions is not None:
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
optimizers["grasp_critic"].zero_grad()
loss_grasp_critic.backward()
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
)
optimizers["grasp_critic"].step()
@@ -467,12 +468,13 @@ def add_actor_information_and_train(
}
# Grasp critic optimization (if available)
if "loss_grasp_critic" in critic_output:
loss_grasp_critic = critic_output["loss_grasp_critic"]
if policy.config.num_discrete_actions is not None:
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
optimizers["grasp_critic"].zero_grad()
loss_grasp_critic.backward()
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
).item()
optimizers["grasp_critic"].step()
@@ -759,7 +761,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
if cfg.policy.num_discrete_actions is not None:
optimizer_grasp_critic = torch.optim.Adam(
params=policy.grasp_critic.parameters(), lr=policy.critic_lr
params=policy.grasp_critic.parameters_to_optimize, lr=cfg.policy.critic_lr
)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
lr_scheduler = None

View File

@@ -16,7 +16,6 @@ from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.policies.sac.modeling_sac import SACPolicy
def preprocess_maniskill_observation(
observations: dict[str, np.ndarray],
) -> dict[str, torch.Tensor]:
@@ -156,6 +155,7 @@ class TimeLimitWrapper(gym.Wrapper):
self.current_step = 0
return super().reset(seed=seed, options=options)
class ManiskillMockGripperWrapper(gym.Wrapper):
def __init__(self, env, nb_discrete_actions: int = 3):
super().__init__(env)
@@ -166,11 +166,16 @@ class ManiskillMockGripperWrapper(gym.Wrapper):
self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1]))
def step(self, action):
action_agent, telop_action = action
if isinstance(action, tuple):
action_agent, telop_action = action
else:
telop_action = 0
action_agent = action
real_action = action_agent[:-1]
final_action = (real_action, telop_action)
obs, reward, terminated, truncated, info = self.env.step(final_action)
return obs, reward, terminated, truncated, info
return obs, reward, terminated, truncated, info
def make_maniskill(
cfg: ManiskillEnvConfig,