Enhance SACPolicy and learner server for improved grasp critic integration
- Updated SACPolicy to conditionally compute grasp critic losses based on the presence of discrete actions. - Refactored the forward method to handle grasp critic model selection and loss computation more clearly. - Adjusted learner server to utilize optimized parameters for grasp critic during training. - Improved action handling in the ManiskillMockGripperWrapper to accommodate both tuple and single action inputs.
This commit is contained in:
committed by
Michel Aractingi
parent
f9fb9d4594
commit
6167886472
@@ -405,12 +405,13 @@ def add_actor_information_and_train(
|
||||
optimizers["critic"].step()
|
||||
|
||||
# Grasp critic optimization (if available)
|
||||
if "loss_grasp_critic" in critic_output:
|
||||
loss_grasp_critic = critic_output["loss_grasp_critic"]
|
||||
if policy.config.num_discrete_actions is not None:
|
||||
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
|
||||
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
|
||||
optimizers["grasp_critic"].zero_grad()
|
||||
loss_grasp_critic.backward()
|
||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||
)
|
||||
optimizers["grasp_critic"].step()
|
||||
|
||||
@@ -467,12 +468,13 @@ def add_actor_information_and_train(
|
||||
}
|
||||
|
||||
# Grasp critic optimization (if available)
|
||||
if "loss_grasp_critic" in critic_output:
|
||||
loss_grasp_critic = critic_output["loss_grasp_critic"]
|
||||
if policy.config.num_discrete_actions is not None:
|
||||
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
|
||||
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
|
||||
optimizers["grasp_critic"].zero_grad()
|
||||
loss_grasp_critic.backward()
|
||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["grasp_critic"].step()
|
||||
|
||||
@@ -759,7 +761,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
|
||||
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
optimizer_grasp_critic = torch.optim.Adam(
|
||||
params=policy.grasp_critic.parameters(), lr=policy.critic_lr
|
||||
params=policy.grasp_critic.parameters_to_optimize, lr=cfg.policy.critic_lr
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
||||
lr_scheduler = None
|
||||
|
||||
@@ -16,7 +16,6 @@ from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
|
||||
|
||||
def preprocess_maniskill_observation(
|
||||
observations: dict[str, np.ndarray],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
@@ -156,6 +155,7 @@ class TimeLimitWrapper(gym.Wrapper):
|
||||
self.current_step = 0
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
|
||||
class ManiskillMockGripperWrapper(gym.Wrapper):
|
||||
def __init__(self, env, nb_discrete_actions: int = 3):
|
||||
super().__init__(env)
|
||||
@@ -166,11 +166,16 @@ class ManiskillMockGripperWrapper(gym.Wrapper):
|
||||
self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1]))
|
||||
|
||||
def step(self, action):
|
||||
action_agent, telop_action = action
|
||||
if isinstance(action, tuple):
|
||||
action_agent, telop_action = action
|
||||
else:
|
||||
telop_action = 0
|
||||
action_agent = action
|
||||
real_action = action_agent[:-1]
|
||||
final_action = (real_action, telop_action)
|
||||
obs, reward, terminated, truncated, info = self.env.step(final_action)
|
||||
return obs, reward, terminated, truncated, info
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
|
||||
def make_maniskill(
|
||||
cfg: ManiskillEnvConfig,
|
||||
|
||||
Reference in New Issue
Block a user