forked from tangger/lerobot
1, add input normalization in configuration_sac.py 2, add masking on loss computation
This commit is contained in:
@@ -28,12 +28,18 @@ class SACConfig:
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [4],
|
||||
"action": [2],
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: dict[str, str] | None = None
|
||||
input_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
"observation.environment_state": "min_max",
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {"action": "min_max"},
|
||||
)
|
||||
|
||||
@@ -177,25 +177,19 @@ class SACPolicy(
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
critics_loss = F.mse_loss(
|
||||
q_preds, # shape: [num_critics, batch_size]
|
||||
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape
|
||||
reduction="none"
|
||||
).sum(0).mean()
|
||||
|
||||
# critics_loss = (
|
||||
# F.mse_loss(
|
||||
# q_preds,
|
||||
# einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]),
|
||||
# reduction="none",
|
||||
# ).sum(0) # sum over ensemble
|
||||
# # `q_preds_ensemble` depends on the first observation and the actions.
|
||||
# * ~batch["observation.state_is_pad"][0]
|
||||
# * ~batch["action_is_pad"]
|
||||
# # q_targets depends on the reward and the next observations.
|
||||
# * ~batch["next.reward_is_pad"]
|
||||
# * ~batch["observation.state_is_pad"][1:]
|
||||
# ).sum(0).mean()
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
q_preds,
|
||||
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]),
|
||||
reduction="none",
|
||||
).sum(0) # sum over ensemble
|
||||
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][:,0] # shape: [batch_size, horizon+1]
|
||||
* ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
|
||||
# q_targets depends on the reward and the next observations.
|
||||
* ~batch["next.reward_is_pad"][:,0] # shape: [batch_size, horizon]
|
||||
* ~batch["observation.state_is_pad"][:,1] # shape: [batch_size, horizon+1]
|
||||
).mean()
|
||||
|
||||
# calculate actors loss
|
||||
# 1- temperature
|
||||
@@ -207,8 +201,8 @@ class SACPolicy(
|
||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||
actor_loss = (
|
||||
-(q_preds - temperature * log_probs).mean()
|
||||
# * ~batch["observation.state_is_pad"][0]
|
||||
# * ~batch["action_is_pad"]
|
||||
* ~batch["observation.state_is_pad"][:,0] # shape: [batch_size, horizon+1]
|
||||
* ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
|
||||
).mean()
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ training:
|
||||
grad_clip_norm: 10.0
|
||||
lr: 3e-4
|
||||
|
||||
eval_freq: 50000
|
||||
eval_freq: 10000
|
||||
log_freq: 500
|
||||
save_freq: 50000
|
||||
|
||||
@@ -46,8 +46,8 @@ policy:
|
||||
|
||||
# Input / output structure.
|
||||
n_action_repeats: 1
|
||||
horizon: 5
|
||||
n_action_steps: 5
|
||||
horizon: 2
|
||||
n_action_steps: 2
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
|
||||
Reference in New Issue
Block a user