committed by
Adil Zouitine
parent
23c9441d5f
commit
dc1548fe1a
@@ -27,7 +27,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.distributions import MultivariateNormal, TransformedDistribution, TanhTransform, Transform
|
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
|
||||||
|
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
@@ -156,7 +156,9 @@ class SACPolicy(
|
|||||||
**asdict(config.policy_kwargs),
|
**asdict(config.policy_kwargs),
|
||||||
)
|
)
|
||||||
if config.target_entropy is None:
|
if config.target_entropy is None:
|
||||||
discrete_actions_dim: Literal[1] | Literal[0] = 1 if config.num_discrete_actions is None else 0
|
discrete_actions_dim: Literal[1] | Literal[0] = (
|
||||||
|
1 if config.num_discrete_actions is not None else 0
|
||||||
|
)
|
||||||
config.target_entropy = -np.prod(continuous_action_dim + discrete_actions_dim) / 2 # (-dim(A)/2)
|
config.target_entropy = -np.prod(continuous_action_dim + discrete_actions_dim) / 2 # (-dim(A)/2)
|
||||||
|
|
||||||
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
||||||
|
|||||||
Reference in New Issue
Block a user