Compare commits
38 Commits
user/miche
...
user/miche
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a22fe8a6de | ||
|
|
49b5f379a7 | ||
|
|
7a3d8756b4 | ||
|
|
dc1548fe1a | ||
|
|
23c9441d5f | ||
|
|
870e3efb92 | ||
|
|
bfd48a8b70 | ||
|
|
5dc7ff6d3c | ||
|
|
ee4ebeac9b | ||
|
|
fe7b47f459 | ||
|
|
044ca3b039 | ||
|
|
bc36c69b71 | ||
|
|
2b9b05f1ba | ||
|
|
9eec7b8bb0 | ||
|
|
a80a9cf379 | ||
|
|
7a42af835e | ||
|
|
9751328783 | ||
|
|
7225bc74a3 | ||
|
|
03b1644bf7 | ||
|
|
9b6e5a383f | ||
|
|
86466b025f | ||
|
|
54745f111d | ||
|
|
82584cca78 | ||
|
|
d3a8c2c247 | ||
|
|
74c11c4a75 | ||
|
|
2d932b710c | ||
|
|
a54baceabb | ||
|
|
077d18b439 | ||
|
|
c6cd1475a7 | ||
|
|
e35ee47b07 | ||
|
|
c3f2487026 | ||
|
|
c621077b62 | ||
|
|
f5cfd9fd48 | ||
|
|
22da1739b1 | ||
|
|
d38d5f988d | ||
|
|
8d1936ffe0 | ||
|
|
cef944e1b1 | ||
|
|
384eb2cd07 |
@@ -201,8 +201,8 @@ class EnvWrapperConfig:
|
||||
joint_masking_action_space: Optional[Any] = None
|
||||
ee_action_space_params: Optional[EEActionSpaceConfig] = None
|
||||
use_gripper: bool = False
|
||||
gripper_quantization_threshold: float | None = None
|
||||
gripper_penalty: float = 0.0
|
||||
gripper_quantization_threshold: float | None = 0.8
|
||||
gripper_penalty: float = 0.0
|
||||
gripper_penalty_in_reward: bool = False
|
||||
open_gripper_on_reset: bool = False
|
||||
|
||||
|
||||
@@ -51,8 +51,8 @@ class ActorNetworkConfig:
|
||||
@dataclass
|
||||
class PolicyConfig:
|
||||
use_tanh_squash: bool = True
|
||||
log_std_min: int = -5
|
||||
log_std_max: int = 2
|
||||
log_std_min: float = 1e-5
|
||||
log_std_max: float = 10.0
|
||||
init_final: float = 0.05
|
||||
|
||||
|
||||
@@ -86,6 +86,7 @@ class SACConfig(PreTrainedConfig):
|
||||
image_encoder_hidden_dim: Hidden dimension size for the image encoder.
|
||||
shared_encoder: Whether to use a shared encoder for actor and critic.
|
||||
num_discrete_actions: Number of discrete actions, eg for gripper actions.
|
||||
image_embedding_pooling_dim: Dimension of the image embedding pooling.
|
||||
concurrency: Configuration for concurrency settings.
|
||||
actor_learner: Configuration for actor-learner architecture.
|
||||
online_steps: Number of steps for online training.
|
||||
@@ -120,7 +121,7 @@ class SACConfig(PreTrainedConfig):
|
||||
}
|
||||
)
|
||||
|
||||
dataset_stats: dict[str, dict[str, list[float]]] = field(
|
||||
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
@@ -147,6 +148,7 @@ class SACConfig(PreTrainedConfig):
|
||||
image_encoder_hidden_dim: int = 32
|
||||
shared_encoder: bool = True
|
||||
num_discrete_actions: int | None = None
|
||||
image_embedding_pooling_dim: int = 8
|
||||
|
||||
# Training parameter
|
||||
online_steps: int = 1000000
|
||||
|
||||
@@ -22,13 +22,12 @@ from dataclasses import asdict
|
||||
from typing import Callable, List, Literal, Optional, Tuple
|
||||
|
||||
import einops
|
||||
from importlib_metadata import distribution
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.distributions import Categorical
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
|
||||
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
@@ -53,153 +52,46 @@ class SACPolicy(
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
# Determine action dimension and initialize all components
|
||||
continuous_action_dim = config.output_features["action"].shape[0]
|
||||
|
||||
if config.dataset_stats is not None:
|
||||
input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features,
|
||||
config.normalization_mapping,
|
||||
input_normalization_params,
|
||||
)
|
||||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
|
||||
output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
||||
|
||||
# HACK: This is hacky and should be removed
|
||||
dataset_stats = dataset_stats or output_normalization_params
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
# NOTE: For images the encoder should be shared between the actor and critic
|
||||
if config.shared_encoder:
|
||||
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
||||
encoder_actor: SACObservationEncoder = encoder_critic
|
||||
else:
|
||||
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
||||
encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
|
||||
self.shared_encoder = config.shared_encoder
|
||||
|
||||
# Create a list of critic heads
|
||||
critic_heads = [
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
|
||||
self.critic_ensemble = CriticEnsemble(
|
||||
encoder=encoder_critic,
|
||||
ensemble=critic_heads,
|
||||
output_normalization=self.normalize_targets,
|
||||
)
|
||||
|
||||
# Create target critic heads as deepcopies of the original critic heads
|
||||
target_critic_heads = [
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
|
||||
self.critic_target = CriticEnsemble(
|
||||
encoder=encoder_critic,
|
||||
ensemble=target_critic_heads,
|
||||
output_normalization=self.normalize_targets,
|
||||
)
|
||||
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
self.grasp_critic = None
|
||||
self.grasp_critic_target = None
|
||||
|
||||
if config.num_discrete_actions is not None:
|
||||
# Create grasp critic
|
||||
self.grasp_critic = GraspCritic(
|
||||
encoder=encoder_critic,
|
||||
input_dim=encoder_critic.output_dim,
|
||||
output_dim=config.num_discrete_actions,
|
||||
softmax_temperature=.15,
|
||||
**asdict(config.grasp_critic_network_kwargs),
|
||||
)
|
||||
|
||||
# Create target grasp critic
|
||||
self.grasp_critic_target = GraspCritic(
|
||||
encoder=encoder_critic,
|
||||
input_dim=encoder_critic.output_dim,
|
||||
output_dim=config.num_discrete_actions,
|
||||
softmax_temperature=0.15,
|
||||
**asdict(config.grasp_critic_network_kwargs),
|
||||
)
|
||||
|
||||
self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict())
|
||||
|
||||
# self.grasp_critic = torch.compile(self.grasp_critic)
|
||||
# self.grasp_critic_target = torch.compile(self.grasp_critic_target)
|
||||
|
||||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
|
||||
action_dim=continuous_action_dim,
|
||||
encoder_is_shared=config.shared_encoder,
|
||||
**asdict(config.policy_kwargs),
|
||||
)
|
||||
if config.target_entropy is None:
|
||||
config.target_entropy = -np.prod(continuous_action_dim) / 2 # (-dim(A)/2)
|
||||
|
||||
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
||||
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
|
||||
# it triggers "can't optimize a non-leaf Tensor"
|
||||
|
||||
temperature_init = config.temperature_init
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)]))
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
self._init_normalization(dataset_stats)
|
||||
self._init_encoders()
|
||||
self._init_critics(continuous_action_dim)
|
||||
self._init_actor(continuous_action_dim)
|
||||
self._init_temperature()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
optim_params = {
|
||||
"actor": self.actor.parameters_to_optimize,
|
||||
"critic": self.critic_ensemble.parameters_to_optimize,
|
||||
"actor": [
|
||||
p
|
||||
for n, p in self.actor.named_parameters()
|
||||
if not n.startswith("encoder") or not self.shared_encoder
|
||||
],
|
||||
"critic": self.critic_ensemble.parameters(),
|
||||
"temperature": self.log_alpha,
|
||||
}
|
||||
if self.config.num_discrete_actions is not None:
|
||||
optim_params["grasp_critic"] = self.grasp_critic.parameters_to_optimize
|
||||
optim_params["grasp_critic"] = self.grasp_critic.parameters()
|
||||
return optim_params
|
||||
|
||||
def reset(self):
|
||||
"""Reset the policy"""
|
||||
pass
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
"""Override .to(device) method to involve moving the log_alpha fixed_std"""
|
||||
if self.actor.fixed_std is not None:
|
||||
self.actor.fixed_std = self.actor.fixed_std.to(*args, **kwargs)
|
||||
# self.log_alpha = self.log_alpha.to(*args, **kwargs)
|
||||
super().to(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select action for inference/evaluation"""
|
||||
# We cached the encoder output to avoid recomputing it
|
||||
observations_features = None
|
||||
if self.shared_encoder:
|
||||
observations_features = self.actor.encoder.get_image_features(batch)
|
||||
# Cache and normalize image features
|
||||
observations_features = self.actor.encoder.get_cached_image_features(batch, normalize=True)
|
||||
|
||||
actions, _, _ = self.actor(batch, observations_features)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
_, discrete_action_distribution = self.grasp_critic(batch, observations_features)
|
||||
discrete_action = discrete_action_distribution.sample().unsqueeze(-1).float()
|
||||
discrete_action_value = self.grasp_critic(batch, observations_features)
|
||||
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
|
||||
actions = torch.cat([actions, discrete_action], dim=-1)
|
||||
|
||||
return actions
|
||||
@@ -434,13 +326,13 @@ class SACPolicy(
|
||||
|
||||
with torch.no_grad():
|
||||
# For DQN, select actions using online network, evaluate with target network
|
||||
next_grasp_qs, next_grasp_distribution = self.grasp_critic_forward(
|
||||
next_grasp_qs = self.grasp_critic_forward(
|
||||
next_observations, use_target=False, observation_features=next_observation_features
|
||||
)
|
||||
best_next_grasp_action = next_grasp_distribution.sample().unsqueeze(-1)
|
||||
best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1, keepdim=True)
|
||||
|
||||
# Get target Q-values from target network
|
||||
target_next_grasp_qs, _ = self.grasp_critic_forward(
|
||||
target_next_grasp_qs = self.grasp_critic_forward(
|
||||
observations=next_observations,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
@@ -451,14 +343,14 @@ class SACPolicy(
|
||||
target_next_grasp_qs, dim=1, index=best_next_grasp_action
|
||||
).squeeze(-1)
|
||||
|
||||
# Compute target Q-value with Bellman equation
|
||||
rewards_gripper = rewards
|
||||
if gripper_penalties is not None:
|
||||
rewards_gripper = rewards + gripper_penalties
|
||||
target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
|
||||
# Compute target Q-value with Bellman equation
|
||||
rewards_gripper = rewards
|
||||
if gripper_penalties is not None:
|
||||
rewards_gripper = rewards + gripper_penalties
|
||||
target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
|
||||
|
||||
# Get predicted Q-values for current observations
|
||||
predicted_grasp_qs, _ = self.grasp_critic_forward(
|
||||
predicted_grasp_qs = self.grasp_critic_forward(
|
||||
observations=observations, use_target=False, observation_features=observation_features
|
||||
)
|
||||
|
||||
@@ -498,107 +390,265 @@ class SACPolicy(
|
||||
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
|
||||
return actor_loss
|
||||
|
||||
def _init_normalization(self, dataset_stats):
|
||||
"""Initialize input/output normalization modules."""
|
||||
self.normalize_inputs = nn.Identity()
|
||||
self.normalize_targets = nn.Identity()
|
||||
self.unnormalize_outputs = nn.Identity()
|
||||
if self.config.dataset_stats:
|
||||
params = _convert_normalization_params_to_tensor(self.config.dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
self.config.input_features, self.config.normalization_mapping, params
|
||||
)
|
||||
stats = dataset_stats or params
|
||||
self.normalize_targets = Normalize(
|
||||
self.config.output_features, self.config.normalization_mapping, stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
self.config.output_features, self.config.normalization_mapping, stats
|
||||
)
|
||||
|
||||
def _init_encoders(self):
|
||||
"""Initialize shared or separate encoders for actor and critic."""
|
||||
self.shared_encoder = self.config.shared_encoder
|
||||
self.encoder_critic = SACObservationEncoder(self.config, self.normalize_inputs)
|
||||
self.encoder_actor = (
|
||||
self.encoder_critic
|
||||
if self.shared_encoder
|
||||
else SACObservationEncoder(self.config, self.normalize_inputs)
|
||||
)
|
||||
|
||||
def _init_critics(self, continuous_action_dim):
|
||||
"""Build critic ensemble, targets, and optional grasp critic."""
|
||||
heads = [
|
||||
CriticHead(
|
||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_ensemble = CriticEnsemble(
|
||||
encoder=self.encoder_critic, ensemble=heads, output_normalization=self.normalize_targets
|
||||
)
|
||||
target_heads = [
|
||||
CriticHead(
|
||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_target = CriticEnsemble(
|
||||
encoder=self.encoder_critic, ensemble=target_heads, output_normalization=self.normalize_targets
|
||||
)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
self._init_grasp_critics()
|
||||
|
||||
def _init_grasp_critics(self):
|
||||
"""Build discrete grasp critic ensemble and target networks."""
|
||||
self.grasp_critic = GraspCritic(
|
||||
encoder=self.encoder_critic,
|
||||
input_dim=self.encoder_critic.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.grasp_critic_network_kwargs),
|
||||
)
|
||||
self.grasp_critic_target = GraspCritic(
|
||||
encoder=self.encoder_critic,
|
||||
input_dim=self.encoder_critic.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.grasp_critic_network_kwargs),
|
||||
)
|
||||
|
||||
# TODO: (maractingi, azouitine) Compile the grasp critic
|
||||
self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict())
|
||||
|
||||
def _init_actor(self, continuous_action_dim):
|
||||
"""Initialize policy actor network and default target entropy."""
|
||||
self.actor = Policy(
|
||||
encoder=self.encoder_actor,
|
||||
network=MLP(input_dim=self.encoder_actor.output_dim, **asdict(self.config.actor_network_kwargs)),
|
||||
action_dim=continuous_action_dim,
|
||||
encoder_is_shared=self.shared_encoder,
|
||||
**asdict(self.config.policy_kwargs),
|
||||
)
|
||||
if self.config.target_entropy is None:
|
||||
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
|
||||
self.config.target_entropy = -np.prod(dim) / 2
|
||||
|
||||
def _init_temperature(self):
|
||||
"""Set up temperature parameter and initial log_alpha."""
|
||||
temp_init = self.config.temperature_init
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
|
||||
class SACObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations."""
|
||||
|
||||
def __init__(self, config: SACConfig, input_normalizer: nn.Module):
|
||||
"""
|
||||
Creates encoders for pixel and/or state modalities.
|
||||
"""
|
||||
super().__init__()
|
||||
def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None:
|
||||
super(SACObservationEncoder, self).__init__()
|
||||
self.config = config
|
||||
self.input_normalization = input_normalizer
|
||||
self.has_pretrained_vision_encoder = False
|
||||
self.parameters_to_optimize = []
|
||||
self._init_image_layers()
|
||||
self._init_state_layers()
|
||||
self._compute_output_dim()
|
||||
|
||||
self.aggregation_size: int = 0
|
||||
if any("observation.image" in key for key in config.input_features):
|
||||
self.camera_number = config.camera_number
|
||||
def _init_image_layers(self) -> None:
|
||||
self.image_keys = [k for k in self.config.input_features if k.startswith("observation.image")]
|
||||
self.has_images = bool(self.image_keys)
|
||||
if not self.has_images:
|
||||
return
|
||||
|
||||
if self.config.vision_encoder_name is not None:
|
||||
self.image_enc_layers = PretrainedImageEncoder(config)
|
||||
self.has_pretrained_vision_encoder = True
|
||||
else:
|
||||
self.image_enc_layers = DefaultImageEncoder(config)
|
||||
if self.config.vision_encoder_name:
|
||||
self.image_encoder = PretrainedImageEncoder(self.config)
|
||||
else:
|
||||
self.image_encoder = DefaultImageEncoder(self.config)
|
||||
|
||||
self.aggregation_size += config.latent_dim * self.camera_number
|
||||
if self.config.freeze_vision_encoder:
|
||||
freeze_image_encoder(self.image_encoder)
|
||||
|
||||
if config.freeze_vision_encoder:
|
||||
freeze_image_encoder(self.image_enc_layers)
|
||||
else:
|
||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
|
||||
dummy = torch.zeros(1, *self.config.input_features[self.image_keys[0]].shape)
|
||||
with torch.no_grad():
|
||||
_, channels, height, width = self.image_encoder(dummy).shape
|
||||
|
||||
if "observation.state" in config.input_features:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
self.spatial_embeddings = nn.ModuleDict()
|
||||
self.post_encoders = nn.ModuleDict()
|
||||
|
||||
for key in self.image_keys:
|
||||
name = key.replace(".", "_")
|
||||
self.spatial_embeddings[name] = SpatialLearnedEmbeddings(
|
||||
height=height,
|
||||
width=width,
|
||||
channel=channels,
|
||||
num_features=self.config.image_embedding_pooling_dim,
|
||||
)
|
||||
self.post_encoders[name] = nn.Sequential(
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(
|
||||
in_features=config.input_features["observation.state"].shape[0],
|
||||
out_features=config.latent_dim,
|
||||
in_features=channels * self.config.image_embedding_pooling_dim,
|
||||
out_features=self.config.latent_dim,
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.LayerNorm(normalized_shape=self.config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self.aggregation_size += config.latent_dim
|
||||
|
||||
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
|
||||
|
||||
if "observation.environment_state" in config.input_features:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
in_features=config.input_features["observation.environment_state"].shape[0],
|
||||
out_features=config.latent_dim,
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
def _init_state_layers(self) -> None:
|
||||
self.has_env = "observation.environment_state" in self.config.input_features
|
||||
self.has_state = "observation.state" in self.config.input_features
|
||||
if self.has_env:
|
||||
dim = self.config.input_features["observation.environment_state"].shape[0]
|
||||
self.env_encoder = nn.Sequential(
|
||||
nn.Linear(dim, self.config.latent_dim),
|
||||
nn.LayerNorm(self.config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
if self.has_state:
|
||||
dim = self.config.input_features["observation.state"].shape[0]
|
||||
self.state_encoder = nn.Sequential(
|
||||
nn.Linear(dim, self.config.latent_dim),
|
||||
nn.LayerNorm(self.config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self.aggregation_size += config.latent_dim
|
||||
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
|
||||
|
||||
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
||||
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
|
||||
def _compute_output_dim(self) -> None:
|
||||
out = 0
|
||||
if self.has_images:
|
||||
out += len(self.image_keys) * self.config.latent_dim
|
||||
if self.has_env:
|
||||
out += self.config.latent_dim
|
||||
if self.has_state:
|
||||
out += self.config.latent_dim
|
||||
self._out_dim = out
|
||||
|
||||
def forward(
|
||||
self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None
|
||||
self, obs: dict[str, Tensor], cache: Optional[dict[str, Tensor]] = None, detach: bool = False
|
||||
) -> Tensor:
|
||||
"""Encode the image and/or state vector.
|
||||
obs = self.input_normalization(obs)
|
||||
parts = []
|
||||
if self.has_images:
|
||||
if cache is None:
|
||||
cache = self.get_cached_image_features(obs, normalize=False)
|
||||
parts.append(self._encode_images(cache, detach))
|
||||
if self.has_env:
|
||||
parts.append(self.env_encoder(obs["observation.environment_state"]))
|
||||
if self.has_state:
|
||||
parts.append(self.state_encoder(obs["observation.state"]))
|
||||
if parts:
|
||||
return torch.cat(parts, dim=-1)
|
||||
|
||||
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
|
||||
over all features.
|
||||
raise ValueError(
|
||||
"No parts to concatenate, you should have at least one image or environment state or state"
|
||||
)
|
||||
|
||||
def get_cached_image_features(self, obs: dict[str, Tensor], normalize: bool = False) -> dict[str, Tensor]:
|
||||
"""Extract and optionally cache image features from observations.
|
||||
|
||||
This function processes image observations through the vision encoder once and returns
|
||||
the resulting features.
|
||||
When the image encoder is shared between actor and critics AND frozen, these features can be safely cached and
|
||||
reused across policy components (actor, critic, grasp_critic), avoiding redundant forward passes.
|
||||
|
||||
Performance impact:
|
||||
- The vision encoder forward pass is typically the main computational bottleneck during training and inference
|
||||
- Caching these features can provide 2-4x speedup in training and inference
|
||||
|
||||
Normalization behavior:
|
||||
- When called from inside forward(): set normalize=False since inputs are already normalized
|
||||
- When called from outside forward(): set normalize=True to ensure proper input normalization
|
||||
|
||||
Usage patterns:
|
||||
- Called in select_action() with normalize=True
|
||||
- Called in learner_server.py's get_observation_features() to pre-compute features for all policy components
|
||||
- Called internally by forward() with normalize=False
|
||||
|
||||
Args:
|
||||
obs: Dictionary of observation tensors containing image keys
|
||||
normalize: Whether to normalize observations before encoding
|
||||
Set to True when calling directly from outside the encoder's forward method
|
||||
Set to False when calling from within forward() where inputs are already normalized
|
||||
|
||||
Returns:
|
||||
Dictionary mapping image keys to their corresponding encoded features
|
||||
"""
|
||||
feat = []
|
||||
obs_dict = self.input_normalization(obs_dict)
|
||||
if len(self.all_image_keys) > 0 and vision_encoder_cache is None:
|
||||
vision_encoder_cache = self.get_image_features(obs_dict)
|
||||
if normalize:
|
||||
obs = self.input_normalization(obs)
|
||||
batched = torch.cat([obs[k] for k in self.image_keys], dim=0)
|
||||
out = self.image_encoder(batched)
|
||||
chunks = torch.chunk(out, len(self.image_keys), dim=0)
|
||||
return dict(zip(self.image_keys, chunks, strict=False))
|
||||
|
||||
if vision_encoder_cache is not None:
|
||||
feat.append(vision_encoder_cache)
|
||||
def _encode_images(self, cache: dict[str, Tensor], detach: bool) -> Tensor:
|
||||
"""Encode image features from cached observations.
|
||||
|
||||
if "observation.environment_state" in self.config.input_features:
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_features:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
This function takes pre-encoded image features from the cache and applies spatial embeddings and post-encoders.
|
||||
It also supports detaching the encoded features if specified.
|
||||
|
||||
features = torch.cat(tensors=feat, dim=-1)
|
||||
features = self.aggregation_layer(features)
|
||||
Args:
|
||||
cache (dict[str, Tensor]): The cached image features.
|
||||
detach (bool): Usually when the encoder is shared between actor and critics,
|
||||
we want to detach the encoded features on the policy side to avoid backprop through the encoder.
|
||||
More detail here `https://cdn.aaai.org/ojs/17276/17276-13-20770-1-2-20210518.pdf`
|
||||
|
||||
return features
|
||||
|
||||
def get_image_features(self, batch: dict[str, Tensor]) -> torch.Tensor:
|
||||
# [N*B, C, H, W]
|
||||
if len(self.all_image_keys) > 0:
|
||||
# Batch all images along the batch dimension, then encode them.
|
||||
images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0)
|
||||
images_batched = self.image_enc_layers(images_batched)
|
||||
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
|
||||
embeddings_image = torch.cat(embeddings_chunks, dim=-1)
|
||||
return embeddings_image
|
||||
return None
|
||||
Returns:
|
||||
Tensor: The encoded image features.
|
||||
"""
|
||||
feats = []
|
||||
for k, feat in cache.items():
|
||||
safe_key = k.replace(".", "_")
|
||||
x = self.spatial_embeddings[safe_key](feat)
|
||||
x = self.post_encoders[safe_key](x)
|
||||
if detach:
|
||||
x = x.detach()
|
||||
feats.append(x)
|
||||
return torch.cat(feats, dim=-1)
|
||||
|
||||
@property
|
||||
def output_dim(self) -> int:
|
||||
"""Returns the dimension of the encoder output"""
|
||||
return self.config.latent_dim
|
||||
return self._out_dim
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
@@ -734,12 +784,6 @@ class CriticEnsemble(nn.Module):
|
||||
self.output_normalization = output_normalization
|
||||
self.critics = nn.ModuleList(ensemble)
|
||||
|
||||
self.parameters_to_optimize = []
|
||||
# Handle the case where a part of the encoder if frozen
|
||||
if self.encoder is not None:
|
||||
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
|
||||
self.parameters_to_optimize += list(self.critics.parameters())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: dict[str, torch.Tensor],
|
||||
@@ -755,7 +799,7 @@ class CriticEnsemble(nn.Module):
|
||||
actions = self.output_normalization(actions)["action"]
|
||||
actions = actions.to(device)
|
||||
|
||||
obs_enc = self.encoder(observations, observation_features)
|
||||
obs_enc = self.encoder(observations, cache=observation_features)
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
|
||||
@@ -781,12 +825,10 @@ class GraspCritic(nn.Module):
|
||||
dropout_rate: Optional[float] = None,
|
||||
init_final: Optional[float] = None,
|
||||
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
|
||||
softmax_temperature: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.output_dim = output_dim
|
||||
self.softmax_temperature = softmax_temperature
|
||||
|
||||
self.net = MLP(
|
||||
input_dim=input_dim,
|
||||
@@ -798,27 +840,20 @@ class GraspCritic(nn.Module):
|
||||
)
|
||||
|
||||
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=self.output_dim)
|
||||
init_final = 0.05
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
self.parameters_to_optimize = []
|
||||
self.parameters_to_optimize += list(self.net.parameters())
|
||||
self.parameters_to_optimize += list(self.output_layer.parameters())
|
||||
|
||||
def forward(
|
||||
self, observations: torch.Tensor, observation_features: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
# Move each tensor in observations to device by cloning first to avoid inplace operations
|
||||
observations = {k: v.to(device) for k, v in observations.items()}
|
||||
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
|
||||
q_values = self.output_layer(self.net(obs_enc))
|
||||
distribution = Categorical(logits=q_values / self.softmax_temperature)
|
||||
return q_values, distribution
|
||||
obs_enc = self.encoder(observations, cache=observation_features)
|
||||
return self.output_layer(self.net(obs_enc))
|
||||
|
||||
|
||||
class Policy(nn.Module):
|
||||
@@ -842,12 +877,8 @@ class Policy(nn.Module):
|
||||
self.log_std_max = log_std_max
|
||||
self.fixed_std = fixed_std
|
||||
self.use_tanh_squash = use_tanh_squash
|
||||
self.parameters_to_optimize = []
|
||||
self.encoder_is_shared = encoder_is_shared
|
||||
|
||||
self.parameters_to_optimize += list(self.network.parameters())
|
||||
|
||||
if self.encoder is not None and not encoder_is_shared:
|
||||
self.parameters_to_optimize += list(self.encoder.parameters())
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
@@ -861,7 +892,6 @@ class Policy(nn.Module):
|
||||
else:
|
||||
orthogonal_init()(self.mean_layer.weight)
|
||||
|
||||
self.parameters_to_optimize += list(self.mean_layer.parameters())
|
||||
# Standard deviation layer or parameter
|
||||
if fixed_std is None:
|
||||
self.std_layer = nn.Linear(out_features, action_dim)
|
||||
@@ -870,15 +900,15 @@ class Policy(nn.Module):
|
||||
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.std_layer.weight)
|
||||
self.parameters_to_optimize += list(self.std_layer.parameters())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
observation_features: torch.Tensor | None = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Encode observations if encoder exists
|
||||
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# We detach the encoder if it is shared to avoid backprop through it
|
||||
# This is important to avoid the encoder to be updated through the policy
|
||||
obs_enc = self.encoder(observations, cache=observation_features, detach=self.encoder_is_shared)
|
||||
|
||||
# Get network outputs
|
||||
outputs = self.network(obs_enc)
|
||||
@@ -887,29 +917,20 @@ class Policy(nn.Module):
|
||||
# Compute standard deviations
|
||||
if self.fixed_std is None:
|
||||
log_std = self.std_layer(outputs)
|
||||
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
|
||||
|
||||
if self.use_tanh_squash:
|
||||
log_std = torch.tanh(log_std)
|
||||
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
|
||||
else:
|
||||
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
|
||||
std = torch.exp(log_std) # Match JAX "exp"
|
||||
std = torch.clamp(std, self.log_std_min, self.log_std_max) # Match JAX default clip
|
||||
else:
|
||||
log_std = self.fixed_std.expand_as(means)
|
||||
|
||||
# uses tanh activation function to squash the action to be in the range of [-1, 1]
|
||||
normal = torch.distributions.Normal(means, torch.exp(log_std))
|
||||
x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1))
|
||||
log_probs = normal.log_prob(x_t) # Base log probability before Tanh
|
||||
# Build transformed distribution
|
||||
dist = TanhMultivariateNormalDiag(loc=means, scale_diag=std)
|
||||
|
||||
if self.use_tanh_squash:
|
||||
actions = torch.tanh(x_t)
|
||||
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
|
||||
else:
|
||||
actions = x_t # No Tanh; raw Gaussian sample
|
||||
# Sample actions (reparameterized)
|
||||
actions = dist.rsample()
|
||||
|
||||
# Compute log_probs
|
||||
log_probs = dist.log_prob(actions)
|
||||
|
||||
log_probs = log_probs.sum(-1) # Sum over action dimensions
|
||||
means = torch.tanh(means) if self.use_tanh_squash else means
|
||||
return actions, log_probs, means
|
||||
|
||||
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
@@ -957,21 +978,16 @@ class DefaultImageEncoder(nn.Module):
|
||||
nn.ReLU(),
|
||||
)
|
||||
# Get first image key from input features
|
||||
image_key = next(key for key in config.input_features.keys() if key.startswith("observation.image")) # noqa: SIM118
|
||||
dummy_batch = torch.zeros(1, *config.input_features[image_key].shape)
|
||||
with torch.inference_mode():
|
||||
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||
self.image_enc_layers.extend(
|
||||
nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.image_enc_layers(x)
|
||||
x = self.image_enc_layers(x)
|
||||
return x
|
||||
|
||||
|
||||
def freeze_image_encoder(image_encoder: nn.Module):
|
||||
"""Freeze all parameters in the encoder"""
|
||||
for param in image_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
class PretrainedImageEncoder(nn.Module):
|
||||
@@ -979,18 +995,12 @@ class PretrainedImageEncoder(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
||||
self.image_enc_proj = nn.Sequential(
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def _load_pretrained_vision_encoder(self, config: SACConfig):
|
||||
"""Set up CNN encoder"""
|
||||
from transformers import AutoModel
|
||||
|
||||
self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True)
|
||||
# self.image_enc_layers.pooler = Identity()
|
||||
|
||||
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
|
||||
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
|
||||
@@ -1001,19 +1011,10 @@ class PretrainedImageEncoder(nn.Module):
|
||||
return self.image_enc_layers, self.image_enc_out_shape
|
||||
|
||||
def forward(self, x):
|
||||
# TODO: (maractingi, azouitine) check the forward pass of the pretrained model
|
||||
# doesn't reach the classifier layer because we don't need it
|
||||
enc_feat = self.image_enc_layers(x).pooler_output
|
||||
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
|
||||
enc_feat = self.image_enc_layers(x).last_hidden_state
|
||||
return enc_feat
|
||||
|
||||
|
||||
def freeze_image_encoder(image_encoder: nn.Module):
|
||||
"""Freeze all parameters in the encoder"""
|
||||
for param in image_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
def orthogonal_init():
|
||||
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
||||
|
||||
@@ -1026,6 +1027,112 @@ class Identity(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class SpatialLearnedEmbeddings(nn.Module):
|
||||
def __init__(self, height, width, channel, num_features=8):
|
||||
"""
|
||||
PyTorch implementation of learned spatial embeddings
|
||||
|
||||
Args:
|
||||
height: Spatial height of input features
|
||||
width: Spatial width of input features
|
||||
channel: Number of input channels
|
||||
num_features: Number of output embedding dimensions
|
||||
"""
|
||||
super().__init__()
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.channel = channel
|
||||
self.num_features = num_features
|
||||
|
||||
self.kernel = nn.Parameter(torch.empty(channel, height, width, num_features))
|
||||
|
||||
nn.init.kaiming_normal_(self.kernel, mode="fan_in", nonlinearity="linear")
|
||||
|
||||
def forward(self, features):
|
||||
"""
|
||||
Forward pass for spatial embedding
|
||||
|
||||
Args:
|
||||
features: Input tensor of shape [B, C, H, W] where B is batch size,
|
||||
C is number of channels, H is height, and W is width
|
||||
Returns:
|
||||
Output tensor of shape [B, C*F] where F is the number of features
|
||||
"""
|
||||
|
||||
features_expanded = features.unsqueeze(-1) # [B, C, H, W, 1]
|
||||
kernel_expanded = self.kernel.unsqueeze(0) # [1, C, H, W, F]
|
||||
|
||||
# Element-wise multiplication and spatial reduction
|
||||
output = (features_expanded * kernel_expanded).sum(dim=(2, 3)) # Sum over H,W dimensions
|
||||
|
||||
# Reshape to combine channel and feature dimensions
|
||||
output = output.view(output.size(0), -1) # [B, C*F]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class RescaleFromTanh(Transform):
|
||||
def __init__(self, low: float = -1, high: float = 1):
|
||||
super().__init__()
|
||||
|
||||
self.low = low
|
||||
|
||||
self.high = high
|
||||
|
||||
def _call(self, x):
|
||||
# Rescale from (-1, 1) to (low, high)
|
||||
|
||||
return 0.5 * (x + 1.0) * (self.high - self.low) + self.low
|
||||
|
||||
def _inverse(self, y):
|
||||
# Rescale from (low, high) back to (-1, 1)
|
||||
|
||||
return 2.0 * (y - self.low) / (self.high - self.low) - 1.0
|
||||
|
||||
def log_abs_det_jacobian(self, x, y):
|
||||
# log|d(rescale)/dx| = sum(log(0.5 * (high - low)))
|
||||
|
||||
scale = 0.5 * (self.high - self.low)
|
||||
|
||||
return torch.sum(torch.log(scale), dim=-1)
|
||||
|
||||
|
||||
class TanhMultivariateNormalDiag(TransformedDistribution):
|
||||
def __init__(self, loc, scale_diag, low=None, high=None):
|
||||
base_dist = MultivariateNormal(loc, torch.diag_embed(scale_diag))
|
||||
|
||||
transforms = [TanhTransform(cache_size=1)]
|
||||
|
||||
if low is not None and high is not None:
|
||||
low = torch.as_tensor(low)
|
||||
|
||||
high = torch.as_tensor(high)
|
||||
|
||||
transforms.insert(0, RescaleFromTanh(low, high))
|
||||
|
||||
super().__init__(base_dist, transforms)
|
||||
|
||||
def mode(self):
|
||||
# Mode is mean of base distribution, passed through transforms
|
||||
|
||||
x = self.base_dist.mean
|
||||
|
||||
for transform in self.transforms:
|
||||
x = transform(x)
|
||||
|
||||
return x
|
||||
|
||||
def stddev(self):
|
||||
std = self.base_dist.stddev
|
||||
|
||||
x = std
|
||||
|
||||
for transform in self.transforms:
|
||||
x = transform(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
||||
converted_params = {}
|
||||
for outer_key, inner_dict in normalization_params.items():
|
||||
@@ -1036,90 +1143,3 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
||||
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
|
||||
|
||||
return converted_params
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# # Benchmark the CriticEnsemble performance
|
||||
# import time
|
||||
|
||||
# # Configuration
|
||||
# num_critics = 10
|
||||
# batch_size = 32
|
||||
# action_dim = 7
|
||||
# obs_dim = 64
|
||||
# hidden_dims = [256, 256]
|
||||
# num_iterations = 100
|
||||
|
||||
# print("Creating test environment...")
|
||||
|
||||
# # Create a simple dummy encoder
|
||||
# class DummyEncoder(nn.Module):
|
||||
# def __init__(self):
|
||||
# super().__init__()
|
||||
# self.output_dim = obs_dim
|
||||
# self.parameters_to_optimize = []
|
||||
|
||||
# def forward(self, obs):
|
||||
# # Just return a random tensor of the right shape
|
||||
# # In practice, this would encode the observations
|
||||
# return torch.randn(batch_size, obs_dim, device=device)
|
||||
|
||||
# # Create critic heads
|
||||
# print(f"Creating {num_critics} critic heads...")
|
||||
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
# critic_heads = [
|
||||
# CriticHead(
|
||||
# input_dim=obs_dim + action_dim,
|
||||
# hidden_dims=hidden_dims,
|
||||
# ).to(device)
|
||||
# for _ in range(num_critics)
|
||||
# ]
|
||||
|
||||
# # Create the critic ensemble
|
||||
# print("Creating CriticEnsemble...")
|
||||
# critic_ensemble = CriticEnsemble(
|
||||
# encoder=DummyEncoder().to(device),
|
||||
# ensemble=critic_heads,
|
||||
# output_normalization=nn.Identity(),
|
||||
# ).to(device)
|
||||
|
||||
# # Create random input data
|
||||
# print("Creating input data...")
|
||||
# obs_dict = {
|
||||
# "observation.state": torch.randn(batch_size, obs_dim, device=device),
|
||||
# }
|
||||
# actions = torch.randn(batch_size, action_dim, device=device)
|
||||
|
||||
# # Warmup run
|
||||
# print("Warming up...")
|
||||
# _ = critic_ensemble(obs_dict, actions)
|
||||
|
||||
# # Time the forward pass
|
||||
# print(f"Running benchmark with {num_iterations} iterations...")
|
||||
# start_time = time.perf_counter()
|
||||
# for _ in range(num_iterations):
|
||||
# q_values = critic_ensemble(obs_dict, actions)
|
||||
# end_time = time.perf_counter()
|
||||
|
||||
# # Print results
|
||||
# elapsed_time = end_time - start_time
|
||||
# print(f"Total time: {elapsed_time:.4f} seconds")
|
||||
# print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms")
|
||||
# print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size]
|
||||
|
||||
# Verify that all critic heads produce different outputs
|
||||
# This confirms each critic head is unique
|
||||
# print("\nVerifying critic outputs are different:")
|
||||
# for i in range(num_critics):
|
||||
# for j in range(i + 1, num_critics):
|
||||
# diff = torch.abs(q_values[i] - q_values[j]).mean().item()
|
||||
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}")
|
||||
|
||||
from lerobot.configs import parser
|
||||
|
||||
@parser.wrap()
|
||||
def main(config: SACConfig):
|
||||
policy = SACPolicy(config=config)
|
||||
print("yolo")
|
||||
|
||||
main()
|
||||
|
||||
@@ -363,8 +363,6 @@ def replay(
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action = actions[idx]["action"]
|
||||
# if replay_delta_actions:
|
||||
# action = action + current_joint_positions
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
|
||||
@@ -231,6 +231,7 @@ def act_with_policy(
|
||||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
policy = policy.eval()
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
obs, info = online_env.reset()
|
||||
@@ -250,18 +251,28 @@ def act_with_policy(
|
||||
logging.info("[ACTOR] Shutting down act_with_policy")
|
||||
return
|
||||
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with TimerManager(
|
||||
elapsed_time_list=list_policy_time,
|
||||
label="Policy inference time",
|
||||
log=False,
|
||||
) as timer: # noqa: F841
|
||||
action = policy.select_action(batch=obs)
|
||||
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
|
||||
if interaction_step >= cfg.policy.online_step_before_learning:
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with TimerManager(
|
||||
elapsed_time_list=list_policy_time,
|
||||
label="Policy inference time",
|
||||
log=False,
|
||||
) as timer: # noqa: F841
|
||||
action = policy.select_action(batch=obs)
|
||||
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
|
||||
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
|
||||
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
|
||||
else:
|
||||
# TODO (azouitine): Make a custom space for torch tensor
|
||||
action = online_env.action_space.sample()
|
||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||
|
||||
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
|
||||
action = (
|
||||
torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0)
|
||||
)
|
||||
|
||||
sum_reward_episode += float(reward)
|
||||
# Increment total steps counter for intervention rate
|
||||
@@ -281,7 +292,7 @@ def act_with_policy(
|
||||
for key, tensor in obs.items():
|
||||
if torch.isnan(tensor).any():
|
||||
logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}")
|
||||
|
||||
|
||||
list_transition_to_send_to_learner.append(
|
||||
Transition(
|
||||
state=obs,
|
||||
|
||||
@@ -78,9 +78,7 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr
|
||||
if isinstance(val, torch.Tensor):
|
||||
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
|
||||
elif isinstance(val, (int, float, bool)):
|
||||
transition["complementary_info"][key] = torch.tensor(
|
||||
val, device=device
|
||||
)
|
||||
transition["complementary_info"][key] = torch.tensor(val, device=device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
|
||||
return transition
|
||||
@@ -269,7 +267,7 @@ class ReplayBuffer:
|
||||
self.complementary_info[key] = torch.empty(
|
||||
(self.capacity, *value_shape), device=self.storage_device
|
||||
)
|
||||
elif isinstance(value, (int, float, bool)):
|
||||
elif isinstance(value, (int, float)):
|
||||
# Handle scalar values similar to reward
|
||||
self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device)
|
||||
else:
|
||||
@@ -563,7 +561,6 @@ class ReplayBuffer:
|
||||
else:
|
||||
first_action = first_action[:, action_mask]
|
||||
|
||||
|
||||
# Get complementary info if available
|
||||
first_complementary_info = None
|
||||
if (
|
||||
@@ -594,7 +591,6 @@ class ReplayBuffer:
|
||||
else:
|
||||
action = action[:, action_mask]
|
||||
|
||||
|
||||
replay_buffer.add(
|
||||
state=data["state"],
|
||||
action=action,
|
||||
|
||||
@@ -258,25 +258,25 @@ class GamepadController(InputController):
|
||||
elif event.button == 0:
|
||||
self.episode_end_status = "rerecord_episode"
|
||||
|
||||
# LT button for closing gripper
|
||||
# RB button (6) for opening gripper
|
||||
elif event.button == 6:
|
||||
self.close_gripper_command = True
|
||||
|
||||
# RB button for opening gripper
|
||||
elif event.button == 7:
|
||||
self.open_gripper_command = True
|
||||
|
||||
# LT button (7) for closing gripper
|
||||
elif event.button == 7:
|
||||
self.close_gripper_command = True
|
||||
|
||||
# Reset episode status on button release
|
||||
elif event.type == pygame.JOYBUTTONUP:
|
||||
if event.button in [0, 2, 3]:
|
||||
self.episode_end_status = None
|
||||
|
||||
if event.button == 6:
|
||||
self.close_gripper_command = False
|
||||
|
||||
if event.button == 7:
|
||||
|
||||
elif event.button == 6:
|
||||
self.open_gripper_command = False
|
||||
|
||||
elif event.button == 7:
|
||||
self.close_gripper_command = False
|
||||
|
||||
# Check for RB button (typically button 5) for intervention flag
|
||||
if self.joystick.get_button(5):
|
||||
self.intervention_flag = True
|
||||
|
||||
@@ -553,9 +553,6 @@ class ImageCropResizeWrapper(gym.Wrapper):
|
||||
# TODO(michel-aractingi): Bug in resize, it returns values outside [0, 1]
|
||||
obs[k] = obs[k].clamp(0.0, 1.0)
|
||||
|
||||
# import cv2
|
||||
# cv2.imwrite(f"tmp_img/{k}.jpg", cv2.cvtColor(obs[k].squeeze(0).permute(1,2,0).cpu().numpy()*255, cv2.COLOR_RGB2BGR))
|
||||
|
||||
# Check for NaNs after processing
|
||||
if torch.isnan(obs[k]).any():
|
||||
logging.error(f"NaN values detected in observation {k} after crop and resize")
|
||||
@@ -721,7 +718,7 @@ class ResetWrapper(gym.Wrapper):
|
||||
env: HILSerlRobotEnv,
|
||||
reset_pose: np.ndarray | None = None,
|
||||
reset_time_s: float = 5,
|
||||
open_gripper_on_reset: bool = False
|
||||
open_gripper_on_reset: bool = False,
|
||||
):
|
||||
super().__init__(env)
|
||||
self.reset_time_s = reset_time_s
|
||||
@@ -730,8 +727,6 @@ class ResetWrapper(gym.Wrapper):
|
||||
self.open_gripper_on_reset = open_gripper_on_reset
|
||||
|
||||
def reset(self, *, seed=None, options=None):
|
||||
|
||||
|
||||
if self.reset_pose is not None:
|
||||
start_time = time.perf_counter()
|
||||
log_say("Reset the environment.", play_sounds=True)
|
||||
@@ -780,14 +775,13 @@ class GripperPenaltyWrapper(gym.RewardWrapper):
|
||||
self.penalty = penalty
|
||||
self.gripper_penalty_in_reward = gripper_penalty_in_reward
|
||||
self.last_gripper_state = None
|
||||
|
||||
|
||||
def reward(self, reward, action):
|
||||
gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND
|
||||
|
||||
action_normalized = action - 1.0 #action / MAX_GRIPPER_COMMAND
|
||||
action_normalized = action - 1.0 # action / MAX_GRIPPER_COMMAND
|
||||
|
||||
gripper_penalty_bool = (gripper_state_normalized < 0.75 and action_normalized > 0.5) or (
|
||||
gripper_penalty_bool = (gripper_state_normalized < 0.5 and action_normalized > 0.5) or (
|
||||
gripper_state_normalized > 0.75 and action_normalized < -0.5
|
||||
)
|
||||
|
||||
@@ -806,7 +800,7 @@ class GripperPenaltyWrapper(gym.RewardWrapper):
|
||||
reward += gripper_penalty
|
||||
else:
|
||||
info["gripper_penalty"] = gripper_penalty
|
||||
|
||||
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, **kwargs):
|
||||
@@ -816,6 +810,7 @@ class GripperPenaltyWrapper(gym.RewardWrapper):
|
||||
info["gripper_penalty"] = 0.0
|
||||
return obs, info
|
||||
|
||||
|
||||
class GripperActionWrapper(gym.ActionWrapper):
|
||||
def __init__(self, env, quantization_threshold: float = 0.2):
|
||||
super().__init__(env)
|
||||
@@ -1192,11 +1187,13 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
||||
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
||||
if cfg.wrapper.use_gripper:
|
||||
env = GripperActionWrapper(
|
||||
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
|
||||
)
|
||||
env = GripperActionWrapper(env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold)
|
||||
if cfg.wrapper.gripper_penalty is not None:
|
||||
env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty, gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward)
|
||||
env = GripperPenaltyWrapper(
|
||||
env=env,
|
||||
penalty=cfg.wrapper.gripper_penalty,
|
||||
gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward,
|
||||
)
|
||||
|
||||
if cfg.wrapper.ee_action_space_params is not None:
|
||||
env = EEActionWrapper(
|
||||
@@ -1221,7 +1218,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||
env=env,
|
||||
reset_pose=cfg.wrapper.fixed_reset_joint_positions,
|
||||
reset_time_s=cfg.wrapper.reset_time_s,
|
||||
open_gripper_on_reset=cfg.wrapper.open_gripper_on_reset
|
||||
open_gripper_on_reset=cfg.wrapper.open_gripper_on_reset,
|
||||
)
|
||||
if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
|
||||
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
|
||||
|
||||
@@ -380,7 +380,6 @@ def add_actor_information_and_train(
|
||||
for _ in range(utd_ratio - 1):
|
||||
# Sample from the iterators
|
||||
batch = next(online_iterator)
|
||||
# batch = replay_buffer.sample(batch_size)
|
||||
|
||||
if dataset_repo_id is not None:
|
||||
batch_offline = next(offline_iterator)
|
||||
@@ -408,7 +407,7 @@ def add_actor_information_and_train(
|
||||
"done": done,
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
"complementary_info": batch.get("complementary_info", None),
|
||||
"complementary_info": batch["complementary_info"],
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss (includes both main critic and grasp critic)
|
||||
@@ -430,7 +429,7 @@ def add_actor_information_and_train(
|
||||
optimizers["grasp_critic"].zero_grad()
|
||||
loss_grasp_critic.backward()
|
||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
optimizers["grasp_critic"].step()
|
||||
|
||||
@@ -439,11 +438,9 @@ def add_actor_information_and_train(
|
||||
|
||||
# Sample for the last update in the UTD ratio
|
||||
batch = next(online_iterator)
|
||||
# batch = replay_buffer.sample(batch_size)
|
||||
|
||||
if dataset_repo_id is not None:
|
||||
batch_offline = next(offline_iterator)
|
||||
# batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
@@ -496,7 +493,7 @@ def add_actor_information_and_train(
|
||||
optimizers["grasp_critic"].zero_grad()
|
||||
loss_grasp_critic.backward()
|
||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["grasp_critic"].step()
|
||||
|
||||
@@ -513,7 +510,7 @@ def add_actor_information_and_train(
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||
parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["actor"].step()
|
||||
|
||||
@@ -775,15 +772,18 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
|
||||
|
||||
"""
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
|
||||
params=policy.actor.parameters_to_optimize,
|
||||
params=[
|
||||
p
|
||||
for n, p in policy.actor.named_parameters()
|
||||
if not policy.config.shared_encoder or not n.startswith("encoder")
|
||||
],
|
||||
lr=cfg.policy.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr)
|
||||
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
||||
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
optimizer_grasp_critic = torch.optim.Adam(
|
||||
params=policy.grasp_critic.parameters_to_optimize, lr=cfg.policy.critic_lr
|
||||
params=policy.grasp_critic.parameters(), lr=cfg.policy.critic_lr
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
||||
lr_scheduler = None
|
||||
@@ -1027,8 +1027,10 @@ def get_observation_features(
|
||||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
observation_features = policy.actor.encoder.get_image_features(observations)
|
||||
next_observation_features = policy.actor.encoder.get_image_features(next_observations)
|
||||
observation_features = policy.actor.encoder.get_cached_image_features(observations, normalize=True)
|
||||
next_observation_features = policy.actor.encoder.get_cached_image_features(
|
||||
next_observations, normalize=True
|
||||
)
|
||||
|
||||
return observation_features, next_observation_features
|
||||
|
||||
@@ -1090,6 +1092,44 @@ def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
|
||||
parameters_queue.put(state_bytes)
|
||||
|
||||
|
||||
def check_weight_gradients(module: nn.Module) -> dict[str, bool]:
|
||||
"""
|
||||
Checks whether each parameter in the module has a gradient.
|
||||
|
||||
Args:
|
||||
module (nn.Module): A PyTorch module whose parameters will be inspected.
|
||||
|
||||
Returns:
|
||||
dict[str, bool]: A dictionary where each key is the parameter name and the value is
|
||||
True if the parameter has an associated gradient (i.e. .grad is not None),
|
||||
otherwise False.
|
||||
"""
|
||||
grad_status = {}
|
||||
for name, param in module.named_parameters():
|
||||
grad_status[name] = param.grad is not None
|
||||
return grad_status
|
||||
|
||||
|
||||
def get_overlapping_parameters(model: nn.Module, grad_status: dict[str, bool]) -> dict[str, bool]:
|
||||
"""
|
||||
Returns a dictionary of parameters (from actor) that also exist in the grad_status dictionary.
|
||||
|
||||
Args:
|
||||
actor (nn.Module): The actor model.
|
||||
grad_status (dict[str, bool]): A dictionary where keys are parameter names and values indicate
|
||||
whether each parameter has a gradient.
|
||||
|
||||
Returns:
|
||||
dict[str, bool]: A dictionary containing only the overlapping parameter names and their gradient status.
|
||||
"""
|
||||
# Get actor parameter names as a set.
|
||||
model_param_names = {name for name, _ in model.named_parameters()}
|
||||
|
||||
# Intersect parameter names between actor and grad_status.
|
||||
overlapping = {name: grad_status[name] for name in grad_status if name in model_param_names}
|
||||
return overlapping
|
||||
|
||||
|
||||
def process_interaction_message(
|
||||
message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user