committed by
Michel Aractingi
parent
c37936f2c9
commit
3424644ecd
@@ -27,7 +27,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.distributions import MultivariateNormal, TransformedDistribution, TanhTransform, Transform
|
||||
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
|
||||
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
@@ -156,7 +156,9 @@ class SACPolicy(
|
||||
**asdict(config.policy_kwargs),
|
||||
)
|
||||
if config.target_entropy is None:
|
||||
discrete_actions_dim: Literal[1] | Literal[0] = 1 if config.num_discrete_actions is None else 0
|
||||
discrete_actions_dim: Literal[1] | Literal[0] = (
|
||||
1 if config.num_discrete_actions is not None else 0
|
||||
)
|
||||
config.target_entropy = -np.prod(continuous_action_dim + discrete_actions_dim) / 2 # (-dim(A)/2)
|
||||
|
||||
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
||||
|
||||
Reference in New Issue
Block a user