Refactor SACPolicy and configuration for improved training dynamics
- Introduced target critic networks in SACPolicy to enhance stability during training. - Updated TD target calculation to incorporate entropy adjustments, improving robustness. - Increased online buffer capacity in configuration from 10,000 to 40,000 for better data handling. - Adjusted learning rates for critic, actor, and temperature to 3e-4 for optimized training performance. These changes aim to refine the SAC implementation, enhancing its robustness and performance during training and inference.
This commit is contained in:
@@ -77,8 +77,19 @@ class SACPolicy(
|
||||
)
|
||||
critic_nets.append(critic_net)
|
||||
|
||||
target_critic_nets = []
|
||||
for _ in range(config.num_critics):
|
||||
target_critic_net = Critic(
|
||||
encoder=encoder_critic,
|
||||
network=MLP(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs
|
||||
)
|
||||
)
|
||||
target_critic_nets.append(target_critic_net)
|
||||
|
||||
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
|
||||
self.critic_target = deepcopy(self.critic_ensemble)
|
||||
self.critic_target = create_critic_ensemble(target_critic_nets, config.num_critics)
|
||||
|
||||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
@@ -169,12 +180,12 @@ class SACPolicy(
|
||||
|
||||
# critics subsample size
|
||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||
|
||||
# compute td target
|
||||
td_target = rewards + self.config.discount * min_q
|
||||
# breakpoint()
|
||||
if self.config.use_backup_entropy:
|
||||
td_target -= self.config.discount * self.temperature() * log_probs \
|
||||
* ~batch["observation.state_is_pad"][:,0] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
|
||||
min_q -= self.temperature() * log_probs * ~batch["observation.state_is_pad"][:,0] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
|
||||
td_target = rewards + self.config.discount * min_q * ~batch["next.done"]
|
||||
# td_target -= self.config.discount * self.temperature() * log_probs \
|
||||
# * ~batch["observation.state_is_pad"][:,0] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
|
||||
# print(f"Target Q-values: mean={td_target.mean():.3f}, max={td_target.max():.3f}")
|
||||
|
||||
# 3- compute predicted qs
|
||||
|
||||
@@ -29,7 +29,7 @@ training:
|
||||
online_steps_between_rollouts: 1000
|
||||
online_sampling_ratio: 1.0
|
||||
online_env_seed: 10000
|
||||
online_buffer_capacity: 10000
|
||||
online_buffer_capacity: 40000
|
||||
online_buffer_seed_size: 0
|
||||
do_online_rollout_async: false
|
||||
|
||||
@@ -70,9 +70,9 @@ policy:
|
||||
temperature_init: 1.0
|
||||
num_critics: 2
|
||||
num_subsample_critics: None
|
||||
critic_lr: 1e-4
|
||||
actor_lr: 1e-4
|
||||
temperature_lr: 1e-4
|
||||
critic_lr: 3e-4
|
||||
actor_lr: 3e-4
|
||||
temperature_lr: 3e-4
|
||||
critic_target_update_weight: 0.005
|
||||
utd_ratio: 2
|
||||
|
||||
|
||||
Reference in New Issue
Block a user