From 77a7f921390b09c80120933c0796dc3fea1f0dba Mon Sep 17 00:00:00 2001 From: KeWang1017 Date: Mon, 30 Dec 2024 18:46:25 +0000 Subject: [PATCH] 1, add input normalization in configuration_sac.py 2, add masking on loss computation --- .../common/policies/sac/configuration_sac.py | 10 ++++-- lerobot/common/policies/sac/modeling_sac.py | 36 ++++++++----------- .../configs/policy/sac_pusht_keypoints.yaml | 6 ++-- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 4ae6e5d42..159602e2c 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -28,12 +28,18 @@ class SACConfig: ) output_shapes: dict[str, list[int]] = field( default_factory=lambda: { - "action": [4], + "action": [2], } ) # Normalization / Unnormalization - input_normalization_modes: dict[str, str] | None = None + input_normalization_modes: dict[str, str] = field( + default_factory=lambda: { + "observation.image": "mean_std", + "observation.state": "min_max", + "observation.environment_state": "min_max", + } + ) output_normalization_modes: dict[str, str] = field( default_factory=lambda: {"action": "min_max"}, ) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index bd77408ec..71083b57c 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -177,25 +177,19 @@ class SACPolicy( # 4- Calculate loss # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. - critics_loss = F.mse_loss( - q_preds, # shape: [num_critics, batch_size] - einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape - reduction="none" - ).sum(0).mean() - - # critics_loss = ( - # F.mse_loss( - # q_preds, - # einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), - # reduction="none", - # ).sum(0) # sum over ensemble - # # `q_preds_ensemble` depends on the first observation and the actions. - # * ~batch["observation.state_is_pad"][0] - # * ~batch["action_is_pad"] - # # q_targets depends on the reward and the next observations. - # * ~batch["next.reward_is_pad"] - # * ~batch["observation.state_is_pad"][1:] - # ).sum(0).mean() + critics_loss = ( + F.mse_loss( + q_preds, + einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), + reduction="none", + ).sum(0) # sum over ensemble + # `q_preds_ensemble` depends on the first observation and the actions. + * ~batch["observation.state_is_pad"][:,0] # shape: [batch_size, horizon+1] + * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon] + # q_targets depends on the reward and the next observations. + * ~batch["next.reward_is_pad"][:,0] # shape: [batch_size, horizon] + * ~batch["observation.state_is_pad"][:,1] # shape: [batch_size, horizon+1] + ).mean() # calculate actors loss # 1- temperature @@ -207,8 +201,8 @@ class SACPolicy( q_preds = self.critic_forward(observations, actions, use_target=False) actor_loss = ( -(q_preds - temperature * log_probs).mean() - # * ~batch["observation.state_is_pad"][0] - # * ~batch["action_is_pad"] + * ~batch["observation.state_is_pad"][:,0] # shape: [batch_size, horizon+1] + * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon] ).mean() diff --git a/lerobot/configs/policy/sac_pusht_keypoints.yaml b/lerobot/configs/policy/sac_pusht_keypoints.yaml index 6d8971a24..dad1508bf 100644 --- a/lerobot/configs/policy/sac_pusht_keypoints.yaml +++ b/lerobot/configs/policy/sac_pusht_keypoints.yaml @@ -19,7 +19,7 @@ training: grad_clip_norm: 10.0 lr: 3e-4 - eval_freq: 50000 + eval_freq: 10000 log_freq: 500 save_freq: 50000 @@ -46,8 +46,8 @@ policy: # Input / output structure. n_action_repeats: 1 - horizon: 5 - n_action_steps: 5 + horizon: 2 + n_action_steps: 2 input_shapes: # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?