diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 2fabf5a8..e851afdc 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -140,7 +140,6 @@ class SACConfig(PreTrainedConfig): ) # Architecture specifics - camera_number: int = 1 device: str = "cpu" storage_device: str = "cpu" # Set to "helper2424/resnet10" for hil serl @@ -184,6 +183,9 @@ class SACConfig(PreTrainedConfig): actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig) concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig) + # Optimizations + use_torch_compile: bool = True + def __post_init__(self): super().__post_init__() # Any validation specific to SAC configuration diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 4749b703..04145e12 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -79,8 +79,9 @@ class SACPolicy( @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select action for inference/evaluation""" + observations_features = None - if self.shared_encoder: + if self.shared_encoder and self.actor.encoder.has_images: # Cache and normalize image features observations_features = self.actor.encoder.get_cached_image_features(batch, normalize=True) @@ -365,7 +366,7 @@ class SACPolicy( # calculate temperature loss with torch.no_grad(): _, log_probs, _ = self.actor(observations, observation_features) - temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean() + temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean() return temperature_loss def compute_loss_actor( @@ -393,6 +394,7 @@ class SACPolicy( self.normalize_inputs = nn.Identity() self.normalize_targets = nn.Identity() self.unnormalize_outputs = nn.Identity() + if self.config.dataset_stats: params = _convert_normalization_params_to_tensor(self.config.dataset_stats) self.normalize_inputs = Normalize( @@ -440,8 +442,9 @@ class SACPolicy( ) self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) - self.critic_ensemble = torch.compile(self.critic_ensemble) - self.critic_target = torch.compile(self.critic_target) + if self.config.use_torch_compile: + self.critic_ensemble = torch.compile(self.critic_ensemble) + self.critic_target = torch.compile(self.critic_target) if self.config.num_discrete_actions is not None: self._init_discrete_critics() @@ -473,9 +476,11 @@ class SACPolicy( encoder_is_shared=self.shared_encoder, **asdict(self.config.policy_kwargs), ) - if self.config.target_entropy is None: + + self.target_entropy = self.config.target_entropy + if self.target_entropy is None: dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0) - self.config.target_entropy = -np.prod(dim) / 2 + self.target_entropy = -np.prod(dim) / 2 def _init_temperature(self): """Set up temperature parameter and initial log_alpha.""" @@ -997,14 +1002,6 @@ def orthogonal_init(): return lambda x: torch.nn.init.orthogonal_(x, gain=1.0) -class Identity(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x - - class SpatialLearnedEmbeddings(nn.Module): def __init__(self, height, width, channel, num_features=8): """ diff --git a/tests/policies/test_sac_config.py b/tests/policies/test_sac_config.py index 3ee21672..98333e9f 100644 --- a/tests/policies/test_sac_config.py +++ b/tests/policies/test_sac_config.py @@ -24,7 +24,7 @@ from lerobot.common.policies.sac.configuration_sac import ( PolicyConfig, SACConfig, ) -from lerobot.configs.types import NormalizationMode +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature def test_sac_config_default_initialization(): @@ -192,16 +192,16 @@ def test_sac_config_custom_initialization(): def test_validate_features(): config = SACConfig( - input_features={"observation.state": {"shape": (10,), "type": "float32"}}, - output_features={"action": {"shape": (3,), "type": "float32"}}, + input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))}, + output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, ) config.validate_features() def test_validate_features_missing_observation(): config = SACConfig( - input_features={"wrong_key": {"shape": (10,), "type": "float32"}}, - output_features={"action": {"shape": (3,), "type": "float32"}}, + input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))}, + output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, ) with pytest.raises( ValueError, match="You must provide either 'observation.state' or an image observation" @@ -211,8 +211,8 @@ def test_validate_features_missing_observation(): def test_validate_features_missing_action(): config = SACConfig( - input_features={"observation.state": {"shape": (10,), "type": "float32"}}, - output_features={"wrong_key": {"shape": (3,), "type": "float32"}}, + input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))}, + output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, ) with pytest.raises(ValueError, match="You must provide 'action' in the output features"): config.validate_features() diff --git a/tests/policies/test_sac_policy.py b/tests/policies/test_sac_policy.py index cfda877a..18e3b6f2 100644 --- a/tests/policies/test_sac_policy.py +++ b/tests/policies/test_sac_policy.py @@ -1,7 +1,13 @@ -import torch -from torch import nn +import math -from lerobot.common.policies.sac.modeling_sac import MLP +import pytest +import torch +from torch import Tensor, nn + +from lerobot.common.policies.sac.configuration_sac import SACConfig +from lerobot.common.policies.sac.modeling_sac import MLP, SACPolicy +from lerobot.common.utils.random_utils import seeded_context +from lerobot.configs.types import FeatureType, PolicyFeature def test_mlp_with_default_args(): @@ -41,3 +47,465 @@ def test_mlp_with_custom_final_activation(): y = mlp(x) assert y.shape == (1, 256) assert (y >= -1).all() and (y <= 1).all() + + +def test_sac_policy_with_default_args(): + with pytest.raises(ValueError, match="should be an instance of class `PreTrainedConfig`"): + SACPolicy() + + +def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor: + return { + "observation.state": torch.randn(batch_size, state_dim), + } + + +def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor: + return { + "observation.image": torch.randn(batch_size, 3, 84, 84), + "observation.state": torch.randn(batch_size, state_dim), + } + + +def create_dummy_action(batch_size: int, action_dim: int = 10) -> Tensor: + return torch.randn(batch_size, action_dim) + + +def create_default_train_batch( + batch_size: int = 8, state_dim: int = 10, action_dim: int = 10 +) -> dict[str, Tensor]: + return { + "action": create_dummy_action(batch_size, action_dim), + "reward": torch.randn(batch_size), + "state": create_dummy_state(batch_size, state_dim), + "next_state": create_dummy_state(batch_size, state_dim), + "done": torch.randn(batch_size), + } + + +def create_train_batch_with_visual_input( + batch_size: int = 8, state_dim: int = 10, action_dim: int = 10 +) -> dict[str, Tensor]: + return { + "action": create_dummy_action(batch_size, action_dim), + "reward": torch.randn(batch_size), + "state": create_dummy_with_visual_input(batch_size, state_dim), + "next_state": create_dummy_with_visual_input(batch_size, state_dim), + "done": torch.randn(batch_size), + } + + +def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]: + return { + "observation.state": torch.randn(batch_size, state_dim), + } + + +def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]: + return { + "observation.state": torch.randn(batch_size, state_dim), + "observation.image": torch.randn(batch_size, 3, 84, 84), + } + + +def make_optimizers(policy: SACPolicy, has_discrete_action: bool = False) -> dict[str, torch.optim.Optimizer]: + """Create optimizers for the SAC policy.""" + optimizer_actor = torch.optim.Adam( + # Handle the case of shared encoder where the encoder weights are not optimized with the actor gradient + params=[ + p + for n, p in policy.actor.named_parameters() + if not policy.config.shared_encoder or not n.startswith("encoder") + ], + lr=policy.config.actor_lr, + ) + optimizer_critic = torch.optim.Adam( + params=policy.critic_ensemble.parameters(), + lr=policy.config.critic_lr, + ) + optimizer_temperature = torch.optim.Adam( + params=[policy.log_alpha], + lr=policy.config.critic_lr, + ) + + optimizers = { + "actor": optimizer_actor, + "critic": optimizer_critic, + "temperature": optimizer_temperature, + } + + if has_discrete_action: + optimizers["discrete_critic"] = torch.optim.Adam( + params=policy.discrete_critic.parameters(), + lr=policy.config.critic_lr, + ) + + return optimizers + + +def create_default_config( + state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False +) -> SACConfig: + action_dim = continuous_action_dim + if has_discrete_action: + action_dim += 1 + + config = SACConfig( + input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, + output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))}, + dataset_stats={ + "observation.state": { + "min": [0.0] * state_dim, + "max": [1.0] * state_dim, + }, + "action": { + "min": [0.0] * continuous_action_dim, + "max": [1.0] * continuous_action_dim, + }, + }, + ) + config.validate_features() + return config + + +def create_config_with_visual_input( + state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False +) -> SACConfig: + config = create_default_config( + state_dim=state_dim, + continuous_action_dim=continuous_action_dim, + has_discrete_action=has_discrete_action, + ) + config.input_features["observation.image"] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84)) + config.dataset_stats["observation.image"] = { + "mean": torch.randn(3, 1, 1), + "std": torch.randn(3, 1, 1), + } + + # Let make tests a little bit faster + config.state_encoder_hidden_dim = 32 + config.latent_dim = 32 + + config.validate_features() + return config + + +@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) +def test_sac_policy_with_default_config(batch_size: int, state_dim: int, action_dim: int): + batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim) + config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) + + policy = SACPolicy(config=config) + policy.train() + + optimizers = make_optimizers(policy) + + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + assert cirtic_loss.item() is not None + assert cirtic_loss.shape == () + cirtic_loss.backward() + optimizers["critic"].step() + + actor_loss = policy.forward(batch, model="actor")["loss_actor"] + assert actor_loss.item() is not None + assert actor_loss.shape == () + + actor_loss.backward() + optimizers["actor"].step() + + temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"] + assert temperature_loss.item() is not None + assert temperature_loss.shape == () + + temperature_loss.backward() + optimizers["temperature"].step() + + policy.eval() + with torch.no_grad(): + observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + selected_action = policy.select_action(observation_batch) + assert selected_action.shape == (batch_size, action_dim) + + +@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) +def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_dim: int): + config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) + policy = SACPolicy(config=config) + + batch = create_train_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim, action_dim=action_dim + ) + + policy.train() + + optimizers = make_optimizers(policy) + + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + assert cirtic_loss.item() is not None + assert cirtic_loss.shape == () + cirtic_loss.backward() + optimizers["critic"].step() + + actor_loss = policy.forward(batch, model="actor")["loss_actor"] + assert actor_loss.item() is not None + assert actor_loss.shape == () + + actor_loss.backward() + optimizers["actor"].step() + + temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"] + assert temperature_loss.item() is not None + assert temperature_loss.shape == () + + temperature_loss.backward() + optimizers["temperature"].step() + + policy.eval() + with torch.no_grad(): + observation_batch = create_observation_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim + ) + selected_action = policy.select_action(observation_batch) + assert selected_action.shape == (batch_size, action_dim) + + +# Let's check best candidates for pretrained encoders +@pytest.mark.parametrize( + "batch_size,state_dim,action_dim,vision_encoder_name", + [(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")], +) +def test_sac_policy_with_pretrained_encoder( + batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str +): + config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) + config.vision_encoder_name = vision_encoder_name + policy = SACPolicy(config=config) + policy.train() + + batch = create_train_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim, action_dim=action_dim + ) + + optimizers = make_optimizers(policy) + + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + assert cirtic_loss.item() is not None + assert cirtic_loss.shape == () + cirtic_loss.backward() + optimizers["critic"].step() + + actor_loss = policy.forward(batch, model="actor")["loss_actor"] + assert actor_loss.item() is not None + assert actor_loss.shape == () + + +def test_sac_policy_with_shared_encoder(): + batch_size = 2 + action_dim = 10 + state_dim = 10 + config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) + config.shared_encoder = True + + policy = SACPolicy(config=config) + policy.train() + + batch = create_train_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim, action_dim=action_dim + ) + + policy.train() + + optimizers = make_optimizers(policy) + + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + assert cirtic_loss.item() is not None + assert cirtic_loss.shape == () + cirtic_loss.backward() + optimizers["critic"].step() + + actor_loss = policy.forward(batch, model="actor")["loss_actor"] + assert actor_loss.item() is not None + assert actor_loss.shape == () + + actor_loss.backward() + optimizers["actor"].step() + + +def test_sac_policy_with_discrete_critic(): + batch_size = 2 + continuous_action_dim = 9 + full_action_dim = continuous_action_dim + 1 # the last action is discrete + state_dim = 10 + config = create_config_with_visual_input( + state_dim=state_dim, continuous_action_dim=continuous_action_dim, has_discrete_action=True + ) + + num_discrete_actions = 5 + config.num_discrete_actions = num_discrete_actions + + policy = SACPolicy(config=config) + policy.train() + + batch = create_train_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim, action_dim=full_action_dim + ) + + policy.train() + + optimizers = make_optimizers(policy, has_discrete_action=True) + + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + assert cirtic_loss.item() is not None + assert cirtic_loss.shape == () + cirtic_loss.backward() + optimizers["critic"].step() + + discrete_critic_loss = policy.forward(batch, model="discrete_critic")["loss_discrete_critic"] + assert discrete_critic_loss.item() is not None + assert discrete_critic_loss.shape == () + discrete_critic_loss.backward() + optimizers["discrete_critic"].step() + + actor_loss = policy.forward(batch, model="actor")["loss_actor"] + assert actor_loss.item() is not None + assert actor_loss.shape == () + + actor_loss.backward() + optimizers["actor"].step() + + policy.eval() + with torch.no_grad(): + observation_batch = create_observation_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim + ) + selected_action = policy.select_action(observation_batch) + assert selected_action.shape == (batch_size, full_action_dim) + + discrete_actions = selected_action[:, -1].long() + discrete_action_values = set(discrete_actions.tolist()) + + assert all(action in range(num_discrete_actions) for action in discrete_action_values), ( + f"Discrete action {discrete_action_values} is not in range({num_discrete_actions})" + ) + + +def test_sac_policy_with_default_entropy(): + config = create_default_config(continuous_action_dim=10, state_dim=10) + policy = SACPolicy(config=config) + assert policy.target_entropy == -5.0 + + +def test_sac_policy_default_target_entropy_with_discrete_action(): + config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True) + policy = SACPolicy(config=config) + assert policy.target_entropy == -3.0 + + +def test_sac_policy_with_predefined_entropy(): + config = create_default_config(state_dim=10, continuous_action_dim=6) + config.target_entropy = -3.5 + + policy = SACPolicy(config=config) + assert policy.target_entropy == pytest.approx(-3.5) + + +def test_sac_policy_update_temperature(): + config = create_default_config(continuous_action_dim=10, state_dim=10) + policy = SACPolicy(config=config) + + assert policy.temperature == pytest.approx(1.0) + policy.log_alpha.data = torch.tensor([math.log(0.1)]) + policy.update_temperature() + assert policy.temperature == pytest.approx(0.1) + + +def test_sac_policy_update_target_network(): + config = create_default_config(state_dim=10, continuous_action_dim=6) + config.critic_target_update_weight = 1.0 + + policy = SACPolicy(config=config) + policy.train() + + for p in policy.critic_ensemble.parameters(): + p.data = torch.ones_like(p.data) + + policy.update_target_networks() + for p in policy.critic_target.parameters(): + assert torch.allclose(p.data, torch.ones_like(p.data)), ( + f"Target network {p.data} is not equal to {torch.ones_like(p.data)}" + ) + + +@pytest.mark.parametrize("num_critics", [1, 3]) +def test_sac_policy_with_critics_number_of_heads(num_critics: int): + batch_size = 2 + action_dim = 10 + state_dim = 10 + config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) + config.num_critics = num_critics + + policy = SACPolicy(config=config) + policy.train() + + assert len(policy.critic_ensemble.critics) == num_critics + + batch = create_train_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim, action_dim=action_dim + ) + + policy.train() + + optimizers = make_optimizers(policy) + + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + assert cirtic_loss.item() is not None + assert cirtic_loss.shape == () + cirtic_loss.backward() + optimizers["critic"].step() + + +def test_sac_policy_save_and_load(tmp_path): + root = tmp_path / "test_sac_save_and_load" + + state_dim = 10 + action_dim = 10 + batch_size = 2 + + config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) + policy = SACPolicy(config=config) + policy.eval() + policy.save_pretrained(root) + loaded_policy = SACPolicy.from_pretrained(root, config=config) + loaded_policy.eval() + + batch = create_default_train_batch(batch_size=1, state_dim=10, action_dim=10) + + with torch.no_grad(): + with seeded_context(12): + # Collect policy values before saving + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + actor_loss = policy.forward(batch, model="actor")["loss_actor"] + temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"] + + observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + actions = policy.select_action(observation_batch) + + with seeded_context(12): + # Collect policy values after loading + loaded_cirtic_loss = loaded_policy.forward(batch, model="critic")["loss_critic"] + loaded_actor_loss = loaded_policy.forward(batch, model="actor")["loss_actor"] + loaded_temperature_loss = loaded_policy.forward(batch, model="temperature")["loss_temperature"] + + loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + loaded_actions = loaded_policy.select_action(loaded_observation_batch) + + assert policy.state_dict().keys() == loaded_policy.state_dict().keys() + for k in policy.state_dict(): + assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6) + + # Compare values before and after saving and loading + # They should be the same + assert torch.allclose(cirtic_loss, loaded_cirtic_loss) + assert torch.allclose(actor_loss, loaded_actor_loss) + assert torch.allclose(temperature_loss, loaded_temperature_loss) + assert torch.allclose(actions, loaded_actions)