Refactor SACPolicy and configuration for improved training dynamics

- Introduced target critic networks in SACPolicy to enhance stability during training.
- Updated TD target calculation to incorporate entropy adjustments, improving robustness.
- Increased online buffer capacity in configuration from 10,000 to 40,000 for better data handling.
- Adjusted learning rates for critic, actor, and temperature to 3e-4 for optimized training performance.

These changes aim to refine the SAC implementation, enhancing its robustness and performance during training and inference.
This commit is contained in:
Ke-Wang1017
2025-01-02 22:13:58 +00:00
parent eec28baa63
commit f99e670976
2 changed files with 21 additions and 10 deletions

View File

@@ -77,8 +77,19 @@ class SACPolicy(
)
critic_nets.append(critic_net)
target_critic_nets = []
for _ in range(config.num_critics):
target_critic_net = Critic(
encoder=encoder_critic,
network=MLP(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs
)
)
target_critic_nets.append(target_critic_net)
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
self.critic_target = deepcopy(self.critic_ensemble)
self.critic_target = create_critic_ensemble(target_critic_nets, config.num_critics)
self.actor = Policy(
encoder=encoder_actor,
@@ -169,12 +180,12 @@ class SACPolicy(
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
# compute td target
td_target = rewards + self.config.discount * min_q
# breakpoint()
if self.config.use_backup_entropy:
td_target -= self.config.discount * self.temperature() * log_probs \
* ~batch["observation.state_is_pad"][:,0] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
min_q -= self.temperature() * log_probs * ~batch["observation.state_is_pad"][:,0] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
td_target = rewards + self.config.discount * min_q * ~batch["next.done"]
# td_target -= self.config.discount * self.temperature() * log_probs \
# * ~batch["observation.state_is_pad"][:,0] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
# print(f"Target Q-values: mean={td_target.mean():.3f}, max={td_target.max():.3f}")
# 3- compute predicted qs

View File

@@ -29,7 +29,7 @@ training:
online_steps_between_rollouts: 1000
online_sampling_ratio: 1.0
online_env_seed: 10000
online_buffer_capacity: 10000
online_buffer_capacity: 40000
online_buffer_seed_size: 0
do_online_rollout_async: false
@@ -70,9 +70,9 @@ policy:
temperature_init: 1.0
num_critics: 2
num_subsample_critics: None
critic_lr: 1e-4
actor_lr: 1e-4
temperature_lr: 1e-4
critic_lr: 3e-4
actor_lr: 3e-4
temperature_lr: 3e-4
critic_target_update_weight: 0.005
utd_ratio: 2