fixed softmax temp

This commit is contained in:
Michel Aractingi
2025-04-08 13:35:25 +02:00
parent 10adadbc71
commit e36bee7560

View File

@@ -129,7 +129,7 @@ class SACPolicy(
encoder=encoder_critic,
input_dim=encoder_critic.output_dim,
output_dim=config.num_discrete_actions,
softmax_temperature=1.0,
softmax_temperature=.15,
**asdict(config.grasp_critic_network_kwargs),
)
@@ -138,7 +138,7 @@ class SACPolicy(
encoder=encoder_critic,
input_dim=encoder_critic.output_dim,
output_dim=config.num_discrete_actions,
softmax_temperature=1.0,
softmax_temperature=0.15,
**asdict(config.grasp_critic_network_kwargs),
)
@@ -786,6 +786,7 @@ class GraspCritic(nn.Module):
super().__init__()
self.encoder = encoder
self.output_dim = output_dim
self.softmax_temperature = softmax_temperature
self.net = MLP(
input_dim=input_dim,