forked from tangger/lerobot
Compare commits
38 Commits
user/azoui
...
user/miche
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a22fe8a6de | ||
|
|
49b5f379a7 | ||
|
|
7a3d8756b4 | ||
|
|
dc1548fe1a | ||
|
|
23c9441d5f | ||
|
|
870e3efb92 | ||
|
|
bfd48a8b70 | ||
|
|
5dc7ff6d3c | ||
|
|
ee4ebeac9b | ||
|
|
fe7b47f459 | ||
|
|
044ca3b039 | ||
|
|
bc36c69b71 | ||
|
|
2b9b05f1ba | ||
|
|
9eec7b8bb0 | ||
|
|
a80a9cf379 | ||
|
|
7a42af835e | ||
|
|
9751328783 | ||
|
|
7225bc74a3 | ||
|
|
03b1644bf7 | ||
|
|
9b6e5a383f | ||
|
|
86466b025f | ||
|
|
54745f111d | ||
|
|
82584cca78 | ||
|
|
d3a8c2c247 | ||
|
|
74c11c4a75 | ||
|
|
2d932b710c | ||
|
|
a54baceabb | ||
|
|
077d18b439 | ||
|
|
c6cd1475a7 | ||
|
|
e35ee47b07 | ||
|
|
c3f2487026 | ||
|
|
c621077b62 | ||
|
|
f5cfd9fd48 | ||
|
|
22da1739b1 | ||
|
|
d38d5f988d | ||
|
|
8d1936ffe0 | ||
|
|
cef944e1b1 | ||
|
|
384eb2cd07 |
@@ -171,7 +171,6 @@ class VideoRecordConfig:
|
||||
class WrapperConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
|
||||
delta_action: float | None = None
|
||||
joint_masking_action_space: list[bool] | None = None
|
||||
|
||||
|
||||
@@ -191,7 +190,6 @@ class EnvWrapperConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
|
||||
display_cameras: bool = False
|
||||
delta_action: float = 0.1
|
||||
use_relative_joint_positions: bool = True
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
@@ -203,8 +201,9 @@ class EnvWrapperConfig:
|
||||
joint_masking_action_space: Optional[Any] = None
|
||||
ee_action_space_params: Optional[EEActionSpaceConfig] = None
|
||||
use_gripper: bool = False
|
||||
gripper_quantization_threshold: float = 0.8
|
||||
gripper_quantization_threshold: float | None = 0.8
|
||||
gripper_penalty: float = 0.0
|
||||
gripper_penalty_in_reward: bool = False
|
||||
open_gripper_on_reset: bool = False
|
||||
|
||||
|
||||
|
||||
@@ -51,8 +51,8 @@ class ActorNetworkConfig:
|
||||
@dataclass
|
||||
class PolicyConfig:
|
||||
use_tanh_squash: bool = True
|
||||
log_std_min: int = -5
|
||||
log_std_max: int = 2
|
||||
log_std_min: float = 1e-5
|
||||
log_std_max: float = 10.0
|
||||
init_final: float = 0.05
|
||||
|
||||
|
||||
@@ -86,6 +86,7 @@ class SACConfig(PreTrainedConfig):
|
||||
image_encoder_hidden_dim: Hidden dimension size for the image encoder.
|
||||
shared_encoder: Whether to use a shared encoder for actor and critic.
|
||||
num_discrete_actions: Number of discrete actions, eg for gripper actions.
|
||||
image_embedding_pooling_dim: Dimension of the image embedding pooling.
|
||||
concurrency: Configuration for concurrency settings.
|
||||
actor_learner: Configuration for actor-learner architecture.
|
||||
online_steps: Number of steps for online training.
|
||||
@@ -147,6 +148,7 @@ class SACConfig(PreTrainedConfig):
|
||||
image_encoder_hidden_dim: int = 32
|
||||
shared_encoder: bool = True
|
||||
num_discrete_actions: int | None = None
|
||||
image_embedding_pooling_dim: int = 8
|
||||
|
||||
# Training parameter
|
||||
online_steps: int = 1000000
|
||||
|
||||
@@ -22,13 +22,12 @@ from dataclasses import asdict
|
||||
from typing import Callable, List, Literal, Optional, Tuple
|
||||
|
||||
import einops
|
||||
from importlib_metadata import distribution
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.distributions import Categorical
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
|
||||
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
@@ -53,158 +52,46 @@ class SACPolicy(
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
# Determine action dimension and initialize all components
|
||||
continuous_action_dim = config.output_features["action"].shape[0]
|
||||
|
||||
if config.dataset_stats is not None:
|
||||
input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features,
|
||||
config.normalization_mapping,
|
||||
input_normalization_params,
|
||||
)
|
||||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
|
||||
|
||||
if config.dataset_stats is not None:
|
||||
output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
||||
|
||||
# HACK: This is hacky and should be removed
|
||||
dataset_stats = dataset_stats or output_normalization_params
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
else:
|
||||
self.normalize_targets = nn.Identity()
|
||||
self.unnormalize_outputs = nn.Identity()
|
||||
|
||||
# NOTE: For images the encoder should be shared between the actor and critic
|
||||
if config.shared_encoder:
|
||||
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
||||
encoder_actor: SACObservationEncoder = encoder_critic
|
||||
else:
|
||||
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
||||
encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
|
||||
self.shared_encoder = config.shared_encoder
|
||||
|
||||
# Create a list of critic heads
|
||||
critic_heads = [
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
|
||||
self.critic_ensemble = CriticEnsemble(
|
||||
encoder=encoder_critic,
|
||||
ensemble=critic_heads,
|
||||
output_normalization=self.normalize_targets,
|
||||
)
|
||||
|
||||
# Create target critic heads as deepcopies of the original critic heads
|
||||
target_critic_heads = [
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
|
||||
self.critic_target = CriticEnsemble(
|
||||
encoder=encoder_critic,
|
||||
ensemble=target_critic_heads,
|
||||
output_normalization=self.normalize_targets,
|
||||
)
|
||||
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
self.grasp_critic = None
|
||||
self.grasp_critic_target = None
|
||||
|
||||
if config.num_discrete_actions is not None:
|
||||
# Create grasp critic
|
||||
self.grasp_critic = GraspCritic(
|
||||
encoder=encoder_critic,
|
||||
input_dim=encoder_critic.output_dim,
|
||||
output_dim=config.num_discrete_actions,
|
||||
softmax_temperature=1.0,
|
||||
**asdict(config.grasp_critic_network_kwargs),
|
||||
)
|
||||
|
||||
# Create target grasp critic
|
||||
self.grasp_critic_target = GraspCritic(
|
||||
encoder=encoder_critic,
|
||||
input_dim=encoder_critic.output_dim,
|
||||
output_dim=config.num_discrete_actions,
|
||||
softmax_temperature=1.0,
|
||||
**asdict(config.grasp_critic_network_kwargs),
|
||||
)
|
||||
|
||||
self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict())
|
||||
|
||||
self.grasp_critic = torch.compile(self.grasp_critic)
|
||||
self.grasp_critic_target = torch.compile(self.grasp_critic_target)
|
||||
|
||||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
|
||||
action_dim=continuous_action_dim,
|
||||
encoder_is_shared=config.shared_encoder,
|
||||
**asdict(config.policy_kwargs),
|
||||
)
|
||||
if config.target_entropy is None:
|
||||
config.target_entropy = -np.prod(continuous_action_dim) / 2 # (-dim(A)/2)
|
||||
|
||||
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
||||
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
|
||||
# it triggers "can't optimize a non-leaf Tensor"
|
||||
|
||||
temperature_init = config.temperature_init
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)]))
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
self._init_normalization(dataset_stats)
|
||||
self._init_encoders()
|
||||
self._init_critics(continuous_action_dim)
|
||||
self._init_actor(continuous_action_dim)
|
||||
self._init_temperature()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
optim_params = {
|
||||
"actor": self.actor.parameters_to_optimize,
|
||||
"critic": self.critic_ensemble.parameters_to_optimize,
|
||||
"actor": [
|
||||
p
|
||||
for n, p in self.actor.named_parameters()
|
||||
if not n.startswith("encoder") or not self.shared_encoder
|
||||
],
|
||||
"critic": self.critic_ensemble.parameters(),
|
||||
"temperature": self.log_alpha,
|
||||
}
|
||||
if self.config.num_discrete_actions is not None:
|
||||
optim_params["grasp_critic"] = self.grasp_critic.parameters_to_optimize
|
||||
optim_params["grasp_critic"] = self.grasp_critic.parameters()
|
||||
return optim_params
|
||||
|
||||
def reset(self):
|
||||
"""Reset the policy"""
|
||||
pass
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
"""Override .to(device) method to involve moving the log_alpha fixed_std"""
|
||||
if self.actor.fixed_std is not None:
|
||||
self.actor.fixed_std = self.actor.fixed_std.to(*args, **kwargs)
|
||||
# self.log_alpha = self.log_alpha.to(*args, **kwargs)
|
||||
super().to(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select action for inference/evaluation"""
|
||||
# We cached the encoder output to avoid recomputing it
|
||||
observations_features = None
|
||||
if self.shared_encoder:
|
||||
observations_features = self.actor.encoder.get_image_features(batch, normalize=True)
|
||||
# Cache and normalize image features
|
||||
observations_features = self.actor.encoder.get_cached_image_features(batch, normalize=True)
|
||||
|
||||
actions, _, _ = self.actor(batch, observations_features)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
_, discrete_action_distribution = self.grasp_critic(batch, observations_features)
|
||||
discrete_action = discrete_action_distribution.sample().unsqueeze(-1).float()
|
||||
discrete_action_value = self.grasp_critic(batch, observations_features)
|
||||
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
|
||||
actions = torch.cat([actions, discrete_action], dim=-1)
|
||||
|
||||
return actions
|
||||
@@ -433,18 +320,19 @@ class SACPolicy(
|
||||
actions_discrete = torch.round(actions_discrete)
|
||||
actions_discrete = actions_discrete.long()
|
||||
|
||||
gripper_penalties: Tensor | None = None
|
||||
if complementary_info is not None:
|
||||
gripper_penalties: Tensor | None = complementary_info.get("gripper_penalty")
|
||||
|
||||
with torch.no_grad():
|
||||
# For DQN, select actions using online network, evaluate with target network
|
||||
next_grasp_qs, next_grasp_distribution = self.grasp_critic_forward(
|
||||
next_grasp_qs = self.grasp_critic_forward(
|
||||
next_observations, use_target=False, observation_features=next_observation_features
|
||||
)
|
||||
best_next_grasp_action = next_grasp_distribution.sample().unsqueeze(-1)
|
||||
best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1, keepdim=True)
|
||||
|
||||
# Get target Q-values from target network
|
||||
target_next_grasp_qs, _ = self.grasp_critic_forward(
|
||||
target_next_grasp_qs = self.grasp_critic_forward(
|
||||
observations=next_observations,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
@@ -455,14 +343,14 @@ class SACPolicy(
|
||||
target_next_grasp_qs, dim=1, index=best_next_grasp_action
|
||||
).squeeze(-1)
|
||||
|
||||
# Compute target Q-value with Bellman equation
|
||||
rewards_gripper = rewards
|
||||
if gripper_penalties is not None:
|
||||
rewards_gripper = rewards + gripper_penalties
|
||||
target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
|
||||
# Compute target Q-value with Bellman equation
|
||||
rewards_gripper = rewards
|
||||
if gripper_penalties is not None:
|
||||
rewards_gripper = rewards + gripper_penalties
|
||||
target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
|
||||
|
||||
# Get predicted Q-values for current observations
|
||||
predicted_grasp_qs, _ = self.grasp_critic_forward(
|
||||
predicted_grasp_qs = self.grasp_critic_forward(
|
||||
observations=observations, use_target=False, observation_features=observation_features
|
||||
)
|
||||
|
||||
@@ -502,109 +390,265 @@ class SACPolicy(
|
||||
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
|
||||
return actor_loss
|
||||
|
||||
def _init_normalization(self, dataset_stats):
|
||||
"""Initialize input/output normalization modules."""
|
||||
self.normalize_inputs = nn.Identity()
|
||||
self.normalize_targets = nn.Identity()
|
||||
self.unnormalize_outputs = nn.Identity()
|
||||
if self.config.dataset_stats:
|
||||
params = _convert_normalization_params_to_tensor(self.config.dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
self.config.input_features, self.config.normalization_mapping, params
|
||||
)
|
||||
stats = dataset_stats or params
|
||||
self.normalize_targets = Normalize(
|
||||
self.config.output_features, self.config.normalization_mapping, stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
self.config.output_features, self.config.normalization_mapping, stats
|
||||
)
|
||||
|
||||
def _init_encoders(self):
|
||||
"""Initialize shared or separate encoders for actor and critic."""
|
||||
self.shared_encoder = self.config.shared_encoder
|
||||
self.encoder_critic = SACObservationEncoder(self.config, self.normalize_inputs)
|
||||
self.encoder_actor = (
|
||||
self.encoder_critic
|
||||
if self.shared_encoder
|
||||
else SACObservationEncoder(self.config, self.normalize_inputs)
|
||||
)
|
||||
|
||||
def _init_critics(self, continuous_action_dim):
|
||||
"""Build critic ensemble, targets, and optional grasp critic."""
|
||||
heads = [
|
||||
CriticHead(
|
||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_ensemble = CriticEnsemble(
|
||||
encoder=self.encoder_critic, ensemble=heads, output_normalization=self.normalize_targets
|
||||
)
|
||||
target_heads = [
|
||||
CriticHead(
|
||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_target = CriticEnsemble(
|
||||
encoder=self.encoder_critic, ensemble=target_heads, output_normalization=self.normalize_targets
|
||||
)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
self._init_grasp_critics()
|
||||
|
||||
def _init_grasp_critics(self):
|
||||
"""Build discrete grasp critic ensemble and target networks."""
|
||||
self.grasp_critic = GraspCritic(
|
||||
encoder=self.encoder_critic,
|
||||
input_dim=self.encoder_critic.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.grasp_critic_network_kwargs),
|
||||
)
|
||||
self.grasp_critic_target = GraspCritic(
|
||||
encoder=self.encoder_critic,
|
||||
input_dim=self.encoder_critic.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.grasp_critic_network_kwargs),
|
||||
)
|
||||
|
||||
# TODO: (maractingi, azouitine) Compile the grasp critic
|
||||
self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict())
|
||||
|
||||
def _init_actor(self, continuous_action_dim):
|
||||
"""Initialize policy actor network and default target entropy."""
|
||||
self.actor = Policy(
|
||||
encoder=self.encoder_actor,
|
||||
network=MLP(input_dim=self.encoder_actor.output_dim, **asdict(self.config.actor_network_kwargs)),
|
||||
action_dim=continuous_action_dim,
|
||||
encoder_is_shared=self.shared_encoder,
|
||||
**asdict(self.config.policy_kwargs),
|
||||
)
|
||||
if self.config.target_entropy is None:
|
||||
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
|
||||
self.config.target_entropy = -np.prod(dim) / 2
|
||||
|
||||
def _init_temperature(self):
|
||||
"""Set up temperature parameter and initial log_alpha."""
|
||||
temp_init = self.config.temperature_init
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
|
||||
class SACObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations."""
|
||||
|
||||
def __init__(self, config: SACConfig, input_normalizer: nn.Module):
|
||||
"""
|
||||
Creates encoders for pixel and/or state modalities.
|
||||
"""
|
||||
super().__init__()
|
||||
def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None:
|
||||
super(SACObservationEncoder, self).__init__()
|
||||
self.config = config
|
||||
self.input_normalization = input_normalizer
|
||||
self.has_pretrained_vision_encoder = False
|
||||
self.parameters_to_optimize = []
|
||||
self._init_image_layers()
|
||||
self._init_state_layers()
|
||||
self._compute_output_dim()
|
||||
|
||||
self.aggregation_size: int = 0
|
||||
if any("observation.image" in key for key in config.input_features):
|
||||
self.camera_number = config.camera_number
|
||||
def _init_image_layers(self) -> None:
|
||||
self.image_keys = [k for k in self.config.input_features if k.startswith("observation.image")]
|
||||
self.has_images = bool(self.image_keys)
|
||||
if not self.has_images:
|
||||
return
|
||||
|
||||
if self.config.vision_encoder_name is not None:
|
||||
self.image_enc_layers = PretrainedImageEncoder(config)
|
||||
self.has_pretrained_vision_encoder = True
|
||||
else:
|
||||
self.image_enc_layers = DefaultImageEncoder(config)
|
||||
if self.config.vision_encoder_name:
|
||||
self.image_encoder = PretrainedImageEncoder(self.config)
|
||||
else:
|
||||
self.image_encoder = DefaultImageEncoder(self.config)
|
||||
|
||||
self.aggregation_size += config.latent_dim * self.camera_number
|
||||
if self.config.freeze_vision_encoder:
|
||||
freeze_image_encoder(self.image_encoder)
|
||||
|
||||
if config.freeze_vision_encoder:
|
||||
freeze_image_encoder(self.image_enc_layers)
|
||||
else:
|
||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
|
||||
dummy = torch.zeros(1, *self.config.input_features[self.image_keys[0]].shape)
|
||||
with torch.no_grad():
|
||||
_, channels, height, width = self.image_encoder(dummy).shape
|
||||
|
||||
if "observation.state" in config.input_features:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
self.spatial_embeddings = nn.ModuleDict()
|
||||
self.post_encoders = nn.ModuleDict()
|
||||
|
||||
for key in self.image_keys:
|
||||
name = key.replace(".", "_")
|
||||
self.spatial_embeddings[name] = SpatialLearnedEmbeddings(
|
||||
height=height,
|
||||
width=width,
|
||||
channel=channels,
|
||||
num_features=self.config.image_embedding_pooling_dim,
|
||||
)
|
||||
self.post_encoders[name] = nn.Sequential(
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(
|
||||
in_features=config.input_features["observation.state"].shape[0],
|
||||
out_features=config.latent_dim,
|
||||
in_features=channels * self.config.image_embedding_pooling_dim,
|
||||
out_features=self.config.latent_dim,
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.LayerNorm(normalized_shape=self.config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self.aggregation_size += config.latent_dim
|
||||
|
||||
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
|
||||
|
||||
if "observation.environment_state" in config.input_features:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
in_features=config.input_features["observation.environment_state"].shape[0],
|
||||
out_features=config.latent_dim,
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
def _init_state_layers(self) -> None:
|
||||
self.has_env = "observation.environment_state" in self.config.input_features
|
||||
self.has_state = "observation.state" in self.config.input_features
|
||||
if self.has_env:
|
||||
dim = self.config.input_features["observation.environment_state"].shape[0]
|
||||
self.env_encoder = nn.Sequential(
|
||||
nn.Linear(dim, self.config.latent_dim),
|
||||
nn.LayerNorm(self.config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
if self.has_state:
|
||||
dim = self.config.input_features["observation.state"].shape[0]
|
||||
self.state_encoder = nn.Sequential(
|
||||
nn.Linear(dim, self.config.latent_dim),
|
||||
nn.LayerNorm(self.config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self.aggregation_size += config.latent_dim
|
||||
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
|
||||
|
||||
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
||||
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
|
||||
def _compute_output_dim(self) -> None:
|
||||
out = 0
|
||||
if self.has_images:
|
||||
out += len(self.image_keys) * self.config.latent_dim
|
||||
if self.has_env:
|
||||
out += self.config.latent_dim
|
||||
if self.has_state:
|
||||
out += self.config.latent_dim
|
||||
self._out_dim = out
|
||||
|
||||
def forward(
|
||||
self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None
|
||||
self, obs: dict[str, Tensor], cache: Optional[dict[str, Tensor]] = None, detach: bool = False
|
||||
) -> Tensor:
|
||||
"""Encode the image and/or state vector.
|
||||
obs = self.input_normalization(obs)
|
||||
parts = []
|
||||
if self.has_images:
|
||||
if cache is None:
|
||||
cache = self.get_cached_image_features(obs, normalize=False)
|
||||
parts.append(self._encode_images(cache, detach))
|
||||
if self.has_env:
|
||||
parts.append(self.env_encoder(obs["observation.environment_state"]))
|
||||
if self.has_state:
|
||||
parts.append(self.state_encoder(obs["observation.state"]))
|
||||
if parts:
|
||||
return torch.cat(parts, dim=-1)
|
||||
|
||||
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
|
||||
over all features.
|
||||
raise ValueError(
|
||||
"No parts to concatenate, you should have at least one image or environment state or state"
|
||||
)
|
||||
|
||||
def get_cached_image_features(self, obs: dict[str, Tensor], normalize: bool = False) -> dict[str, Tensor]:
|
||||
"""Extract and optionally cache image features from observations.
|
||||
|
||||
This function processes image observations through the vision encoder once and returns
|
||||
the resulting features.
|
||||
When the image encoder is shared between actor and critics AND frozen, these features can be safely cached and
|
||||
reused across policy components (actor, critic, grasp_critic), avoiding redundant forward passes.
|
||||
|
||||
Performance impact:
|
||||
- The vision encoder forward pass is typically the main computational bottleneck during training and inference
|
||||
- Caching these features can provide 2-4x speedup in training and inference
|
||||
|
||||
Normalization behavior:
|
||||
- When called from inside forward(): set normalize=False since inputs are already normalized
|
||||
- When called from outside forward(): set normalize=True to ensure proper input normalization
|
||||
|
||||
Usage patterns:
|
||||
- Called in select_action() with normalize=True
|
||||
- Called in learner_server.py's get_observation_features() to pre-compute features for all policy components
|
||||
- Called internally by forward() with normalize=False
|
||||
|
||||
Args:
|
||||
obs: Dictionary of observation tensors containing image keys
|
||||
normalize: Whether to normalize observations before encoding
|
||||
Set to True when calling directly from outside the encoder's forward method
|
||||
Set to False when calling from within forward() where inputs are already normalized
|
||||
|
||||
Returns:
|
||||
Dictionary mapping image keys to their corresponding encoded features
|
||||
"""
|
||||
feat = []
|
||||
obs_dict = self.input_normalization(obs_dict)
|
||||
if len(self.all_image_keys) > 0 and vision_encoder_cache is None:
|
||||
vision_encoder_cache = self.get_image_features(obs_dict, normalize=False)
|
||||
|
||||
if vision_encoder_cache is not None:
|
||||
feat.append(vision_encoder_cache)
|
||||
|
||||
if "observation.environment_state" in self.config.input_features:
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_features:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
|
||||
features = torch.cat(tensors=feat, dim=-1)
|
||||
features = self.aggregation_layer(features)
|
||||
|
||||
return features
|
||||
|
||||
def get_image_features(self, batch: dict[str, Tensor], normalize: bool = True) -> torch.Tensor:
|
||||
# [N*B, C, H, W]
|
||||
if normalize:
|
||||
batch = self.input_normalization(batch)
|
||||
if len(self.all_image_keys) > 0:
|
||||
# Batch all images along the batch dimension, then encode them.
|
||||
images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0)
|
||||
images_batched = self.image_enc_layers(images_batched)
|
||||
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
|
||||
embeddings_image = torch.cat(embeddings_chunks, dim=-1)
|
||||
return embeddings_image
|
||||
return None
|
||||
obs = self.input_normalization(obs)
|
||||
batched = torch.cat([obs[k] for k in self.image_keys], dim=0)
|
||||
out = self.image_encoder(batched)
|
||||
chunks = torch.chunk(out, len(self.image_keys), dim=0)
|
||||
return dict(zip(self.image_keys, chunks, strict=False))
|
||||
|
||||
def _encode_images(self, cache: dict[str, Tensor], detach: bool) -> Tensor:
|
||||
"""Encode image features from cached observations.
|
||||
|
||||
This function takes pre-encoded image features from the cache and applies spatial embeddings and post-encoders.
|
||||
It also supports detaching the encoded features if specified.
|
||||
|
||||
Args:
|
||||
cache (dict[str, Tensor]): The cached image features.
|
||||
detach (bool): Usually when the encoder is shared between actor and critics,
|
||||
we want to detach the encoded features on the policy side to avoid backprop through the encoder.
|
||||
More detail here `https://cdn.aaai.org/ojs/17276/17276-13-20770-1-2-20210518.pdf`
|
||||
|
||||
Returns:
|
||||
Tensor: The encoded image features.
|
||||
"""
|
||||
feats = []
|
||||
for k, feat in cache.items():
|
||||
safe_key = k.replace(".", "_")
|
||||
x = self.spatial_embeddings[safe_key](feat)
|
||||
x = self.post_encoders[safe_key](x)
|
||||
if detach:
|
||||
x = x.detach()
|
||||
feats.append(x)
|
||||
return torch.cat(feats, dim=-1)
|
||||
|
||||
@property
|
||||
def output_dim(self) -> int:
|
||||
"""Returns the dimension of the encoder output"""
|
||||
return self.config.latent_dim
|
||||
return self._out_dim
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
@@ -740,12 +784,6 @@ class CriticEnsemble(nn.Module):
|
||||
self.output_normalization = output_normalization
|
||||
self.critics = nn.ModuleList(ensemble)
|
||||
|
||||
self.parameters_to_optimize = []
|
||||
# Handle the case where a part of the encoder if frozen
|
||||
if self.encoder is not None:
|
||||
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
|
||||
self.parameters_to_optimize += list(self.critics.parameters())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: dict[str, torch.Tensor],
|
||||
@@ -761,7 +799,7 @@ class CriticEnsemble(nn.Module):
|
||||
actions = self.output_normalization(actions)["action"]
|
||||
actions = actions.to(device)
|
||||
|
||||
obs_enc = self.encoder(observations, observation_features)
|
||||
obs_enc = self.encoder(observations, cache=observation_features)
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
|
||||
@@ -787,7 +825,6 @@ class GraspCritic(nn.Module):
|
||||
dropout_rate: Optional[float] = None,
|
||||
init_final: Optional[float] = None,
|
||||
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
|
||||
softmax_temperature: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
@@ -809,20 +846,14 @@ class GraspCritic(nn.Module):
|
||||
else:
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
self.parameters_to_optimize = []
|
||||
self.parameters_to_optimize += list(self.net.parameters())
|
||||
self.parameters_to_optimize += list(self.output_layer.parameters())
|
||||
|
||||
def forward(
|
||||
self, observations: torch.Tensor, observation_features: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
# Move each tensor in observations to device by cloning first to avoid inplace operations
|
||||
observations = {k: v.to(device) for k, v in observations.items()}
|
||||
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
|
||||
q_values = self.output_layer(self.net(obs_enc))
|
||||
distribution = Categorical(logits=q_values / self.softmax_temperature)
|
||||
return q_values, distribution
|
||||
obs_enc = self.encoder(observations, cache=observation_features)
|
||||
return self.output_layer(self.net(obs_enc))
|
||||
|
||||
|
||||
class Policy(nn.Module):
|
||||
@@ -846,12 +877,8 @@ class Policy(nn.Module):
|
||||
self.log_std_max = log_std_max
|
||||
self.fixed_std = fixed_std
|
||||
self.use_tanh_squash = use_tanh_squash
|
||||
self.parameters_to_optimize = []
|
||||
self.encoder_is_shared = encoder_is_shared
|
||||
|
||||
self.parameters_to_optimize += list(self.network.parameters())
|
||||
|
||||
if self.encoder is not None and not encoder_is_shared:
|
||||
self.parameters_to_optimize += list(self.encoder.parameters())
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
@@ -865,7 +892,6 @@ class Policy(nn.Module):
|
||||
else:
|
||||
orthogonal_init()(self.mean_layer.weight)
|
||||
|
||||
self.parameters_to_optimize += list(self.mean_layer.parameters())
|
||||
# Standard deviation layer or parameter
|
||||
if fixed_std is None:
|
||||
self.std_layer = nn.Linear(out_features, action_dim)
|
||||
@@ -874,15 +900,15 @@ class Policy(nn.Module):
|
||||
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.std_layer.weight)
|
||||
self.parameters_to_optimize += list(self.std_layer.parameters())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
observation_features: torch.Tensor | None = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Encode observations if encoder exists
|
||||
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# We detach the encoder if it is shared to avoid backprop through it
|
||||
# This is important to avoid the encoder to be updated through the policy
|
||||
obs_enc = self.encoder(observations, cache=observation_features, detach=self.encoder_is_shared)
|
||||
|
||||
# Get network outputs
|
||||
outputs = self.network(obs_enc)
|
||||
@@ -891,29 +917,20 @@ class Policy(nn.Module):
|
||||
# Compute standard deviations
|
||||
if self.fixed_std is None:
|
||||
log_std = self.std_layer(outputs)
|
||||
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
|
||||
|
||||
if self.use_tanh_squash:
|
||||
log_std = torch.tanh(log_std)
|
||||
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
|
||||
else:
|
||||
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
|
||||
std = torch.exp(log_std) # Match JAX "exp"
|
||||
std = torch.clamp(std, self.log_std_min, self.log_std_max) # Match JAX default clip
|
||||
else:
|
||||
log_std = self.fixed_std.expand_as(means)
|
||||
|
||||
# uses tanh activation function to squash the action to be in the range of [-1, 1]
|
||||
normal = torch.distributions.Normal(means, torch.exp(log_std))
|
||||
x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1))
|
||||
log_probs = normal.log_prob(x_t) # Base log probability before Tanh
|
||||
# Build transformed distribution
|
||||
dist = TanhMultivariateNormalDiag(loc=means, scale_diag=std)
|
||||
|
||||
if self.use_tanh_squash:
|
||||
actions = torch.tanh(x_t)
|
||||
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
|
||||
else:
|
||||
actions = x_t # No Tanh; raw Gaussian sample
|
||||
# Sample actions (reparameterized)
|
||||
actions = dist.rsample()
|
||||
|
||||
# Compute log_probs
|
||||
log_probs = dist.log_prob(actions)
|
||||
|
||||
log_probs = log_probs.sum(-1) # Sum over action dimensions
|
||||
means = torch.tanh(means) if self.use_tanh_squash else means
|
||||
return actions, log_probs, means
|
||||
|
||||
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
@@ -961,21 +978,16 @@ class DefaultImageEncoder(nn.Module):
|
||||
nn.ReLU(),
|
||||
)
|
||||
# Get first image key from input features
|
||||
image_key = next(key for key in config.input_features.keys() if key.startswith("observation.image")) # noqa: SIM118
|
||||
dummy_batch = torch.zeros(1, *config.input_features[image_key].shape)
|
||||
with torch.inference_mode():
|
||||
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||
self.image_enc_layers.extend(
|
||||
nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.image_enc_layers(x)
|
||||
x = self.image_enc_layers(x)
|
||||
return x
|
||||
|
||||
|
||||
def freeze_image_encoder(image_encoder: nn.Module):
|
||||
"""Freeze all parameters in the encoder"""
|
||||
for param in image_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
class PretrainedImageEncoder(nn.Module):
|
||||
@@ -983,18 +995,12 @@ class PretrainedImageEncoder(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
||||
self.image_enc_proj = nn.Sequential(
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def _load_pretrained_vision_encoder(self, config: SACConfig):
|
||||
"""Set up CNN encoder"""
|
||||
from transformers import AutoModel
|
||||
|
||||
self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True)
|
||||
# self.image_enc_layers.pooler = Identity()
|
||||
|
||||
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
|
||||
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
|
||||
@@ -1005,19 +1011,10 @@ class PretrainedImageEncoder(nn.Module):
|
||||
return self.image_enc_layers, self.image_enc_out_shape
|
||||
|
||||
def forward(self, x):
|
||||
# TODO: (maractingi, azouitine) check the forward pass of the pretrained model
|
||||
# doesn't reach the classifier layer because we don't need it
|
||||
enc_feat = self.image_enc_layers(x).pooler_output
|
||||
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
|
||||
enc_feat = self.image_enc_layers(x).last_hidden_state
|
||||
return enc_feat
|
||||
|
||||
|
||||
def freeze_image_encoder(image_encoder: nn.Module):
|
||||
"""Freeze all parameters in the encoder"""
|
||||
for param in image_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
def orthogonal_init():
|
||||
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
||||
|
||||
@@ -1030,6 +1027,112 @@ class Identity(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class SpatialLearnedEmbeddings(nn.Module):
|
||||
def __init__(self, height, width, channel, num_features=8):
|
||||
"""
|
||||
PyTorch implementation of learned spatial embeddings
|
||||
|
||||
Args:
|
||||
height: Spatial height of input features
|
||||
width: Spatial width of input features
|
||||
channel: Number of input channels
|
||||
num_features: Number of output embedding dimensions
|
||||
"""
|
||||
super().__init__()
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.channel = channel
|
||||
self.num_features = num_features
|
||||
|
||||
self.kernel = nn.Parameter(torch.empty(channel, height, width, num_features))
|
||||
|
||||
nn.init.kaiming_normal_(self.kernel, mode="fan_in", nonlinearity="linear")
|
||||
|
||||
def forward(self, features):
|
||||
"""
|
||||
Forward pass for spatial embedding
|
||||
|
||||
Args:
|
||||
features: Input tensor of shape [B, C, H, W] where B is batch size,
|
||||
C is number of channels, H is height, and W is width
|
||||
Returns:
|
||||
Output tensor of shape [B, C*F] where F is the number of features
|
||||
"""
|
||||
|
||||
features_expanded = features.unsqueeze(-1) # [B, C, H, W, 1]
|
||||
kernel_expanded = self.kernel.unsqueeze(0) # [1, C, H, W, F]
|
||||
|
||||
# Element-wise multiplication and spatial reduction
|
||||
output = (features_expanded * kernel_expanded).sum(dim=(2, 3)) # Sum over H,W dimensions
|
||||
|
||||
# Reshape to combine channel and feature dimensions
|
||||
output = output.view(output.size(0), -1) # [B, C*F]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class RescaleFromTanh(Transform):
|
||||
def __init__(self, low: float = -1, high: float = 1):
|
||||
super().__init__()
|
||||
|
||||
self.low = low
|
||||
|
||||
self.high = high
|
||||
|
||||
def _call(self, x):
|
||||
# Rescale from (-1, 1) to (low, high)
|
||||
|
||||
return 0.5 * (x + 1.0) * (self.high - self.low) + self.low
|
||||
|
||||
def _inverse(self, y):
|
||||
# Rescale from (low, high) back to (-1, 1)
|
||||
|
||||
return 2.0 * (y - self.low) / (self.high - self.low) - 1.0
|
||||
|
||||
def log_abs_det_jacobian(self, x, y):
|
||||
# log|d(rescale)/dx| = sum(log(0.5 * (high - low)))
|
||||
|
||||
scale = 0.5 * (self.high - self.low)
|
||||
|
||||
return torch.sum(torch.log(scale), dim=-1)
|
||||
|
||||
|
||||
class TanhMultivariateNormalDiag(TransformedDistribution):
|
||||
def __init__(self, loc, scale_diag, low=None, high=None):
|
||||
base_dist = MultivariateNormal(loc, torch.diag_embed(scale_diag))
|
||||
|
||||
transforms = [TanhTransform(cache_size=1)]
|
||||
|
||||
if low is not None and high is not None:
|
||||
low = torch.as_tensor(low)
|
||||
|
||||
high = torch.as_tensor(high)
|
||||
|
||||
transforms.insert(0, RescaleFromTanh(low, high))
|
||||
|
||||
super().__init__(base_dist, transforms)
|
||||
|
||||
def mode(self):
|
||||
# Mode is mean of base distribution, passed through transforms
|
||||
|
||||
x = self.base_dist.mean
|
||||
|
||||
for transform in self.transforms:
|
||||
x = transform(x)
|
||||
|
||||
return x
|
||||
|
||||
def stddev(self):
|
||||
std = self.base_dist.stddev
|
||||
|
||||
x = std
|
||||
|
||||
for transform in self.transforms:
|
||||
x = transform(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
||||
converted_params = {}
|
||||
for outer_key, inner_dict in normalization_params.items():
|
||||
@@ -1040,90 +1143,3 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
||||
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
|
||||
|
||||
return converted_params
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# # Benchmark the CriticEnsemble performance
|
||||
# import time
|
||||
|
||||
# # Configuration
|
||||
# num_critics = 10
|
||||
# batch_size = 32
|
||||
# action_dim = 7
|
||||
# obs_dim = 64
|
||||
# hidden_dims = [256, 256]
|
||||
# num_iterations = 100
|
||||
|
||||
# print("Creating test environment...")
|
||||
|
||||
# # Create a simple dummy encoder
|
||||
# class DummyEncoder(nn.Module):
|
||||
# def __init__(self):
|
||||
# super().__init__()
|
||||
# self.output_dim = obs_dim
|
||||
# self.parameters_to_optimize = []
|
||||
|
||||
# def forward(self, obs):
|
||||
# # Just return a random tensor of the right shape
|
||||
# # In practice, this would encode the observations
|
||||
# return torch.randn(batch_size, obs_dim, device=device)
|
||||
|
||||
# # Create critic heads
|
||||
# print(f"Creating {num_critics} critic heads...")
|
||||
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
# critic_heads = [
|
||||
# CriticHead(
|
||||
# input_dim=obs_dim + action_dim,
|
||||
# hidden_dims=hidden_dims,
|
||||
# ).to(device)
|
||||
# for _ in range(num_critics)
|
||||
# ]
|
||||
|
||||
# # Create the critic ensemble
|
||||
# print("Creating CriticEnsemble...")
|
||||
# critic_ensemble = CriticEnsemble(
|
||||
# encoder=DummyEncoder().to(device),
|
||||
# ensemble=critic_heads,
|
||||
# output_normalization=nn.Identity(),
|
||||
# ).to(device)
|
||||
|
||||
# # Create random input data
|
||||
# print("Creating input data...")
|
||||
# obs_dict = {
|
||||
# "observation.state": torch.randn(batch_size, obs_dim, device=device),
|
||||
# }
|
||||
# actions = torch.randn(batch_size, action_dim, device=device)
|
||||
|
||||
# # Warmup run
|
||||
# print("Warming up...")
|
||||
# _ = critic_ensemble(obs_dict, actions)
|
||||
|
||||
# # Time the forward pass
|
||||
# print(f"Running benchmark with {num_iterations} iterations...")
|
||||
# start_time = time.perf_counter()
|
||||
# for _ in range(num_iterations):
|
||||
# q_values = critic_ensemble(obs_dict, actions)
|
||||
# end_time = time.perf_counter()
|
||||
|
||||
# # Print results
|
||||
# elapsed_time = end_time - start_time
|
||||
# print(f"Total time: {elapsed_time:.4f} seconds")
|
||||
# print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms")
|
||||
# print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size]
|
||||
|
||||
# Verify that all critic heads produce different outputs
|
||||
# This confirms each critic head is unique
|
||||
# print("\nVerifying critic outputs are different:")
|
||||
# for i in range(num_critics):
|
||||
# for j in range(i + 1, num_critics):
|
||||
# diff = torch.abs(q_values[i] - q_values[j]).mean().item()
|
||||
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}")
|
||||
|
||||
from lerobot.configs import parser
|
||||
|
||||
@parser.wrap()
|
||||
def main(config: SACConfig):
|
||||
policy = SACPolicy(config=config)
|
||||
print("yolo")
|
||||
|
||||
main()
|
||||
|
||||
@@ -221,7 +221,6 @@ def record_episode(
|
||||
events=events,
|
||||
policy=policy,
|
||||
fps=fps,
|
||||
# record_delta_actions=record_delta_actions,
|
||||
teleoperate=policy is None,
|
||||
single_task=single_task,
|
||||
)
|
||||
@@ -267,8 +266,6 @@ def control_loop(
|
||||
|
||||
if teleoperate:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
# if record_delta_actions:
|
||||
# action["action"] = action["action"] - current_joint_positions
|
||||
else:
|
||||
observation = robot.capture_observation()
|
||||
|
||||
|
||||
@@ -363,8 +363,6 @@ def replay(
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action = actions[idx]["action"]
|
||||
# if replay_delta_actions:
|
||||
# action = action + current_joint_positions
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
|
||||
@@ -231,6 +231,7 @@ def act_with_policy(
|
||||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
policy = policy.eval()
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
obs, info = online_env.reset()
|
||||
|
||||
@@ -78,9 +78,7 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr
|
||||
if isinstance(val, torch.Tensor):
|
||||
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
|
||||
elif isinstance(val, (int, float, bool)):
|
||||
transition["complementary_info"][key] = torch.tensor(
|
||||
val, device=device, non_blocking=non_blocking
|
||||
)
|
||||
transition["complementary_info"][key] = torch.tensor(val, device=device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
|
||||
return transition
|
||||
@@ -505,7 +503,6 @@ class ReplayBuffer:
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
capacity: Optional[int] = None,
|
||||
action_mask: Optional[Sequence[int]] = None,
|
||||
action_delta: Optional[float] = None,
|
||||
image_augmentation_function: Optional[Callable] = None,
|
||||
use_drq: bool = True,
|
||||
storage_device: str = "cpu",
|
||||
@@ -520,7 +517,6 @@ class ReplayBuffer:
|
||||
state_keys (Optional[Sequence[str]]): The list of keys that appear in `state` and `next_state`.
|
||||
capacity (Optional[int]): Buffer capacity. If None, uses dataset length.
|
||||
action_mask (Optional[Sequence[int]]): Indices of action dimensions to keep.
|
||||
action_delta (Optional[float]): Factor to divide actions by.
|
||||
image_augmentation_function (Optional[Callable]): Function for image augmentation.
|
||||
If None, uses default random shift with pad=4.
|
||||
use_drq (bool): Whether to use DrQ image augmentation when sampling.
|
||||
@@ -565,9 +561,6 @@ class ReplayBuffer:
|
||||
else:
|
||||
first_action = first_action[:, action_mask]
|
||||
|
||||
if action_delta is not None:
|
||||
first_action = first_action / action_delta
|
||||
|
||||
# Get complementary info if available
|
||||
first_complementary_info = None
|
||||
if (
|
||||
@@ -598,9 +591,6 @@ class ReplayBuffer:
|
||||
else:
|
||||
action = action[:, action_mask]
|
||||
|
||||
if action_delta is not None:
|
||||
action = action / action_delta
|
||||
|
||||
replay_buffer.add(
|
||||
state=data["state"],
|
||||
action=action,
|
||||
|
||||
@@ -42,7 +42,6 @@ class HILSerlRobotEnv(gym.Env):
|
||||
self,
|
||||
robot,
|
||||
use_delta_action_space: bool = True,
|
||||
delta: float | None = None,
|
||||
display_cameras: bool = False,
|
||||
):
|
||||
"""
|
||||
@@ -55,8 +54,6 @@ class HILSerlRobotEnv(gym.Env):
|
||||
robot: The robot interface object used to connect and interact with the physical robot.
|
||||
use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute
|
||||
joint positions are used.
|
||||
delta (float or None): A scaling factor for the relative adjustments applied to joint positions. Should be a value between
|
||||
0 and 1 when using a delta action space.
|
||||
display_cameras (bool): If True, the robot's camera feeds will be displayed during execution.
|
||||
"""
|
||||
super().__init__()
|
||||
@@ -74,7 +71,6 @@ class HILSerlRobotEnv(gym.Env):
|
||||
self.current_step = 0
|
||||
self.episode_data = None
|
||||
|
||||
self.delta = delta
|
||||
self.use_delta_action_space = use_delta_action_space
|
||||
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
|
||||
|
||||
@@ -374,7 +370,7 @@ class RewardWrapper(gym.Wrapper):
|
||||
self.device = device
|
||||
|
||||
def step(self, action):
|
||||
observation, _, terminated, truncated, info = self.env.step(action)
|
||||
observation, reward, terminated, truncated, info = self.env.step(action)
|
||||
images = [
|
||||
observation[key].to(self.device, non_blocking=self.device.type == "cuda")
|
||||
for key in observation
|
||||
@@ -382,15 +378,17 @@ class RewardWrapper(gym.Wrapper):
|
||||
]
|
||||
start_time = time.perf_counter()
|
||||
with torch.inference_mode():
|
||||
reward = (
|
||||
success = (
|
||||
self.reward_classifier.predict_reward(images, threshold=0.8)
|
||||
if self.reward_classifier is not None
|
||||
else 0.0
|
||||
)
|
||||
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
|
||||
|
||||
if reward == 1.0:
|
||||
if success == 1.0:
|
||||
terminated = True
|
||||
reward = 1.0
|
||||
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
@@ -720,11 +718,13 @@ class ResetWrapper(gym.Wrapper):
|
||||
env: HILSerlRobotEnv,
|
||||
reset_pose: np.ndarray | None = None,
|
||||
reset_time_s: float = 5,
|
||||
open_gripper_on_reset: bool = False,
|
||||
):
|
||||
super().__init__(env)
|
||||
self.reset_time_s = reset_time_s
|
||||
self.reset_pose = reset_pose
|
||||
self.robot = self.unwrapped.robot
|
||||
self.open_gripper_on_reset = open_gripper_on_reset
|
||||
|
||||
def reset(self, *, seed=None, options=None):
|
||||
if self.reset_pose is not None:
|
||||
@@ -733,6 +733,14 @@ class ResetWrapper(gym.Wrapper):
|
||||
reset_follower_position(self.robot, self.reset_pose)
|
||||
busy_wait(self.reset_time_s - (time.perf_counter() - start_time))
|
||||
log_say("Reset the environment done.", play_sounds=True)
|
||||
if self.open_gripper_on_reset:
|
||||
current_joint_pos = self.robot.follower_arms["main"].read("Present_Position")
|
||||
current_joint_pos[-1] = MAX_GRIPPER_COMMAND
|
||||
self.robot.send_action(torch.from_numpy(current_joint_pos))
|
||||
busy_wait(0.1)
|
||||
current_joint_pos[-1] = 0.0
|
||||
self.robot.send_action(torch.from_numpy(current_joint_pos))
|
||||
busy_wait(0.2)
|
||||
else:
|
||||
log_say(
|
||||
f"Manually reset the environment for {self.reset_time_s} seconds.",
|
||||
@@ -762,37 +770,48 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
|
||||
|
||||
|
||||
class GripperPenaltyWrapper(gym.RewardWrapper):
|
||||
def __init__(self, env, penalty: float = -0.1):
|
||||
def __init__(self, env, penalty: float = -0.1, gripper_penalty_in_reward: bool = True):
|
||||
super().__init__(env)
|
||||
self.penalty = penalty
|
||||
self.gripper_penalty_in_reward = gripper_penalty_in_reward
|
||||
self.last_gripper_state = None
|
||||
|
||||
def reward(self, reward, action):
|
||||
gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND
|
||||
|
||||
if isinstance(action, tuple):
|
||||
action = action[0]
|
||||
action_normalized = action[-1] / MAX_GRIPPER_COMMAND
|
||||
action_normalized = action - 1.0 # action / MAX_GRIPPER_COMMAND
|
||||
|
||||
gripper_penalty_bool = (gripper_state_normalized < 0.1 and action_normalized > 0.9) or (
|
||||
gripper_state_normalized > 0.9 and action_normalized < 0.1
|
||||
gripper_penalty_bool = (gripper_state_normalized < 0.5 and action_normalized > 0.5) or (
|
||||
gripper_state_normalized > 0.75 and action_normalized < -0.5
|
||||
)
|
||||
breakpoint()
|
||||
|
||||
return reward + self.penalty * gripper_penalty_bool
|
||||
return reward + self.penalty * int(gripper_penalty_bool)
|
||||
|
||||
def step(self, action):
|
||||
self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
||||
if isinstance(action, tuple):
|
||||
gripper_action = action[0][-1]
|
||||
else:
|
||||
gripper_action = action[-1]
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
reward = self.reward(reward, action)
|
||||
gripper_penalty = self.reward(reward, gripper_action)
|
||||
|
||||
if self.gripper_penalty_in_reward:
|
||||
reward += gripper_penalty
|
||||
else:
|
||||
info["gripper_penalty"] = gripper_penalty
|
||||
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, **kwargs):
|
||||
self.last_gripper_state = None
|
||||
return super().reset(**kwargs)
|
||||
obs, info = super().reset(**kwargs)
|
||||
if self.gripper_penalty_in_reward:
|
||||
info["gripper_penalty"] = 0.0
|
||||
return obs, info
|
||||
|
||||
|
||||
class GripperQuantizationWrapper(gym.ActionWrapper):
|
||||
class GripperActionWrapper(gym.ActionWrapper):
|
||||
def __init__(self, env, quantization_threshold: float = 0.2):
|
||||
super().__init__(env)
|
||||
self.quantization_threshold = quantization_threshold
|
||||
@@ -801,16 +820,18 @@ class GripperQuantizationWrapper(gym.ActionWrapper):
|
||||
is_intervention = False
|
||||
if isinstance(action, tuple):
|
||||
action, is_intervention = action
|
||||
|
||||
gripper_command = action[-1]
|
||||
# Quantize gripper command to -1, 0 or 1
|
||||
if gripper_command < -self.quantization_threshold:
|
||||
gripper_command = -MAX_GRIPPER_COMMAND
|
||||
elif gripper_command > self.quantization_threshold:
|
||||
gripper_command = MAX_GRIPPER_COMMAND
|
||||
else:
|
||||
gripper_command = 0.0
|
||||
|
||||
# Gripper actions are between 0, 2
|
||||
# we want to quantize them to -1, 0 or 1
|
||||
gripper_command = gripper_command - 1.0
|
||||
|
||||
if self.quantization_threshold is not None:
|
||||
# Quantize gripper command to -1, 0 or 1
|
||||
gripper_command = (
|
||||
np.sign(gripper_command) if abs(gripper_command) > self.quantization_threshold else 0.0
|
||||
)
|
||||
gripper_command = gripper_command * MAX_GRIPPER_COMMAND
|
||||
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
||||
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
|
||||
action[-1] = gripper_action.item()
|
||||
@@ -836,10 +857,12 @@ class EEActionWrapper(gym.ActionWrapper):
|
||||
]
|
||||
)
|
||||
if self.use_gripper:
|
||||
action_space_bounds = np.concatenate([action_space_bounds, [1.0]])
|
||||
# gripper actions open at 2.0, and closed at 0.0
|
||||
min_action_space_bounds = np.concatenate([-action_space_bounds, [0.0]])
|
||||
max_action_space_bounds = np.concatenate([action_space_bounds, [2.0]])
|
||||
ee_action_space = gym.spaces.Box(
|
||||
low=-action_space_bounds,
|
||||
high=action_space_bounds,
|
||||
low=min_action_space_bounds,
|
||||
high=max_action_space_bounds,
|
||||
shape=(3 + int(self.use_gripper),),
|
||||
dtype=np.float32,
|
||||
)
|
||||
@@ -997,11 +1020,11 @@ class GamepadControlWrapper(gym.Wrapper):
|
||||
if self.use_gripper:
|
||||
gripper_command = self.controller.gripper_command()
|
||||
if gripper_command == "open":
|
||||
gamepad_action = np.concatenate([gamepad_action, [1.0]])
|
||||
gamepad_action = np.concatenate([gamepad_action, [2.0]])
|
||||
elif gripper_command == "close":
|
||||
gamepad_action = np.concatenate([gamepad_action, [-1.0]])
|
||||
else:
|
||||
gamepad_action = np.concatenate([gamepad_action, [0.0]])
|
||||
else:
|
||||
gamepad_action = np.concatenate([gamepad_action, [1.0]])
|
||||
|
||||
# Check episode ending buttons
|
||||
# We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None
|
||||
@@ -1141,7 +1164,6 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||
env = HILSerlRobotEnv(
|
||||
robot=robot,
|
||||
display_cameras=cfg.wrapper.display_cameras,
|
||||
delta=cfg.wrapper.delta_action,
|
||||
use_delta_action_space=cfg.wrapper.use_relative_joint_positions
|
||||
and cfg.wrapper.ee_action_space_params is None,
|
||||
)
|
||||
@@ -1165,10 +1187,13 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
||||
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
||||
if cfg.wrapper.use_gripper:
|
||||
env = GripperQuantizationWrapper(
|
||||
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
|
||||
)
|
||||
# env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty)
|
||||
env = GripperActionWrapper(env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold)
|
||||
if cfg.wrapper.gripper_penalty is not None:
|
||||
env = GripperPenaltyWrapper(
|
||||
env=env,
|
||||
penalty=cfg.wrapper.gripper_penalty,
|
||||
gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward,
|
||||
)
|
||||
|
||||
if cfg.wrapper.ee_action_space_params is not None:
|
||||
env = EEActionWrapper(
|
||||
@@ -1176,6 +1201,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||
ee_action_space_params=cfg.wrapper.ee_action_space_params,
|
||||
use_gripper=cfg.wrapper.use_gripper,
|
||||
)
|
||||
|
||||
if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad:
|
||||
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
|
||||
env = GamepadControlWrapper(
|
||||
@@ -1192,6 +1218,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||
env=env,
|
||||
reset_pose=cfg.wrapper.fixed_reset_joint_positions,
|
||||
reset_time_s=cfg.wrapper.reset_time_s,
|
||||
open_gripper_on_reset=cfg.wrapper.open_gripper_on_reset,
|
||||
)
|
||||
if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
|
||||
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
|
||||
@@ -1341,11 +1368,10 @@ def record_dataset(env, policy, cfg):
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
def replay_episode(env, repo_id, root=None, episode=0):
|
||||
def replay_episode(env, cfg):
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
local_files_only = root is not None
|
||||
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.dataset_root, episodes=[cfg.episode])
|
||||
env.reset()
|
||||
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
@@ -1353,7 +1379,7 @@ def replay_episode(env, repo_id, root=None, episode=0):
|
||||
for idx in range(dataset.num_frames):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action = actions[idx]["action"][:4]
|
||||
action = actions[idx]["action"]
|
||||
env.step((action, False))
|
||||
# env.step((action / env.unwrapped.delta, False))
|
||||
|
||||
@@ -1384,9 +1410,7 @@ def main(cfg: EnvConfig):
|
||||
if cfg.mode == "replay":
|
||||
replay_episode(
|
||||
env,
|
||||
cfg.replay_repo_id,
|
||||
root=cfg.dataset_root,
|
||||
episode=cfg.replay_episode,
|
||||
cfg=cfg,
|
||||
)
|
||||
exit()
|
||||
|
||||
|
||||
@@ -407,6 +407,7 @@ def add_actor_information_and_train(
|
||||
"done": done,
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
"complementary_info": batch["complementary_info"],
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss (includes both main critic and grasp critic)
|
||||
@@ -428,7 +429,7 @@ def add_actor_information_and_train(
|
||||
optimizers["grasp_critic"].zero_grad()
|
||||
loss_grasp_critic.backward()
|
||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
optimizers["grasp_critic"].step()
|
||||
|
||||
@@ -492,7 +493,7 @@ def add_actor_information_and_train(
|
||||
optimizers["grasp_critic"].zero_grad()
|
||||
loss_grasp_critic.backward()
|
||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["grasp_critic"].step()
|
||||
|
||||
@@ -509,7 +510,7 @@ def add_actor_information_and_train(
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||
parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["actor"].step()
|
||||
|
||||
@@ -771,17 +772,18 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
|
||||
|
||||
"""
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
|
||||
params=policy.actor.parameters_to_optimize,
|
||||
params=[
|
||||
p
|
||||
for n, p in policy.actor.named_parameters()
|
||||
if not policy.config.shared_encoder or not n.startswith("encoder")
|
||||
],
|
||||
lr=cfg.policy.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
||||
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
optimizer_grasp_critic = torch.optim.Adam(
|
||||
params=policy.grasp_critic.parameters_to_optimize, lr=cfg.policy.critic_lr
|
||||
params=policy.grasp_critic.parameters(), lr=cfg.policy.critic_lr
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
||||
lr_scheduler = None
|
||||
@@ -992,7 +994,6 @@ def initialize_offline_replay_buffer(
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_features.keys(),
|
||||
action_mask=active_action_dims,
|
||||
action_delta=cfg.env.wrapper.delta_action,
|
||||
storage_device=storage_device,
|
||||
optimize_memory=True,
|
||||
capacity=cfg.policy.offline_buffer_capacity,
|
||||
@@ -1026,8 +1027,10 @@ def get_observation_features(
|
||||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
observation_features = policy.actor.encoder.get_image_features(observations, normalize=True)
|
||||
next_observation_features = policy.actor.encoder.get_image_features(next_observations, normalize=True)
|
||||
observation_features = policy.actor.encoder.get_cached_image_features(observations, normalize=True)
|
||||
next_observation_features = policy.actor.encoder.get_cached_image_features(
|
||||
next_observations, normalize=True
|
||||
)
|
||||
|
||||
return observation_features, next_observation_features
|
||||
|
||||
@@ -1089,6 +1092,44 @@ def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
|
||||
parameters_queue.put(state_bytes)
|
||||
|
||||
|
||||
def check_weight_gradients(module: nn.Module) -> dict[str, bool]:
|
||||
"""
|
||||
Checks whether each parameter in the module has a gradient.
|
||||
|
||||
Args:
|
||||
module (nn.Module): A PyTorch module whose parameters will be inspected.
|
||||
|
||||
Returns:
|
||||
dict[str, bool]: A dictionary where each key is the parameter name and the value is
|
||||
True if the parameter has an associated gradient (i.e. .grad is not None),
|
||||
otherwise False.
|
||||
"""
|
||||
grad_status = {}
|
||||
for name, param in module.named_parameters():
|
||||
grad_status[name] = param.grad is not None
|
||||
return grad_status
|
||||
|
||||
|
||||
def get_overlapping_parameters(model: nn.Module, grad_status: dict[str, bool]) -> dict[str, bool]:
|
||||
"""
|
||||
Returns a dictionary of parameters (from actor) that also exist in the grad_status dictionary.
|
||||
|
||||
Args:
|
||||
actor (nn.Module): The actor model.
|
||||
grad_status (dict[str, bool]): A dictionary where keys are parameter names and values indicate
|
||||
whether each parameter has a gradient.
|
||||
|
||||
Returns:
|
||||
dict[str, bool]: A dictionary containing only the overlapping parameter names and their gradient status.
|
||||
"""
|
||||
# Get actor parameter names as a set.
|
||||
model_param_names = {name for name, _ in model.named_parameters()}
|
||||
|
||||
# Intersect parameter names between actor and grad_status.
|
||||
overlapping = {name: grad_status[name] for name in grad_status if name in model_param_names}
|
||||
return overlapping
|
||||
|
||||
|
||||
def process_interaction_message(
|
||||
message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user