1, add input normalization in configuration_sac.py 2, add masking on loss computation
This commit is contained in:
@@ -28,12 +28,18 @@ class SACConfig:
|
|||||||
)
|
)
|
||||||
output_shapes: dict[str, list[int]] = field(
|
output_shapes: dict[str, list[int]] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"action": [4],
|
"action": [2],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Normalization / Unnormalization
|
# Normalization / Unnormalization
|
||||||
input_normalization_modes: dict[str, str] | None = None
|
input_normalization_modes: dict[str, str] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"observation.image": "mean_std",
|
||||||
|
"observation.state": "min_max",
|
||||||
|
"observation.environment_state": "min_max",
|
||||||
|
}
|
||||||
|
)
|
||||||
output_normalization_modes: dict[str, str] = field(
|
output_normalization_modes: dict[str, str] = field(
|
||||||
default_factory=lambda: {"action": "min_max"},
|
default_factory=lambda: {"action": "min_max"},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -177,25 +177,19 @@ class SACPolicy(
|
|||||||
|
|
||||||
# 4- Calculate loss
|
# 4- Calculate loss
|
||||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||||
critics_loss = F.mse_loss(
|
critics_loss = (
|
||||||
q_preds, # shape: [num_critics, batch_size]
|
F.mse_loss(
|
||||||
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape
|
q_preds,
|
||||||
reduction="none"
|
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]),
|
||||||
).sum(0).mean()
|
reduction="none",
|
||||||
|
).sum(0) # sum over ensemble
|
||||||
# critics_loss = (
|
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||||
# F.mse_loss(
|
* ~batch["observation.state_is_pad"][:,0] # shape: [batch_size, horizon+1]
|
||||||
# q_preds,
|
* ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
|
||||||
# einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]),
|
# q_targets depends on the reward and the next observations.
|
||||||
# reduction="none",
|
* ~batch["next.reward_is_pad"][:,0] # shape: [batch_size, horizon]
|
||||||
# ).sum(0) # sum over ensemble
|
* ~batch["observation.state_is_pad"][:,1] # shape: [batch_size, horizon+1]
|
||||||
# # `q_preds_ensemble` depends on the first observation and the actions.
|
).mean()
|
||||||
# * ~batch["observation.state_is_pad"][0]
|
|
||||||
# * ~batch["action_is_pad"]
|
|
||||||
# # q_targets depends on the reward and the next observations.
|
|
||||||
# * ~batch["next.reward_is_pad"]
|
|
||||||
# * ~batch["observation.state_is_pad"][1:]
|
|
||||||
# ).sum(0).mean()
|
|
||||||
|
|
||||||
# calculate actors loss
|
# calculate actors loss
|
||||||
# 1- temperature
|
# 1- temperature
|
||||||
@@ -207,8 +201,8 @@ class SACPolicy(
|
|||||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||||
actor_loss = (
|
actor_loss = (
|
||||||
-(q_preds - temperature * log_probs).mean()
|
-(q_preds - temperature * log_probs).mean()
|
||||||
# * ~batch["observation.state_is_pad"][0]
|
* ~batch["observation.state_is_pad"][:,0] # shape: [batch_size, horizon+1]
|
||||||
# * ~batch["action_is_pad"]
|
* ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
|
||||||
).mean()
|
).mean()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ training:
|
|||||||
grad_clip_norm: 10.0
|
grad_clip_norm: 10.0
|
||||||
lr: 3e-4
|
lr: 3e-4
|
||||||
|
|
||||||
eval_freq: 50000
|
eval_freq: 10000
|
||||||
log_freq: 500
|
log_freq: 500
|
||||||
save_freq: 50000
|
save_freq: 50000
|
||||||
|
|
||||||
@@ -46,8 +46,8 @@ policy:
|
|||||||
|
|
||||||
# Input / output structure.
|
# Input / output structure.
|
||||||
n_action_repeats: 1
|
n_action_repeats: 1
|
||||||
horizon: 5
|
horizon: 2
|
||||||
n_action_steps: 5
|
n_action_steps: 2
|
||||||
|
|
||||||
input_shapes:
|
input_shapes:
|
||||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||||
|
|||||||
Reference in New Issue
Block a user