Compare commits

...

38 Commits

Author SHA1 Message Date
AdilZouitine
a22fe8a6de Refactor SACObservationEncoder to improve modularity and readability. Split initialization into dedicated methods for image and state layers, and enhance caching logic for image features. Update forward method to streamline feature encoding and ensure proper normalization handling. 2025-04-18 12:22:14 +00:00
AdilZouitine
49b5f379a7 Refactor SACPolicy initialization by breaking down the constructor into smaller methods for normalization, encoders, critics, actor, and temperature setup. This enhances readability and maintainability. 2025-04-17 16:37:43 +00:00
AdilZouitine
7a3d8756b4 Refactor input and output normalization handling in SACPolicy for improved clarity and efficiency. Consolidate encoder initialization logic and remove redundant else statements. 2025-04-17 16:05:11 +00:00
AdilZouitine
dc1548fe1a Fix init temp
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
2025-04-16 16:46:37 +02:00
AdilZouitine
23c9441d5f Update log_std_min type to float in PolicyConfig for consistency 2025-04-16 16:46:37 +02:00
AdilZouitine
870e3efb92 fix caching
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
2025-04-16 16:46:37 +02:00
AdilZouitine
bfd48a8b70 Handle caching
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
2025-04-16 16:46:37 +02:00
AdilZouitine
5dc7ff6d3c change the tanh distribution to match hil serl
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
2025-04-16 16:46:37 +02:00
AdilZouitine
ee4ebeac9b match target entropy hil serl
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
2025-04-16 16:46:37 +02:00
AdilZouitine
fe7b47f459 stick to hil serl nn architecture
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
2025-04-16 16:46:37 +02:00
AdilZouitine
044ca3b039 Refactor modeling_sac and parameter handling for clarity and reusability.
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
2025-04-16 16:46:37 +02:00
AdilZouitine
bc36c69b71 fix encoder training 2025-04-16 16:46:37 +02:00
pre-commit-ci[bot]
2b9b05f1ba [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-16 16:46:37 +02:00
Michel Aractingi
9eec7b8bb0 General fixes in code, removed delta action, fixed grasp penalty, added logic to put gripper reward in info 2025-04-16 16:46:37 +02:00
pre-commit-ci[bot]
a80a9cf379 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-16 16:46:37 +02:00
AdilZouitine
7a42af835e fix caching and dataset stats is optional 2025-04-16 16:46:37 +02:00
AdilZouitine
9751328783 Add rounding for safety 2025-04-16 16:46:37 +02:00
pre-commit-ci[bot]
7225bc74a3 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-16 16:46:37 +02:00
AdilZouitine
03b1644bf7 fix sign issue 2025-04-16 16:46:37 +02:00
AdilZouitine
9b6e5a383f Refactor complementary_info handling in ReplayBuffer 2025-04-16 16:46:37 +02:00
AdilZouitine
86466b025f Handle gripper penalty 2025-04-16 16:46:37 +02:00
AdilZouitine
54745f111d fix caching 2025-04-16 16:46:37 +02:00
pre-commit-ci[bot]
82584cca78 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-16 16:46:37 +02:00
AdilZouitine
d3a8c2c247 fix indentation issue 2025-04-16 16:46:37 +02:00
AdilZouitine
74c11c4a75 Enhance SAC configuration and replay buffer with asynchronous prefetching support
- Added async_prefetch parameter to SACConfig for improved buffer management.
- Implemented get_iterator method in ReplayBuffer to support asynchronous prefetching of batches.
- Updated learner_server to utilize the new iterator for online and offline sampling, enhancing training efficiency.
2025-04-16 16:46:37 +02:00
AdilZouitine
2d932b710c Enhance SACPolicy to support shared encoder and optimize action selection
- Cached encoder output in select_action method to reduce redundant computations.
- Updated action selection and grasp critic calls to utilize cached encoder features when available.
2025-04-16 16:46:37 +02:00
AdilZouitine
a54baceabb Enhance SACPolicy and learner server for improved grasp critic integration
- Updated SACPolicy to conditionally compute grasp critic losses based on the presence of discrete actions.
- Refactored the forward method to handle grasp critic model selection and loss computation more clearly.
- Adjusted learner server to utilize optimized parameters for grasp critic during training.
- Improved action handling in the ManiskillMockGripperWrapper to accommodate both tuple and single action inputs.
2025-04-16 16:46:37 +02:00
AdilZouitine
077d18b439 Refactor SACPolicy for improved readability and action dimension handling
- Cleaned up code formatting for better readability, including consistent spacing and removal of unnecessary blank lines.
- Consolidated continuous action dimension calculation to enhance clarity and maintainability.
- Simplified loss return statements in the forward method to improve code structure.
- Ensured grasp critic parameters are included conditionally based on configuration settings.
2025-04-16 16:46:37 +02:00
AdilZouitine
c6cd1475a7 Add mock gripper support and enhance SAC policy action handling
- Introduced mock_gripper parameter in ManiskillEnvConfig to enable gripper simulation.
- Added ManiskillMockGripperWrapper to adjust action space for environments with discrete actions.
- Updated SACPolicy to compute continuous action dimensions correctly, ensuring compatibility with the new gripper setup.
- Refactored action handling in the training loop to accommodate the changes in action dimensions.
2025-04-16 16:46:37 +02:00
AdilZouitine
e35ee47b07 Refactor SAC policy and training loop to enhance discrete action support
- Updated SACPolicy to conditionally compute losses for grasp critic based on num_discrete_actions.
- Simplified forward method to return loss outputs as a dictionary for better clarity.
- Adjusted learner_server to handle both main and grasp critic losses during training.
- Ensured optimizers are created conditionally for grasp critic based on configuration settings.
2025-04-16 16:46:37 +02:00
AdilZouitine
c3f2487026 Refactor SAC configuration and policy to support discrete actions
- Removed GraspCriticNetworkConfig class and integrated its parameters into SACConfig.
- Added num_discrete_actions parameter to SACConfig for better action handling.
- Updated SACPolicy to conditionally create grasp critic networks based on num_discrete_actions.
- Enhanced grasp critic forward pass to handle discrete actions and compute losses accordingly.
2025-04-16 16:46:37 +02:00
Michel Aractingi
c621077b62 Added Gripper quantization wrapper and grasp penalty
removed complementary info from buffer and learner server
removed get_gripper_action function
added gripper parameters to `common/envs/configs.py`
2025-04-16 16:46:37 +02:00
pre-commit-ci[bot]
f5cfd9fd48 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-16 16:46:37 +02:00
s1lent4gnt
22da1739b1 Add grasp critic to the training loop
- Integrated the grasp critic gradient update to the training loop in learner_server
- Added Adam optimizer and configured grasp critic learning rate in configuration_sac
- Added target critics networks update after the critics gradient step
2025-04-16 16:46:37 +02:00
s1lent4gnt
d38d5f988d Add get_gripper_action method to GamepadController 2025-04-16 16:46:37 +02:00
s1lent4gnt
8d1936ffe0 Add gripper penalty wrapper 2025-04-16 16:46:37 +02:00
s1lent4gnt
cef944e1b1 Add complementary info in the replay buffer
- Added complementary info in the add method
- Added complementary info in the sample method
2025-04-16 16:46:37 +02:00
s1lent4gnt
384eb2cd07 Add grasp critic
- Implemented grasp critic to evaluate gripper actions
- Added corresponding config parameters for tuning
2025-04-16 16:46:37 +02:00
10 changed files with 1226 additions and 958 deletions

View File

@@ -171,7 +171,6 @@ class VideoRecordConfig:
class WrapperConfig:
"""Configuration for environment wrappers."""
delta_action: float | None = None
joint_masking_action_space: list[bool] | None = None
@@ -191,7 +190,6 @@ class EnvWrapperConfig:
"""Configuration for environment wrappers."""
display_cameras: bool = False
delta_action: float = 0.1
use_relative_joint_positions: bool = True
add_joint_velocity_to_observation: bool = False
add_ee_pose_to_observation: bool = False
@@ -203,6 +201,10 @@ class EnvWrapperConfig:
joint_masking_action_space: Optional[Any] = None
ee_action_space_params: Optional[EEActionSpaceConfig] = None
use_gripper: bool = False
gripper_quantization_threshold: float | None = 0.8
gripper_penalty: float = 0.0
gripper_penalty_in_reward: bool = False
open_gripper_on_reset: bool = False
@EnvConfig.register_subclass(name="gym_manipulator")
@@ -254,6 +256,7 @@ class ManiskillEnvConfig(EnvConfig):
robot: str = "so100" # This is a hack to make the robot config work
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
wrapper: WrapperConfig = field(default_factory=WrapperConfig)
mock_gripper: bool = False
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),

