Added gripper control mechanism to gym_manipulator
Moved HilSerl env config to configs/env/configs.py fixes in actor_server and modeling_sac and configuration_sac added the possibility of ignoring missing keys in env_cfg in get_features_from_env_config function
This commit is contained in:
committed by
AdilZouitine
parent
79e0f6e06c
commit
02b9ea9446
@@ -14,10 +14,12 @@
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
|
||||
@@ -159,20 +161,84 @@ class XarmEnv(EnvConfig):
|
||||
@dataclass
|
||||
class VideoRecordConfig:
|
||||
"""Configuration for video recording in ManiSkill environments."""
|
||||
|
||||
enabled: bool = False
|
||||
record_dir: str = "videos"
|
||||
trajectory_name: str = "trajectory"
|
||||
|
||||
|
||||
@dataclass
|
||||
class WrapperConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
|
||||
delta_action: float | None = None
|
||||
joint_masking_action_space: list[bool] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EEActionSpaceConfig:
|
||||
"""Configuration parameters for end-effector action space."""
|
||||
|
||||
x_step_size: float
|
||||
y_step_size: float
|
||||
z_step_size: float
|
||||
bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
|
||||
use_gamepad: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvWrapperConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
|
||||
display_cameras: bool = False
|
||||
delta_action: float = 0.1
|
||||
use_relative_joint_positions: bool = True
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
crop_params_dict: Optional[Dict[str, Tuple[int, int, int, int]]] = None
|
||||
resize_size: Optional[Tuple[int, int]] = None
|
||||
control_time_s: float = 20.0
|
||||
fixed_reset_joint_positions: Optional[Any] = None
|
||||
reset_time_s: float = 5.0
|
||||
joint_masking_action_space: Optional[Any] = None
|
||||
ee_action_space_params: Optional[EEActionSpaceConfig] = None
|
||||
use_gripper: bool = False
|
||||
|
||||
|
||||
@EnvConfig.register_subclass(name="gym_manipulator")
|
||||
@dataclass
|
||||
class HILSerlRobotEnvConfig(EnvConfig):
|
||||
"""Configuration for the HILSerlRobotEnv environment."""
|
||||
|
||||
robot: Optional[RobotConfig] = None
|
||||
wrapper: Optional[EnvWrapperConfig] = None
|
||||
fps: int = 10
|
||||
name: str = "real_robot"
|
||||
mode: str = None # Either "record", "replay", None
|
||||
repo_id: Optional[str] = None
|
||||
dataset_root: Optional[str] = None
|
||||
task: str = ""
|
||||
num_episodes: int = 10 # only for record mode
|
||||
episode: int = 0
|
||||
device: str = "cuda"
|
||||
push_to_hub: bool = True
|
||||
pretrained_policy_name_or_path: Optional[str] = None
|
||||
reward_classifier: dict[str, str | None] = field(
|
||||
default_factory=lambda: {
|
||||
"pretrained_path": None,
|
||||
"config_path": None,
|
||||
}
|
||||
)
|
||||
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("maniskill_push")
|
||||
@dataclass
|
||||
class ManiskillEnvConfig(EnvConfig):
|
||||
"""Configuration for the ManiSkill environment."""
|
||||
|
||||
name: str = "maniskill/pushcube"
|
||||
task: str = "PushCube-v1"
|
||||
image_size: int = 64
|
||||
@@ -185,7 +251,7 @@ class ManiskillEnvConfig(EnvConfig):
|
||||
render_mode: str = "rgb_array"
|
||||
render_size: int = 64
|
||||
device: str = "cuda"
|
||||
robot: str = "so100" # This is a hack to make the robot config work
|
||||
robot: str = "so100" # This is a hack to make the robot config work
|
||||
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
|
||||
wrapper: WrapperConfig = field(default_factory=WrapperConfig)
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
@@ -218,4 +284,4 @@ class ManiskillEnvConfig(EnvConfig):
|
||||
"control_mode": self.control_mode,
|
||||
"sensor_configs": {"width": self.image_size, "height": self.image_size},
|
||||
"num_envs": 1,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,9 +49,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, (
|
||||
f"expect channel last images, but instead got {img.shape=}"
|
||||
)
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
@@ -91,7 +89,7 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||
else:
|
||||
feature = ft
|
||||
|
||||
policy_key = env_cfg.features_map[key]
|
||||
policy_key = env_cfg.features_map.get(key, key)
|
||||
policy_features[policy_key] = feature
|
||||
|
||||
return policy_features
|
||||
|
||||
@@ -20,7 +20,7 @@ from typing import Any, Optional
|
||||
|
||||
from lerobot.common.optim.optimizers import MultiAdamConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode, PolicyFeature, FeatureType
|
||||
from lerobot.configs.types import NormalizationMode, PolicyFeature, FeatureType
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -29,7 +29,6 @@ class ConcurrencyConfig:
|
||||
learner: str = "threads"
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorLearnerConfig:
|
||||
learner_host: str = "127.0.0.1"
|
||||
@@ -110,6 +109,7 @@ class SACConfig(PreTrainedConfig):
|
||||
use_backup_entropy: Whether to use backup entropy for the SAC algorithm.
|
||||
grad_clip_norm: Gradient clipping norm for the SAC algorithm.
|
||||
"""
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
@@ -152,8 +152,8 @@ class SACConfig(PreTrainedConfig):
|
||||
camera_number: int = 1
|
||||
device: str = "cuda"
|
||||
storage_device: str = "cpu"
|
||||
# Set to "helper2424/resnet10" for hil serl
|
||||
vision_encoder_name: str | None = None
|
||||
# Set to "helper2424/resnet10" for hil serl
|
||||
vision_encoder_name: str | None = None
|
||||
freeze_vision_encoder: bool = True
|
||||
image_encoder_hidden_dim: int = 32
|
||||
shared_encoder: bool = True
|
||||
@@ -163,7 +163,7 @@ class SACConfig(PreTrainedConfig):
|
||||
online_env_seed: int = 10000
|
||||
online_buffer_capacity: int = 100000
|
||||
offline_buffer_capacity: int = 100000
|
||||
online_step_before_learning: int = 100
|
||||
online_step_before_learning: int = 100
|
||||
policy_update_freq: int = 1
|
||||
|
||||
# SAC algorithm parameters
|
||||
@@ -181,24 +181,14 @@ class SACConfig(PreTrainedConfig):
|
||||
target_entropy: float | None = None
|
||||
use_backup_entropy: bool = True
|
||||
grad_clip_norm: float = 40.0
|
||||
|
||||
|
||||
# Network configuration
|
||||
critic_network_kwargs: CriticNetworkConfig = field(
|
||||
default_factory=CriticNetworkConfig
|
||||
)
|
||||
actor_network_kwargs: ActorNetworkConfig = field(
|
||||
default_factory=ActorNetworkConfig
|
||||
)
|
||||
policy_kwargs: PolicyConfig = field(
|
||||
default_factory=PolicyConfig
|
||||
)
|
||||
|
||||
actor_learner_config: ActorLearnerConfig = field(
|
||||
default_factory=ActorLearnerConfig
|
||||
)
|
||||
concurrency: ConcurrencyConfig = field(
|
||||
default_factory=ConcurrencyConfig
|
||||
)
|
||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
||||
|
||||
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
||||
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
@@ -218,18 +208,20 @@ class SACConfig(PreTrainedConfig):
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if "observation.image" not in self.input_features:
|
||||
raise ValueError("You must provide 'observation.image' in the input features")
|
||||
|
||||
if "observation.state" not in self.input_features:
|
||||
raise ValueError("You must provide 'observation.state' in the input features")
|
||||
|
||||
has_image = any(key.startswith("observation.image") for key in self.input_features)
|
||||
has_state = "observation.state" in self.input_features
|
||||
|
||||
if not (has_state or has_image):
|
||||
raise ValueError(
|
||||
"You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features"
|
||||
)
|
||||
|
||||
if "action" not in self.output_features:
|
||||
raise ValueError("You must provide 'action' in the output features")
|
||||
|
||||
@property
|
||||
def image_features(self) -> list[str]:
|
||||
return [key for key in self.input_features.keys() if 'image' in key]
|
||||
return [key for key in self.input_features.keys() if "image" in key]
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
@@ -243,9 +235,13 @@ class SACConfig(PreTrainedConfig):
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import draccus
|
||||
|
||||
config = SACConfig()
|
||||
draccus.set_config_type("json")
|
||||
draccus.dump(config=config, stream=open(file='run_config.json', mode='w'), )
|
||||
|
||||
draccus.dump(
|
||||
config=config,
|
||||
stream=open(file="run_config.json", mode="w"),
|
||||
)
|
||||
|
||||
@@ -39,7 +39,6 @@ from lerobot.common.policies.utils import get_device_from_parameters
|
||||
class SACPolicy(
|
||||
PreTrainedPolicy,
|
||||
):
|
||||
|
||||
config_class = SACConfig
|
||||
name = "sac"
|
||||
|
||||
@@ -53,9 +52,7 @@ class SACPolicy(
|
||||
self.config = config
|
||||
|
||||
if config.dataset_stats is not None:
|
||||
input_normalization_params = _convert_normalization_params_to_tensor(
|
||||
config.dataset_stats
|
||||
)
|
||||
input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features,
|
||||
config.normalization_mapping,
|
||||
@@ -64,12 +61,10 @@ class SACPolicy(
|
||||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
|
||||
output_normalization_params = _convert_normalization_params_to_tensor(
|
||||
config.dataset_stats
|
||||
)
|
||||
output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
||||
|
||||
# HACK: This is hacky and should be removed
|
||||
dataset_stats = dataset_stats or output_normalization_params
|
||||
dataset_stats = dataset_stats or output_normalization_params
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
@@ -138,7 +133,6 @@ class SACPolicy(
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)]))
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return {
|
||||
"actor": self.actor.parameters_to_optimize,
|
||||
@@ -655,9 +649,10 @@ class SACObservationEncoder(nn.Module):
|
||||
class DefaultImageEncoder(nn.Module):
|
||||
def __init__(self, config: SACConfig):
|
||||
super().__init__()
|
||||
image_key = next(key for key in config.input_features.keys() if key.startswith("observation.image")) # noqa: SIM118
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=config.input_features["observation.image"].shape[0],
|
||||
in_channels=config.input_features[image_key].shape[0],
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
@@ -685,7 +680,9 @@ class DefaultImageEncoder(nn.Module):
|
||||
),
|
||||
nn.ReLU(),
|
||||
)
|
||||
dummy_batch = torch.zeros(1, *config.input_features["observation.image"].shape)
|
||||
# Get first image key from input features
|
||||
image_key = next(key for key in config.input_features.keys() if key.startswith("observation.image")) # noqa: SIM118
|
||||
dummy_batch = torch.zeros(1, *config.input_features[image_key].shape)
|
||||
with torch.inference_mode():
|
||||
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||
self.image_enc_layers.extend(
|
||||
@@ -844,8 +841,10 @@ if __name__ == "__main__":
|
||||
import draccus
|
||||
|
||||
from lerobot.configs import parser
|
||||
|
||||
@parser.wrap()
|
||||
def main(config: SACConfig):
|
||||
policy = SACPolicy(config=config)
|
||||
print("yolo")
|
||||
main()
|
||||
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user