1, add input normalization in configuration_sac.py 2, add masking on loss computation

This commit is contained in:
KeWang1017
2024-12-30 18:46:25 +00:00
committed by Ke-Wang1017
parent 35de91ef2b
commit 77a7f92139
3 changed files with 26 additions and 26 deletions

View File

@@ -28,12 +28,18 @@ class SACConfig:
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [4],
"action": [2],
}
)
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] | None = None
input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "min_max",
"observation.environment_state": "min_max",
}
)
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"},
)

View File

@@ -177,25 +177,19 @@ class SACPolicy(
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
critics_loss = F.mse_loss(
q_preds, # shape: [num_critics, batch_size]
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape
reduction="none"
).sum(0).mean()
# critics_loss = (
# F.mse_loss(
# q_preds,
# einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]),
# reduction="none",
# ).sum(0) # sum over ensemble
# # `q_preds_ensemble` depends on the first observation and the actions.
# * ~batch["observation.state_is_pad"][0]
# * ~batch["action_is_pad"]
# # q_targets depends on the reward and the next observations.
# * ~batch["next.reward_is_pad"]
# * ~batch["observation.state_is_pad"][1:]
# ).sum(0).mean()
critics_loss = (
F.mse_loss(
q_preds,
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]),
reduction="none",
).sum(0) # sum over ensemble
# `q_preds_ensemble` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][:,0] # shape: [batch_size, horizon+1]
* ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
# q_targets depends on the reward and the next observations.
* ~batch["next.reward_is_pad"][:,0] # shape: [batch_size, horizon]
* ~batch["observation.state_is_pad"][:,1] # shape: [batch_size, horizon+1]
).mean()
# calculate actors loss
# 1- temperature
@@ -207,8 +201,8 @@ class SACPolicy(
q_preds = self.critic_forward(observations, actions, use_target=False)
actor_loss = (
-(q_preds - temperature * log_probs).mean()
# * ~batch["observation.state_is_pad"][0]
# * ~batch["action_is_pad"]
* ~batch["observation.state_is_pad"][:,0] # shape: [batch_size, horizon+1]
* ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
).mean()

View File

@@ -19,7 +19,7 @@ training:
grad_clip_norm: 10.0
lr: 3e-4
eval_freq: 50000
eval_freq: 10000
log_freq: 500
save_freq: 50000
@@ -46,8 +46,8 @@ policy:
# Input / output structure.
n_action_repeats: 1
horizon: 5
n_action_steps: 5
horizon: 2
n_action_steps: 2
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?