fix log_alpha in modeling_sac: change to nn.parameter

added pretrained vision model in policy

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi
2025-02-13 11:26:24 +01:00
committed by AdilZouitine
parent 57e09828ce
commit a0e0a9a9b1
4 changed files with 7 additions and 8 deletions

View File

@@ -411,7 +411,7 @@ def add_actor_information_and_train(
next_observations = batch["next_state"]
done = batch["done"]
assert_and_breakpoint(observations=observations, actions=actions, next_state=next_observations)
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
with policy_lock:
loss_critic = policy.compute_loss_critic(
@@ -533,7 +533,6 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
optimizer_critic = torch.optim.Adam(
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
)
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
lr_scheduler = None
optimizers = {