[PORT HIL-SERL] Better unit tests coverage for SAC policy (#1074)

This commit is contained in:
Eugene Mironov
2025-05-14 21:41:36 +07:00
committed by AdilZouitine
parent f8a963b86f
commit 5902f8fcc7
4 changed files with 492 additions and 25 deletions

View File

@@ -140,7 +140,6 @@ class SACConfig(PreTrainedConfig):
)
# Architecture specifics
camera_number: int = 1
device: str = "cpu"
storage_device: str = "cpu"
# Set to "helper2424/resnet10" for hil serl
@@ -184,6 +183,9 @@ class SACConfig(PreTrainedConfig):
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
# Optimizations
use_torch_compile: bool = True
def __post_init__(self):
super().__post_init__()
# Any validation specific to SAC configuration

View File

@@ -79,8 +79,9 @@ class SACPolicy(
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation"""
observations_features = None
if self.shared_encoder:
if self.shared_encoder and self.actor.encoder.has_images:
# Cache and normalize image features
observations_features = self.actor.encoder.get_cached_image_features(batch, normalize=True)
@@ -365,7 +366,7 @@ class SACPolicy(
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.actor(observations, observation_features)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
return temperature_loss
def compute_loss_actor(
@@ -393,6 +394,7 @@ class SACPolicy(
self.normalize_inputs = nn.Identity()
self.normalize_targets = nn.Identity()
self.unnormalize_outputs = nn.Identity()
if self.config.dataset_stats:
params = _convert_normalization_params_to_tensor(self.config.dataset_stats)
self.normalize_inputs = Normalize(
@@ -440,8 +442,9 @@ class SACPolicy(
)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target)
if self.config.use_torch_compile:
self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target)
if self.config.num_discrete_actions is not None:
self._init_discrete_critics()
@@ -473,9 +476,11 @@ class SACPolicy(
encoder_is_shared=self.shared_encoder,
**asdict(self.config.policy_kwargs),
)
if self.config.target_entropy is None:
self.target_entropy = self.config.target_entropy
if self.target_entropy is None:
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
self.config.target_entropy = -np.prod(dim) / 2
self.target_entropy = -np.prod(dim) / 2
def _init_temperature(self):
"""Set up temperature parameter and initial log_alpha."""
@@ -997,14 +1002,6 @@ def orthogonal_init():
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
class Identity(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
class SpatialLearnedEmbeddings(nn.Module):
def __init__(self, height, width, channel, num_features=8):
"""