forked from tangger/lerobot
Add review feedback
This commit is contained in:
@@ -22,6 +22,18 @@ from lerobot.configs.policies import PreTrainedConfig
|
|||||||
from lerobot.configs.types import NormalizationMode
|
from lerobot.configs.types import NormalizationMode
|
||||||
|
|
||||||
|
|
||||||
|
def is_image_feature(key: str) -> bool:
|
||||||
|
"""Check if a feature key represents an image feature.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The feature key to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the key represents an image feature, False otherwise
|
||||||
|
"""
|
||||||
|
return key.startswith("observation.image")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConcurrencyConfig:
|
class ConcurrencyConfig:
|
||||||
"""Configuration for the concurrency of the actor and learner.
|
"""Configuration for the concurrency of the actor and learner.
|
||||||
@@ -203,7 +215,7 @@ class SACConfig(PreTrainedConfig):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def validate_features(self) -> None:
|
def validate_features(self) -> None:
|
||||||
has_image = any(key.startswith("observation.image") for key in self.input_features)
|
has_image = any(is_image_feature(key) for key in self.input_features)
|
||||||
has_state = "observation.state" in self.input_features
|
has_state = "observation.state" in self.input_features
|
||||||
|
|
||||||
if not (has_state or has_image):
|
if not (has_state or has_image):
|
||||||
@@ -216,7 +228,7 @@ class SACConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def image_features(self) -> list[str]:
|
def image_features(self) -> list[str]:
|
||||||
return [key for key in self.input_features if "image" in key]
|
return [key for key in self.input_features if is_image_feature(key)]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def observation_delta_indices(self) -> list:
|
def observation_delta_indices(self) -> list:
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from torch.distributions import MultivariateNormal, TanhTransform, Transform, Tr
|
|||||||
|
|
||||||
from lerobot.common.policies.normalize import NormalizeBuffer, UnnormalizeBuffer
|
from lerobot.common.policies.normalize import NormalizeBuffer, UnnormalizeBuffer
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
from lerobot.common.policies.sac.configuration_sac import SACConfig, is_image_feature
|
||||||
from lerobot.common.policies.utils import get_device_from_parameters
|
from lerobot.common.policies.utils import get_device_from_parameters
|
||||||
|
|
||||||
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
|
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
|
||||||
@@ -264,6 +264,7 @@ class SACPolicy(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||||
|
# TODO: Get indices before forward pass to avoid unnecessary computation
|
||||||
if self.config.num_subsample_critics is not None:
|
if self.config.num_subsample_critics is not None:
|
||||||
indices = torch.randperm(self.config.num_critics)
|
indices = torch.randperm(self.config.num_critics)
|
||||||
indices = indices[: self.config.num_subsample_critics]
|
indices = indices[: self.config.num_subsample_critics]
|
||||||
@@ -468,6 +469,7 @@ class SACPolicy(
|
|||||||
|
|
||||||
def _init_actor(self, continuous_action_dim):
|
def _init_actor(self, continuous_action_dim):
|
||||||
"""Initialize policy actor network and default target entropy."""
|
"""Initialize policy actor network and default target entropy."""
|
||||||
|
# NOTE: The actor select only the continuous action part
|
||||||
self.actor = Policy(
|
self.actor = Policy(
|
||||||
encoder=self.encoder_actor,
|
encoder=self.encoder_actor,
|
||||||
network=MLP(input_dim=self.encoder_actor.output_dim, **asdict(self.config.actor_network_kwargs)),
|
network=MLP(input_dim=self.encoder_actor.output_dim, **asdict(self.config.actor_network_kwargs)),
|
||||||
@@ -500,7 +502,7 @@ class SACObservationEncoder(nn.Module):
|
|||||||
self._compute_output_dim()
|
self._compute_output_dim()
|
||||||
|
|
||||||
def _init_image_layers(self) -> None:
|
def _init_image_layers(self) -> None:
|
||||||
self.image_keys = [k for k in self.config.input_features if k.startswith("observation.image")]
|
self.image_keys = [k for k in self.config.input_features if is_image_feature(k)]
|
||||||
self.has_images = bool(self.image_keys)
|
self.has_images = bool(self.image_keys)
|
||||||
if not self.has_images:
|
if not self.has_images:
|
||||||
return
|
return
|
||||||
@@ -928,7 +930,7 @@ class Policy(nn.Module):
|
|||||||
class DefaultImageEncoder(nn.Module):
|
class DefaultImageEncoder(nn.Module):
|
||||||
def __init__(self, config: SACConfig):
|
def __init__(self, config: SACConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
image_key = next(key for key in config.input_features.keys() if key.startswith("observation.image")) # noqa: SIM118
|
image_key = next(key for key in config.input_features if is_image_feature(key))
|
||||||
self.image_enc_layers = nn.Sequential(
|
self.image_enc_layers = nn.Sequential(
|
||||||
nn.Conv2d(
|
nn.Conv2d(
|
||||||
in_channels=config.input_features[image_key].shape[0],
|
in_channels=config.input_features[image_key].shape[0],
|
||||||
@@ -959,7 +961,6 @@ class DefaultImageEncoder(nn.Module):
|
|||||||
),
|
),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
)
|
)
|
||||||
# Get first image key from input features
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.image_enc_layers(x)
|
x = self.image_enc_layers(x)
|
||||||
|
|||||||
@@ -271,9 +271,11 @@ class ReplayBuffer:
|
|||||||
|
|
||||||
# Split the augmented images back to their sources
|
# Split the augmented images back to their sources
|
||||||
for i, key in enumerate(image_keys):
|
for i, key in enumerate(image_keys):
|
||||||
# State images are at even indices (0, 2, 4...)
|
# Calculate offsets for the current image key:
|
||||||
|
# For each key, we have 2*batch_size images (batch_size for states, batch_size for next_states)
|
||||||
|
# States start at index i*2*batch_size and take up batch_size slots
|
||||||
batch_state[key] = augmented_images[i * 2 * batch_size : (i * 2 + 1) * batch_size]
|
batch_state[key] = augmented_images[i * 2 * batch_size : (i * 2 + 1) * batch_size]
|
||||||
# Next state images are at odd indices (1, 3, 5...)
|
# Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots
|
||||||
batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size]
|
batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size]
|
||||||
|
|
||||||
# Sample other tensors
|
# Sample other tensors
|
||||||
|
|||||||
@@ -111,7 +111,9 @@ def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
|
|||||||
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
|
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
|
||||||
buffer = io.BytesIO(buffer)
|
buffer = io.BytesIO(buffer)
|
||||||
buffer.seek(0)
|
buffer.seek(0)
|
||||||
return torch.load(buffer, weights_only=False) # nosec B614: Safe usage of torch.load
|
return torch.load(buffer, weights_only=False) # nosec B614: Using weights_only=False relies on pickle which has security implications.
|
||||||
|
# This is currently safe as we only deserialize trusted internal data.
|
||||||
|
# TODO: Verify if weights_only=True would work for our use case (safer default in torch 2.6+)
|
||||||
|
|
||||||
|
|
||||||
def python_object_to_bytes(python_object: Any) -> bytes:
|
def python_object_to_bytes(python_object: Any) -> bytes:
|
||||||
|
|||||||
Reference in New Issue
Block a user