View File

@@ -51,8 +51,8 @@ class ActorNetworkConfig:
@dataclass
class PolicyConfig:
use_tanh_squash: bool = True
log_std_min: int = -5
log_std_max: int = 2
log_std_min: float = 1e-5
log_std_max: float = 10.0
init_final: float = 0.05
@@ -85,12 +85,15 @@ class SACConfig(PreTrainedConfig):
freeze_vision_encoder: Whether to freeze the vision encoder during training.
image_encoder_hidden_dim: Hidden dimension size for the image encoder.
shared_encoder: Whether to use a shared encoder for actor and critic.
num_discrete_actions: Number of discrete actions, eg for gripper actions.
image_embedding_pooling_dim: Dimension of the image embedding pooling.
concurrency: Configuration for concurrency settings.
actor_learner: Configuration for actor-learner architecture.
online_steps: Number of steps for online training.
online_env_seed: Seed for the online environment.
online_buffer_capacity: Capacity of the online replay buffer.
offline_buffer_capacity: Capacity of the offline replay buffer.
async_prefetch: Whether to use asynchronous prefetching for the buffers.
online_step_before_learning: Number of steps before learning starts.
policy_update_freq: Frequency of policy updates.
discount: Discount factor for the SAC algorithm.
@@ -118,7 +121,7 @@ class SACConfig(PreTrainedConfig):
}
)
dataset_stats: dict[str, dict[str, list[float]]] = field(
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
default_factory=lambda: {
"observation.image": {
"mean": [0.485, 0.456, 0.406],
@@ -144,12 +147,15 @@ class SACConfig(PreTrainedConfig):
freeze_vision_encoder: bool = True
image_encoder_hidden_dim: int = 32
shared_encoder: bool = True
num_discrete_actions: int | None = None
image_embedding_pooling_dim: int = 8
# Training parameter
online_steps: int = 1000000
online_env_seed: int = 10000
online_buffer_capacity: int = 100000
offline_buffer_capacity: int = 100000
async_prefetch: bool = False
online_step_before_learning: int = 100
policy_update_freq: int = 1
@@ -173,7 +179,7 @@ class SACConfig(PreTrainedConfig):
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
grasp_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)

File diff suppressed because it is too large Load Diff

View File

@@ -221,7 +221,6 @@ def record_episode(
events=events,
policy=policy,
fps=fps,
# record_delta_actions=record_delta_actions,
teleoperate=policy is None,
single_task=single_task,
)
@@ -267,8 +266,6 @@ def control_loop(
if teleoperate:
observation, action = robot.teleop_step(record_data=True)
# if record_delta_actions:
# action["action"] = action["action"] - current_joint_positions
else:
observation = robot.capture_observation()

View File

@@ -363,8 +363,6 @@ def replay(
start_episode_t = time.perf_counter()
action = actions[idx]["action"]
# if replay_delta_actions:
# action = action + current_joint_positions
robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t

View File

@@ -231,6 +231,7 @@ def act_with_policy(
cfg=cfg.policy,
env_cfg=cfg.env,
)
policy = policy.eval()
assert isinstance(policy, nn.Module)
obs, info = online_env.reset()

File diff suppressed because it is too large Load Diff

View File

@@ -42,7 +42,6 @@ class HILSerlRobotEnv(gym.Env):
self,
robot,
use_delta_action_space: bool = True,
delta: float | None = None,
display_cameras: bool = False,
):
"""
@@ -55,8 +54,6 @@ class HILSerlRobotEnv(gym.Env):
robot: The robot interface object used to connect and interact with the physical robot.
use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute
joint positions are used.
delta (float or None): A scaling factor for the relative adjustments applied to joint positions. Should be a value between
0 and 1 when using a delta action space.
display_cameras (bool): If True, the robot's camera feeds will be displayed during execution.
"""
super().__init__()
@@ -74,7 +71,6 @@ class HILSerlRobotEnv(gym.Env):
self.current_step = 0
self.episode_data = None
self.delta = delta
self.use_delta_action_space = use_delta_action_space
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
@@ -374,7 +370,7 @@ class RewardWrapper(gym.Wrapper):
self.device = device
def step(self, action):
observation, _, terminated, truncated, info = self.env.step(action)
observation, reward, terminated, truncated, info = self.env.step(action)
images = [
observation[key].to(self.device, non_blocking=self.device.type == "cuda")
for key in observation
@@ -382,15 +378,17 @@ class RewardWrapper(gym.Wrapper):
]
start_time = time.perf_counter()
with torch.inference_mode():
reward = (
success = (
self.reward_classifier.predict_reward(images, threshold=0.8)
if self.reward_classifier is not None
else 0.0
)
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
if reward == 1.0:
if success == 1.0:
terminated = True
reward = 1.0
return observation, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
@@ -720,11 +718,13 @@ class ResetWrapper(gym.Wrapper):
env: HILSerlRobotEnv,
reset_pose: np.ndarray | None = None,
reset_time_s: float = 5,
open_gripper_on_reset: bool = False,
):
super().__init__(env)
self.reset_time_s = reset_time_s
self.reset_pose = reset_pose
self.robot = self.unwrapped.robot
self.open_gripper_on_reset = open_gripper_on_reset
def reset(self, *, seed=None, options=None):
if self.reset_pose is not None:
@@ -733,6 +733,14 @@ class ResetWrapper(gym.Wrapper):
reset_follower_position(self.robot, self.reset_pose)
busy_wait(self.reset_time_s - (time.perf_counter() - start_time))
log_say("Reset the environment done.", play_sounds=True)
if self.open_gripper_on_reset:
current_joint_pos = self.robot.follower_arms["main"].read("Present_Position")
current_joint_pos[-1] = MAX_GRIPPER_COMMAND
self.robot.send_action(torch.from_numpy(current_joint_pos))
busy_wait(0.1)
current_joint_pos[-1] = 0.0
self.robot.send_action(torch.from_numpy(current_joint_pos))
busy_wait(0.2)
else:
log_say(
f"Manually reset the environment for {self.reset_time_s} seconds.",
@@ -761,6 +769,75 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
return observation
class GripperPenaltyWrapper(gym.RewardWrapper):
def __init__(self, env, penalty: float = -0.1, gripper_penalty_in_reward: bool = True):
super().__init__(env)
self.penalty = penalty
self.gripper_penalty_in_reward = gripper_penalty_in_reward
self.last_gripper_state = None
def reward(self, reward, action):
gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND
action_normalized = action - 1.0 # action / MAX_GRIPPER_COMMAND
gripper_penalty_bool = (gripper_state_normalized < 0.5 and action_normalized > 0.5) or (
gripper_state_normalized > 0.75 and action_normalized < -0.5
)
return reward + self.penalty * int(gripper_penalty_bool)
def step(self, action):
self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
if isinstance(action, tuple):
gripper_action = action[0][-1]
else:
gripper_action = action[-1]
obs, reward, terminated, truncated, info = self.env.step(action)
gripper_penalty = self.reward(reward, gripper_action)
if self.gripper_penalty_in_reward:
reward += gripper_penalty
else:
info["gripper_penalty"] = gripper_penalty
return obs, reward, terminated, truncated, info
def reset(self, **kwargs):
self.last_gripper_state = None
obs, info = super().reset(**kwargs)
if self.gripper_penalty_in_reward:
info["gripper_penalty"] = 0.0
return obs, info
class GripperActionWrapper(gym.ActionWrapper):
def __init__(self, env, quantization_threshold: float = 0.2):
super().__init__(env)
self.quantization_threshold = quantization_threshold
def action(self, action):
is_intervention = False
if isinstance(action, tuple):
action, is_intervention = action
gripper_command = action[-1]
# Gripper actions are between 0, 2
# we want to quantize them to -1, 0 or 1
gripper_command = gripper_command - 1.0
if self.quantization_threshold is not None:
# Quantize gripper command to -1, 0 or 1
gripper_command = (
np.sign(gripper_command) if abs(gripper_command) > self.quantization_threshold else 0.0
)
gripper_command = gripper_command * MAX_GRIPPER_COMMAND
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
action[-1] = gripper_action.item()
return action, is_intervention
class EEActionWrapper(gym.ActionWrapper):
def __init__(self, env, ee_action_space_params=None, use_gripper=False):
super().__init__(env)
@@ -780,10 +857,12 @@ class EEActionWrapper(gym.ActionWrapper):
]
)
if self.use_gripper:
action_space_bounds = np.concatenate([action_space_bounds, [1.0]])
# gripper actions open at 2.0, and closed at 0.0
min_action_space_bounds = np.concatenate([-action_space_bounds, [0.0]])
max_action_space_bounds = np.concatenate([action_space_bounds, [2.0]])
ee_action_space = gym.spaces.Box(
low=-action_space_bounds,
high=action_space_bounds,
low=min_action_space_bounds,
high=max_action_space_bounds,
shape=(3 + int(self.use_gripper),),
dtype=np.float32,
)
@@ -820,17 +899,7 @@ class EEActionWrapper(gym.ActionWrapper):
fk_func=self.fk_function,
)
if self.use_gripper:
# Quantize gripper command to -1, 0 or 1
if gripper_command < -0.2:
gripper_command = -1.0
elif gripper_command > 0.2:
gripper_command = 1.0
else:
gripper_command = 0.0
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
target_joint_pos[-1] = gripper_action
target_joint_pos[-1] = gripper_command
return target_joint_pos, is_intervention
@@ -951,11 +1020,11 @@ class GamepadControlWrapper(gym.Wrapper):
if self.use_gripper:
gripper_command = self.controller.gripper_command()
if gripper_command == "open":
gamepad_action = np.concatenate([gamepad_action, [1.0]])
gamepad_action = np.concatenate([gamepad_action, [2.0]])
elif gripper_command == "close":
gamepad_action = np.concatenate([gamepad_action, [-1.0]])
else:
gamepad_action = np.concatenate([gamepad_action, [0.0]])
else:
gamepad_action = np.concatenate([gamepad_action, [1.0]])
# Check episode ending buttons
# We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None
@@ -1095,7 +1164,6 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
env = HILSerlRobotEnv(
robot=robot,
display_cameras=cfg.wrapper.display_cameras,
delta=cfg.wrapper.delta_action,
use_delta_action_space=cfg.wrapper.use_relative_joint_positions
and cfg.wrapper.ee_action_space_params is None,
)
@@ -1118,12 +1186,22 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
# Add reward computation and control wrappers
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
if cfg.wrapper.use_gripper:
env = GripperActionWrapper(env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold)
if cfg.wrapper.gripper_penalty is not None:
env = GripperPenaltyWrapper(
env=env,
penalty=cfg.wrapper.gripper_penalty,
gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward,
)
if cfg.wrapper.ee_action_space_params is not None:
env = EEActionWrapper(
env=env,
ee_action_space_params=cfg.wrapper.ee_action_space_params,
use_gripper=cfg.wrapper.use_gripper,
)
if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad:
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
env = GamepadControlWrapper(
@@ -1140,6 +1218,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
env=env,
reset_pose=cfg.wrapper.fixed_reset_joint_positions,
reset_time_s=cfg.wrapper.reset_time_s,
open_gripper_on_reset=cfg.wrapper.open_gripper_on_reset,
)
if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
@@ -1289,11 +1368,10 @@ def record_dataset(env, policy, cfg):
dataset.push_to_hub()
def replay_episode(env, repo_id, root=None, episode=0):
def replay_episode(env, cfg):
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
local_files_only = root is not None
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
dataset = LeRobotDataset(cfg.repo_id, root=cfg.dataset_root, episodes=[cfg.episode])
env.reset()
actions = dataset.hf_dataset.select_columns("action")
@@ -1301,7 +1379,7 @@ def replay_episode(env, repo_id, root=None, episode=0):
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
action = actions[idx]["action"][:4]
action = actions[idx]["action"]
env.step((action, False))
# env.step((action / env.unwrapped.delta, False))
@@ -1332,9 +1410,7 @@ def main(cfg: EnvConfig):
if cfg.mode == "replay":
replay_episode(
env,
cfg.replay_repo_id,
root=cfg.dataset_root,
episode=cfg.replay_episode,
cfg=cfg,
)
exit()

View File

@@ -1,4 +1,4 @@
#!/usr/bin/env python
# !/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
@@ -269,6 +269,7 @@ def add_actor_information_and_train(
policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
saving_checkpoint = cfg.save_checkpoint
online_steps = cfg.policy.online_steps
async_prefetch = cfg.policy.async_prefetch
# Initialize logging for multiprocessing
if not use_threads(cfg):
@@ -326,6 +327,9 @@ def add_actor_information_and_train(
if cfg.dataset is not None:
dataset_repo_id = cfg.dataset.repo_id
# Initialize iterators
online_iterator = None
offline_iterator = None
# NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
while True:
# Exit the training loop if shutdown is requested
@@ -359,13 +363,26 @@ def add_actor_information_and_train(
if len(replay_buffer) < online_step_before_learning:
continue
if online_iterator is None:
logging.debug("[LEARNER] Initializing online replay buffer iterator")
online_iterator = replay_buffer.get_iterator(
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
)
if offline_replay_buffer is not None and offline_iterator is None:
logging.debug("[LEARNER] Initializing offline replay buffer iterator")
offline_iterator = offline_replay_buffer.get_iterator(
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
)
logging.debug("[LEARNER] Starting optimization loop")
time_for_one_optimization_step = time.time()
for _ in range(utd_ratio - 1):
batch = replay_buffer.sample(batch_size=batch_size)
# Sample from the iterators
batch = next(online_iterator)
if dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
batch_offline = next(offline_iterator)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
@@ -390,26 +407,40 @@ def add_actor_information_and_train(
"done": done,
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
"complementary_info": batch["complementary_info"],
}
# Use the forward method for critic loss
loss_critic = policy.forward(forward_batch, model="critic")
# Use the forward method for critic loss (includes both main critic and grasp critic)
critic_output = policy.forward(forward_batch, model="critic")
# Main critic optimization
loss_critic = critic_output["loss_critic"]
optimizers["critic"].zero_grad()
loss_critic.backward()
# clip gradients
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
)
optimizers["critic"].step()
# Grasp critic optimization (if available)
if policy.config.num_discrete_actions is not None:
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
optimizers["grasp_critic"].zero_grad()
loss_grasp_critic.backward()
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
)
optimizers["grasp_critic"].step()
# Update target networks
policy.update_target_networks()
batch = replay_buffer.sample(batch_size=batch_size)
# Sample for the last update in the UTD ratio
batch = next(online_iterator)
if dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
batch_offline = next(offline_iterator)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
@@ -437,63 +468,80 @@ def add_actor_information_and_train(
"next_observation_feature": next_observation_features,
}
# Use the forward method for critic loss
loss_critic = policy.forward(forward_batch, model="critic")
# Use the forward method for critic loss (includes both main critic and grasp critic)
critic_output = policy.forward(forward_batch, model="critic")
# Main critic optimization
loss_critic = critic_output["loss_critic"]
optimizers["critic"].zero_grad()
loss_critic.backward()
# clip gradients
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
).item()
optimizers["critic"].step()
training_infos = {}
training_infos["loss_critic"] = loss_critic.item()
training_infos["critic_grad_norm"] = critic_grad_norm
# Initialize training info dictionary
training_infos = {
"loss_critic": loss_critic.item(),
"critic_grad_norm": critic_grad_norm,
}
# Grasp critic optimization (if available)
if policy.config.num_discrete_actions is not None:
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
optimizers["grasp_critic"].zero_grad()
loss_grasp_critic.backward()
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
).item()
optimizers["grasp_critic"].step()
# Add grasp critic info to training info
training_infos["loss_grasp_critic"] = loss_grasp_critic.item()
training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm
# Actor and temperature optimization (at specified frequency)
if optimization_step % policy_update_freq == 0:
for _ in range(policy_update_freq):
# Use the forward method for actor loss
loss_actor = policy.forward(forward_batch, model="actor")
# Actor optimization
actor_output = policy.forward(forward_batch, model="actor")
loss_actor = actor_output["loss_actor"]
optimizers["actor"].zero_grad()
loss_actor.backward()
# clip gradients
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value
).item()
optimizers["actor"].step()
# Add actor info to training info
training_infos["loss_actor"] = loss_actor.item()
training_infos["actor_grad_norm"] = actor_grad_norm
# Temperature optimization using forward method
loss_temperature = policy.forward(forward_batch, model="temperature")
# Temperature optimization
temperature_output = policy.forward(forward_batch, model="temperature")
loss_temperature = temperature_output["loss_temperature"]
optimizers["temperature"].zero_grad()
loss_temperature.backward()
# clip gradients
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
).item()
optimizers["temperature"].step()
# Add temperature info to training info
training_infos["loss_temperature"] = loss_temperature.item()
training_infos["temperature_grad_norm"] = temp_grad_norm
training_infos["temperature"] = policy.temperature
# Update temperature
policy.update_temperature()
# Check if it's time to push updated policy to actors
# Push policy to actors if needed
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
last_time_policy_pushed = time.time()
# Update target networks
policy.update_target_networks()
# Log training metrics at specified intervals
@@ -697,7 +745,7 @@ def save_training_checkpoint(
logging.info("Resume training")
def make_optimizers_and_scheduler(cfg, policy: nn.Module):
def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
"""
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
@@ -724,11 +772,19 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
"""
optimizer_actor = torch.optim.Adam(
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
params=policy.actor.parameters_to_optimize,
params=[
p
for n, p in policy.actor.named_parameters()
if not policy.config.shared_encoder or not n.startswith("encoder")
],
lr=cfg.policy.actor_lr,
)
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
if cfg.policy.num_discrete_actions is not None:
optimizer_grasp_critic = torch.optim.Adam(
params=policy.grasp_critic.parameters(), lr=cfg.policy.critic_lr
)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
lr_scheduler = None
optimizers = {
@@ -736,6 +792,8 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
"critic": optimizer_critic,
"temperature": optimizer_temperature,
}
if cfg.policy.num_discrete_actions is not None:
optimizers["grasp_critic"] = optimizer_grasp_critic
return optimizers, lr_scheduler
@@ -936,7 +994,6 @@ def initialize_offline_replay_buffer(
device=device,
state_keys=cfg.policy.input_features.keys(),
action_mask=active_action_dims,
action_delta=cfg.env.wrapper.delta_action,
storage_device=storage_device,
optimize_memory=True,
capacity=cfg.policy.offline_buffer_capacity,
@@ -970,11 +1027,9 @@ def get_observation_features(
return None, None
with torch.no_grad():
observation_features = (
policy.actor.encoder(observations) if policy.actor.encoder is not None else None
)
next_observation_features = (
policy.actor.encoder(next_observations) if policy.actor.encoder is not None else None
observation_features = policy.actor.encoder.get_cached_image_features(observations, normalize=True)
next_observation_features = policy.actor.encoder.get_cached_image_features(
next_observations, normalize=True
)
return observation_features, next_observation_features
@@ -1037,6 +1092,44 @@ def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
parameters_queue.put(state_bytes)
def check_weight_gradients(module: nn.Module) -> dict[str, bool]:
"""
Checks whether each parameter in the module has a gradient.
Args:
module (nn.Module): A PyTorch module whose parameters will be inspected.
Returns:
dict[str, bool]: A dictionary where each key is the parameter name and the value is
True if the parameter has an associated gradient (i.e. .grad is not None),
otherwise False.
"""
grad_status = {}
for name, param in module.named_parameters():
grad_status[name] = param.grad is not None
return grad_status
def get_overlapping_parameters(model: nn.Module, grad_status: dict[str, bool]) -> dict[str, bool]:
"""
Returns a dictionary of parameters (from actor) that also exist in the grad_status dictionary.
Args:
actor (nn.Module): The actor model.
grad_status (dict[str, bool]): A dictionary where keys are parameter names and values indicate
whether each parameter has a gradient.
Returns:
dict[str, bool]: A dictionary containing only the overlapping parameter names and their gradient status.
"""
# Get actor parameter names as a set.
model_param_names = {name for name, _ in model.named_parameters()}
# Intersect parameter names between actor and grad_status.
overlapping = {name: grad_status[name] for name in grad_status if name in model_param_names}
return overlapping
def process_interaction_message(
message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None
):

View File

@@ -1,5 +1,3 @@
import logging
import time
from typing import Any
import einops
@@ -10,7 +8,6 @@ from mani_skill.utils.wrappers.record import RecordEpisode
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
from lerobot.common.envs.configs import ManiskillEnvConfig
from lerobot.configs import parser
def preprocess_maniskill_observation(
@@ -153,6 +150,27 @@ class TimeLimitWrapper(gym.Wrapper):
return super().reset(seed=seed, options=options)
class ManiskillMockGripperWrapper(gym.Wrapper):
def __init__(self, env, nb_discrete_actions: int = 3):
super().__init__(env)
new_shape = env.action_space[0].shape[0] + 1
new_low = np.concatenate([env.action_space[0].low, [0]])
new_high = np.concatenate([env.action_space[0].high, [nb_discrete_actions - 1]])
action_space_agent = gym.spaces.Box(low=new_low, high=new_high, shape=(new_shape,))
self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1]))
def step(self, action):
if isinstance(action, tuple):
action_agent, telop_action = action
else:
telop_action = 0
action_agent = action
real_action = action_agent[:-1]
final_action = (real_action, telop_action)
obs, reward, terminated, truncated, info = self.env.step(final_action)
return obs, reward, terminated, truncated, info
def make_maniskill(
cfg: ManiskillEnvConfig,
n_envs: int | None = None,
@@ -197,40 +215,42 @@ def make_maniskill(
env = ManiSkillCompat(env)
env = ManiSkillActionWrapper(env)
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control
if cfg.mock_gripper:
env = ManiskillMockGripperWrapper(env, nb_discrete_actions=3)
return env
@parser.wrap()
def main(cfg: ManiskillEnvConfig):
"""Main function to run the ManiSkill environment."""
# Create the ManiSkill environment
env = make_maniskill(cfg, n_envs=1)
# @parser.wrap()
# def main(cfg: TrainPipelineConfig):
# """Main function to run the ManiSkill environment."""
# # Create the ManiSkill environment
# env = make_maniskill(cfg.env, n_envs=1)
# Reset the environment
obs, info = env.reset()
# # Reset the environment
# obs, info = env.reset()
# Run a simple interaction loop
sum_reward = 0
for i in range(100):
# Sample a random action
action = env.action_space.sample()
# # Run a simple interaction loop
# sum_reward = 0
# for i in range(100):
# # Sample a random action
# action = env.action_space.sample()
# Step the environment
start_time = time.perf_counter()
obs, reward, terminated, truncated, info = env.step(action)
step_time = time.perf_counter() - start_time
sum_reward += reward
# Log information
# # Step the environment
# start_time = time.perf_counter()
# obs, reward, terminated, truncated, info = env.step(action)
# step_time = time.perf_counter() - start_time
# sum_reward += reward
# # Log information
# Reset if episode terminated
if terminated or truncated:
logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
sum_reward = 0
obs, info = env.reset()
# # Reset if episode terminated
# if terminated or truncated:
# logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
# sum_reward = 0
# obs, info = env.reset()
# Close the environment
env.close()
# # Close the environment
# env.close()
# if __name__ == "__main__":