forked from tangger/lerobot
Compare commits
38 Commits
user/azoui
...
user/miche
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a22fe8a6de | ||
|
|
49b5f379a7 | ||
|
|
7a3d8756b4 | ||
|
|
dc1548fe1a | ||
|
|
23c9441d5f | ||
|
|
870e3efb92 | ||
|
|
bfd48a8b70 | ||
|
|
5dc7ff6d3c | ||
|
|
ee4ebeac9b | ||
|
|
fe7b47f459 | ||
|
|
044ca3b039 | ||
|
|
bc36c69b71 | ||
|
|
2b9b05f1ba | ||
|
|
9eec7b8bb0 | ||
|
|
a80a9cf379 | ||
|
|
7a42af835e | ||
|
|
9751328783 | ||
|
|
7225bc74a3 | ||
|
|
03b1644bf7 | ||
|
|
9b6e5a383f | ||
|
|
86466b025f | ||
|
|
54745f111d | ||
|
|
82584cca78 | ||
|
|
d3a8c2c247 | ||
|
|
74c11c4a75 | ||
|
|
2d932b710c | ||
|
|
a54baceabb | ||
|
|
077d18b439 | ||
|
|
c6cd1475a7 | ||
|
|
e35ee47b07 | ||
|
|
c3f2487026 | ||
|
|
c621077b62 | ||
|
|
f5cfd9fd48 | ||
|
|
22da1739b1 | ||
|
|
d38d5f988d | ||
|
|
8d1936ffe0 | ||
|
|
cef944e1b1 | ||
|
|
384eb2cd07 |
@@ -171,7 +171,6 @@ class VideoRecordConfig:
|
||||
class WrapperConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
|
||||
delta_action: float | None = None
|
||||
joint_masking_action_space: list[bool] | None = None
|
||||
|
||||
|
||||
@@ -191,7 +190,6 @@ class EnvWrapperConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
|
||||
display_cameras: bool = False
|
||||
delta_action: float = 0.1
|
||||
use_relative_joint_positions: bool = True
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
@@ -203,6 +201,10 @@ class EnvWrapperConfig:
|
||||
joint_masking_action_space: Optional[Any] = None
|
||||
ee_action_space_params: Optional[EEActionSpaceConfig] = None
|
||||
use_gripper: bool = False
|
||||
gripper_quantization_threshold: float | None = 0.8
|
||||
gripper_penalty: float = 0.0
|
||||
gripper_penalty_in_reward: bool = False
|
||||
open_gripper_on_reset: bool = False
|
||||
|
||||
|
||||
@EnvConfig.register_subclass(name="gym_manipulator")
|
||||
@@ -254,6 +256,7 @@ class ManiskillEnvConfig(EnvConfig):
|
||||
robot: str = "so100" # This is a hack to make the robot config work
|
||||
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
|
||||
wrapper: WrapperConfig = field(default_factory=WrapperConfig)
|
||||
mock_gripper: bool = False
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
|
||||
@@ -51,8 +51,8 @@ class ActorNetworkConfig:
|
||||
@dataclass
|
||||
class PolicyConfig:
|
||||
use_tanh_squash: bool = True
|
||||
log_std_min: int = -5
|
||||
log_std_max: int = 2
|
||||
log_std_min: float = 1e-5
|
||||
log_std_max: float = 10.0
|
||||
init_final: float = 0.05
|
||||
|
||||
|
||||
@@ -85,12 +85,15 @@ class SACConfig(PreTrainedConfig):
|
||||
freeze_vision_encoder: Whether to freeze the vision encoder during training.
|
||||
image_encoder_hidden_dim: Hidden dimension size for the image encoder.
|
||||
shared_encoder: Whether to use a shared encoder for actor and critic.
|
||||
num_discrete_actions: Number of discrete actions, eg for gripper actions.
|
||||
image_embedding_pooling_dim: Dimension of the image embedding pooling.
|
||||
concurrency: Configuration for concurrency settings.
|
||||
actor_learner: Configuration for actor-learner architecture.
|
||||
online_steps: Number of steps for online training.
|
||||
online_env_seed: Seed for the online environment.
|
||||
online_buffer_capacity: Capacity of the online replay buffer.
|
||||
offline_buffer_capacity: Capacity of the offline replay buffer.
|
||||
async_prefetch: Whether to use asynchronous prefetching for the buffers.
|
||||
online_step_before_learning: Number of steps before learning starts.
|
||||
policy_update_freq: Frequency of policy updates.
|
||||
discount: Discount factor for the SAC algorithm.
|
||||
@@ -118,7 +121,7 @@ class SACConfig(PreTrainedConfig):
|
||||
}
|
||||
)
|
||||
|
||||
dataset_stats: dict[str, dict[str, list[float]]] = field(
|
||||
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
@@ -144,12 +147,15 @@ class SACConfig(PreTrainedConfig):
|
||||
freeze_vision_encoder: bool = True
|
||||
image_encoder_hidden_dim: int = 32
|
||||
shared_encoder: bool = True
|
||||
num_discrete_actions: int | None = None
|
||||
image_embedding_pooling_dim: int = 8
|
||||
|
||||
# Training parameter
|
||||
online_steps: int = 1000000
|
||||
online_env_seed: int = 10000
|
||||
online_buffer_capacity: int = 100000
|
||||
offline_buffer_capacity: int = 100000
|
||||
async_prefetch: bool = False
|
||||
online_step_before_learning: int = 100
|
||||
policy_update_freq: int = 1
|
||||
|
||||
@@ -173,7 +179,7 @@ class SACConfig(PreTrainedConfig):
|
||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
||||
|
||||
grasp_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
||||
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -221,7 +221,6 @@ def record_episode(
|
||||
events=events,
|
||||
policy=policy,
|
||||
fps=fps,
|
||||
# record_delta_actions=record_delta_actions,
|
||||
teleoperate=policy is None,
|
||||
single_task=single_task,
|
||||
)
|
||||
@@ -267,8 +266,6 @@ def control_loop(
|
||||
|
||||
if teleoperate:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
# if record_delta_actions:
|
||||
# action["action"] = action["action"] - current_joint_positions
|
||||
else:
|
||||
observation = robot.capture_observation()
|
||||
|
||||
|
||||
@@ -363,8 +363,6 @@ def replay(
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action = actions[idx]["action"]
|
||||
# if replay_delta_actions:
|
||||
# action = action + current_joint_positions
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
|
||||
@@ -231,6 +231,7 @@ def act_with_policy(
|
||||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
policy = policy.eval()
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
obs, info = online_env.reset()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -42,7 +42,6 @@ class HILSerlRobotEnv(gym.Env):
|
||||
self,
|
||||
robot,
|
||||
use_delta_action_space: bool = True,
|
||||
delta: float | None = None,
|
||||
display_cameras: bool = False,
|
||||
):
|
||||
"""
|
||||
@@ -55,8 +54,6 @@ class HILSerlRobotEnv(gym.Env):
|
||||
robot: The robot interface object used to connect and interact with the physical robot.
|
||||
use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute
|
||||
joint positions are used.
|
||||
delta (float or None): A scaling factor for the relative adjustments applied to joint positions. Should be a value between
|
||||
0 and 1 when using a delta action space.
|
||||
display_cameras (bool): If True, the robot's camera feeds will be displayed during execution.
|
||||
"""
|
||||
super().__init__()
|
||||
@@ -74,7 +71,6 @@ class HILSerlRobotEnv(gym.Env):
|
||||
self.current_step = 0
|
||||
self.episode_data = None
|
||||
|
||||
self.delta = delta
|
||||
self.use_delta_action_space = use_delta_action_space
|
||||
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
|
||||
|
||||
@@ -374,7 +370,7 @@ class RewardWrapper(gym.Wrapper):
|
||||
self.device = device
|
||||
|
||||
def step(self, action):
|
||||
observation, _, terminated, truncated, info = self.env.step(action)
|
||||
observation, reward, terminated, truncated, info = self.env.step(action)
|
||||
images = [
|
||||
observation[key].to(self.device, non_blocking=self.device.type == "cuda")
|
||||
for key in observation
|
||||
@@ -382,15 +378,17 @@ class RewardWrapper(gym.Wrapper):
|
||||
]
|
||||
start_time = time.perf_counter()
|
||||
with torch.inference_mode():
|
||||
reward = (
|
||||
success = (
|
||||
self.reward_classifier.predict_reward(images, threshold=0.8)
|
||||
if self.reward_classifier is not None
|
||||
else 0.0
|
||||
)
|
||||
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
|
||||
|
||||
if reward == 1.0:
|
||||
if success == 1.0:
|
||||
terminated = True
|
||||
reward = 1.0
|
||||
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
@@ -720,11 +718,13 @@ class ResetWrapper(gym.Wrapper):
|
||||
env: HILSerlRobotEnv,
|
||||
reset_pose: np.ndarray | None = None,
|
||||
reset_time_s: float = 5,
|
||||
open_gripper_on_reset: bool = False,
|
||||
):
|
||||
super().__init__(env)
|
||||
self.reset_time_s = reset_time_s
|
||||
self.reset_pose = reset_pose
|
||||
self.robot = self.unwrapped.robot
|
||||
self.open_gripper_on_reset = open_gripper_on_reset
|
||||
|
||||
def reset(self, *, seed=None, options=None):
|
||||
if self.reset_pose is not None:
|
||||
@@ -733,6 +733,14 @@ class ResetWrapper(gym.Wrapper):
|
||||
reset_follower_position(self.robot, self.reset_pose)
|
||||
busy_wait(self.reset_time_s - (time.perf_counter() - start_time))
|
||||
log_say("Reset the environment done.", play_sounds=True)
|
||||
if self.open_gripper_on_reset:
|
||||
current_joint_pos = self.robot.follower_arms["main"].read("Present_Position")
|
||||
current_joint_pos[-1] = MAX_GRIPPER_COMMAND
|
||||
self.robot.send_action(torch.from_numpy(current_joint_pos))
|
||||
busy_wait(0.1)
|
||||
current_joint_pos[-1] = 0.0
|
||||
self.robot.send_action(torch.from_numpy(current_joint_pos))
|
||||
busy_wait(0.2)
|
||||
else:
|
||||
log_say(
|
||||
f"Manually reset the environment for {self.reset_time_s} seconds.",
|
||||
@@ -761,6 +769,75 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
|
||||
return observation
|
||||
|
||||
|
||||
class GripperPenaltyWrapper(gym.RewardWrapper):
|
||||
def __init__(self, env, penalty: float = -0.1, gripper_penalty_in_reward: bool = True):
|
||||
super().__init__(env)
|
||||
self.penalty = penalty
|
||||
self.gripper_penalty_in_reward = gripper_penalty_in_reward
|
||||
self.last_gripper_state = None
|
||||
|
||||
def reward(self, reward, action):
|
||||
gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND
|
||||
|
||||
action_normalized = action - 1.0 # action / MAX_GRIPPER_COMMAND
|
||||
|
||||
gripper_penalty_bool = (gripper_state_normalized < 0.5 and action_normalized > 0.5) or (
|
||||
gripper_state_normalized > 0.75 and action_normalized < -0.5
|
||||
)
|
||||
|
||||
return reward + self.penalty * int(gripper_penalty_bool)
|
||||
|
||||
def step(self, action):
|
||||
self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
||||
if isinstance(action, tuple):
|
||||
gripper_action = action[0][-1]
|
||||
else:
|
||||
gripper_action = action[-1]
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
gripper_penalty = self.reward(reward, gripper_action)
|
||||
|
||||
if self.gripper_penalty_in_reward:
|
||||
reward += gripper_penalty
|
||||
else:
|
||||
info["gripper_penalty"] = gripper_penalty
|
||||
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, **kwargs):
|
||||
self.last_gripper_state = None
|
||||
obs, info = super().reset(**kwargs)
|
||||
if self.gripper_penalty_in_reward:
|
||||
info["gripper_penalty"] = 0.0
|
||||
return obs, info
|
||||
|
||||
|
||||
class GripperActionWrapper(gym.ActionWrapper):
|
||||
def __init__(self, env, quantization_threshold: float = 0.2):
|
||||
super().__init__(env)
|
||||
self.quantization_threshold = quantization_threshold
|
||||
|
||||
def action(self, action):
|
||||
is_intervention = False
|
||||
if isinstance(action, tuple):
|
||||
action, is_intervention = action
|
||||
gripper_command = action[-1]
|
||||
|
||||
# Gripper actions are between 0, 2
|
||||
# we want to quantize them to -1, 0 or 1
|
||||
gripper_command = gripper_command - 1.0
|
||||
|
||||
if self.quantization_threshold is not None:
|
||||
# Quantize gripper command to -1, 0 or 1
|
||||
gripper_command = (
|
||||
np.sign(gripper_command) if abs(gripper_command) > self.quantization_threshold else 0.0
|
||||
)
|
||||
gripper_command = gripper_command * MAX_GRIPPER_COMMAND
|
||||
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
||||
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
|
||||
action[-1] = gripper_action.item()
|
||||
return action, is_intervention
|
||||
|
||||
|
||||
class EEActionWrapper(gym.ActionWrapper):
|
||||
def __init__(self, env, ee_action_space_params=None, use_gripper=False):
|
||||
super().__init__(env)
|
||||
@@ -780,10 +857,12 @@ class EEActionWrapper(gym.ActionWrapper):
|
||||
]
|
||||
)
|
||||
if self.use_gripper:
|
||||
action_space_bounds = np.concatenate([action_space_bounds, [1.0]])
|
||||
# gripper actions open at 2.0, and closed at 0.0
|
||||
min_action_space_bounds = np.concatenate([-action_space_bounds, [0.0]])
|
||||
max_action_space_bounds = np.concatenate([action_space_bounds, [2.0]])
|
||||
ee_action_space = gym.spaces.Box(
|
||||
low=-action_space_bounds,
|
||||
high=action_space_bounds,
|
||||
low=min_action_space_bounds,
|
||||
high=max_action_space_bounds,
|
||||
shape=(3 + int(self.use_gripper),),
|
||||
dtype=np.float32,
|
||||
)
|
||||
@@ -820,17 +899,7 @@ class EEActionWrapper(gym.ActionWrapper):
|
||||
fk_func=self.fk_function,
|
||||
)
|
||||
if self.use_gripper:
|
||||
# Quantize gripper command to -1, 0 or 1
|
||||
if gripper_command < -0.2:
|
||||
gripper_command = -1.0
|
||||
elif gripper_command > 0.2:
|
||||
gripper_command = 1.0
|
||||
else:
|
||||
gripper_command = 0.0
|
||||
|
||||
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
||||
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
|
||||
target_joint_pos[-1] = gripper_action
|
||||
target_joint_pos[-1] = gripper_command
|
||||
|
||||
return target_joint_pos, is_intervention
|
||||
|
||||
@@ -951,11 +1020,11 @@ class GamepadControlWrapper(gym.Wrapper):
|
||||
if self.use_gripper:
|
||||
gripper_command = self.controller.gripper_command()
|
||||
if gripper_command == "open":
|
||||
gamepad_action = np.concatenate([gamepad_action, [1.0]])
|
||||
gamepad_action = np.concatenate([gamepad_action, [2.0]])
|
||||
elif gripper_command == "close":
|
||||
gamepad_action = np.concatenate([gamepad_action, [-1.0]])
|
||||
else:
|
||||
gamepad_action = np.concatenate([gamepad_action, [0.0]])
|
||||
else:
|
||||
gamepad_action = np.concatenate([gamepad_action, [1.0]])
|
||||
|
||||
# Check episode ending buttons
|
||||
# We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None
|
||||
@@ -1095,7 +1164,6 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||
env = HILSerlRobotEnv(
|
||||
robot=robot,
|
||||
display_cameras=cfg.wrapper.display_cameras,
|
||||
delta=cfg.wrapper.delta_action,
|
||||
use_delta_action_space=cfg.wrapper.use_relative_joint_positions
|
||||
and cfg.wrapper.ee_action_space_params is None,
|
||||
)
|
||||
@@ -1118,12 +1186,22 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||
# Add reward computation and control wrappers
|
||||
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
||||
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
||||
if cfg.wrapper.use_gripper:
|
||||
env = GripperActionWrapper(env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold)
|
||||
if cfg.wrapper.gripper_penalty is not None:
|
||||
env = GripperPenaltyWrapper(
|
||||
env=env,
|
||||
penalty=cfg.wrapper.gripper_penalty,
|
||||
gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward,
|
||||
)
|
||||
|
||||
if cfg.wrapper.ee_action_space_params is not None:
|
||||
env = EEActionWrapper(
|
||||
env=env,
|
||||
ee_action_space_params=cfg.wrapper.ee_action_space_params,
|
||||
use_gripper=cfg.wrapper.use_gripper,
|
||||
)
|
||||
|
||||
if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad:
|
||||
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
|
||||
env = GamepadControlWrapper(
|
||||
@@ -1140,6 +1218,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||
env=env,
|
||||
reset_pose=cfg.wrapper.fixed_reset_joint_positions,
|
||||
reset_time_s=cfg.wrapper.reset_time_s,
|
||||
open_gripper_on_reset=cfg.wrapper.open_gripper_on_reset,
|
||||
)
|
||||
if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
|
||||
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
|
||||
@@ -1289,11 +1368,10 @@ def record_dataset(env, policy, cfg):
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
def replay_episode(env, repo_id, root=None, episode=0):
|
||||
def replay_episode(env, cfg):
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
local_files_only = root is not None
|
||||
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.dataset_root, episodes=[cfg.episode])
|
||||
env.reset()
|
||||
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
@@ -1301,7 +1379,7 @@ def replay_episode(env, repo_id, root=None, episode=0):
|
||||
for idx in range(dataset.num_frames):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action = actions[idx]["action"][:4]
|
||||
action = actions[idx]["action"]
|
||||
env.step((action, False))
|
||||
# env.step((action / env.unwrapped.delta, False))
|
||||
|
||||
@@ -1332,9 +1410,7 @@ def main(cfg: EnvConfig):
|
||||
if cfg.mode == "replay":
|
||||
replay_episode(
|
||||
env,
|
||||
cfg.replay_repo_id,
|
||||
root=cfg.dataset_root,
|
||||
episode=cfg.replay_episode,
|
||||
cfg=cfg,
|
||||
)
|
||||
exit()
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
@@ -269,6 +269,7 @@ def add_actor_information_and_train(
|
||||
policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
|
||||
saving_checkpoint = cfg.save_checkpoint
|
||||
online_steps = cfg.policy.online_steps
|
||||
async_prefetch = cfg.policy.async_prefetch
|
||||
|
||||
# Initialize logging for multiprocessing
|
||||
if not use_threads(cfg):
|
||||
@@ -326,6 +327,9 @@ def add_actor_information_and_train(
|
||||
if cfg.dataset is not None:
|
||||
dataset_repo_id = cfg.dataset.repo_id
|
||||
|
||||
# Initialize iterators
|
||||
online_iterator = None
|
||||
offline_iterator = None
|
||||
# NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
|
||||
while True:
|
||||
# Exit the training loop if shutdown is requested
|
||||
@@ -359,13 +363,26 @@ def add_actor_information_and_train(
|
||||
if len(replay_buffer) < online_step_before_learning:
|
||||
continue
|
||||
|
||||
if online_iterator is None:
|
||||
logging.debug("[LEARNER] Initializing online replay buffer iterator")
|
||||
online_iterator = replay_buffer.get_iterator(
|
||||
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
||||
)
|
||||
|
||||
if offline_replay_buffer is not None and offline_iterator is None:
|
||||
logging.debug("[LEARNER] Initializing offline replay buffer iterator")
|
||||
offline_iterator = offline_replay_buffer.get_iterator(
|
||||
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
||||
)
|
||||
|
||||
logging.debug("[LEARNER] Starting optimization loop")
|
||||
time_for_one_optimization_step = time.time()
|
||||
for _ in range(utd_ratio - 1):
|
||||
batch = replay_buffer.sample(batch_size=batch_size)
|
||||
# Sample from the iterators
|
||||
batch = next(online_iterator)
|
||||
|
||||
if dataset_repo_id is not None:
|
||||
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
|
||||
batch_offline = next(offline_iterator)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
@@ -390,26 +407,40 @@ def add_actor_information_and_train(
|
||||
"done": done,
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
"complementary_info": batch["complementary_info"],
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss
|
||||
loss_critic = policy.forward(forward_batch, model="critic")
|
||||
# Use the forward method for critic loss (includes both main critic and grasp critic)
|
||||
critic_output = policy.forward(forward_batch, model="critic")
|
||||
|
||||
# Main critic optimization
|
||||
loss_critic = critic_output["loss_critic"]
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
|
||||
# clip gradients
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
|
||||
optimizers["critic"].step()
|
||||
|
||||
# Grasp critic optimization (if available)
|
||||
if policy.config.num_discrete_actions is not None:
|
||||
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
|
||||
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
|
||||
optimizers["grasp_critic"].zero_grad()
|
||||
loss_grasp_critic.backward()
|
||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
optimizers["grasp_critic"].step()
|
||||
|
||||
# Update target networks
|
||||
policy.update_target_networks()
|
||||
|
||||
batch = replay_buffer.sample(batch_size=batch_size)
|
||||
# Sample for the last update in the UTD ratio
|
||||
batch = next(online_iterator)
|
||||
|
||||
if dataset_repo_id is not None:
|
||||
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
|
||||
batch_offline = next(offline_iterator)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
@@ -437,63 +468,80 @@ def add_actor_information_and_train(
|
||||
"next_observation_feature": next_observation_features,
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss
|
||||
loss_critic = policy.forward(forward_batch, model="critic")
|
||||
# Use the forward method for critic loss (includes both main critic and grasp critic)
|
||||
critic_output = policy.forward(forward_batch, model="critic")
|
||||
|
||||
# Main critic optimization
|
||||
loss_critic = critic_output["loss_critic"]
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
|
||||
# clip gradients
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
|
||||
optimizers["critic"].step()
|
||||
|
||||
training_infos = {}
|
||||
training_infos["loss_critic"] = loss_critic.item()
|
||||
training_infos["critic_grad_norm"] = critic_grad_norm
|
||||
# Initialize training info dictionary
|
||||
training_infos = {
|
||||
"loss_critic": loss_critic.item(),
|
||||
"critic_grad_norm": critic_grad_norm,
|
||||
}
|
||||
|
||||
# Grasp critic optimization (if available)
|
||||
if policy.config.num_discrete_actions is not None:
|
||||
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
|
||||
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
|
||||
optimizers["grasp_critic"].zero_grad()
|
||||
loss_grasp_critic.backward()
|
||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["grasp_critic"].step()
|
||||
|
||||
# Add grasp critic info to training info
|
||||
training_infos["loss_grasp_critic"] = loss_grasp_critic.item()
|
||||
training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm
|
||||
|
||||
# Actor and temperature optimization (at specified frequency)
|
||||
if optimization_step % policy_update_freq == 0:
|
||||
for _ in range(policy_update_freq):
|
||||
# Use the forward method for actor loss
|
||||
loss_actor = policy.forward(forward_batch, model="actor")
|
||||
|
||||
# Actor optimization
|
||||
actor_output = policy.forward(forward_batch, model="actor")
|
||||
loss_actor = actor_output["loss_actor"]
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
|
||||
# clip gradients
|
||||
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||
parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
|
||||
optimizers["actor"].step()
|
||||
|
||||
# Add actor info to training info
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
training_infos["actor_grad_norm"] = actor_grad_norm
|
||||
|
||||
# Temperature optimization using forward method
|
||||
loss_temperature = policy.forward(forward_batch, model="temperature")
|
||||
# Temperature optimization
|
||||
temperature_output = policy.forward(forward_batch, model="temperature")
|
||||
loss_temperature = temperature_output["loss_temperature"]
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
|
||||
# clip gradients
|
||||
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
|
||||
optimizers["temperature"].step()
|
||||
|
||||
# Add temperature info to training info
|
||||
training_infos["loss_temperature"] = loss_temperature.item()
|
||||
training_infos["temperature_grad_norm"] = temp_grad_norm
|
||||
training_infos["temperature"] = policy.temperature
|
||||
|
||||
# Update temperature
|
||||
policy.update_temperature()
|
||||
|
||||
# Check if it's time to push updated policy to actors
|
||||
# Push policy to actors if needed
|
||||
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
# Update target networks
|
||||
policy.update_target_networks()
|
||||
|
||||
# Log training metrics at specified intervals
|
||||
@@ -697,7 +745,7 @@ def save_training_checkpoint(
|
||||
logging.info("Resume training")
|
||||
|
||||
|
||||
def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||
def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
|
||||
"""
|
||||
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
|
||||
|
||||
@@ -724,11 +772,19 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||
|
||||
"""
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
|
||||
params=policy.actor.parameters_to_optimize,
|
||||
params=[
|
||||
p
|
||||
for n, p in policy.actor.named_parameters()
|
||||
if not policy.config.shared_encoder or not n.startswith("encoder")
|
||||
],
|
||||
lr=cfg.policy.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
||||
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
optimizer_grasp_critic = torch.optim.Adam(
|
||||
params=policy.grasp_critic.parameters(), lr=cfg.policy.critic_lr
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
@@ -736,6 +792,8 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||
"critic": optimizer_critic,
|
||||
"temperature": optimizer_temperature,
|
||||
}
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
optimizers["grasp_critic"] = optimizer_grasp_critic
|
||||
return optimizers, lr_scheduler
|
||||
|
||||
|
||||
@@ -936,7 +994,6 @@ def initialize_offline_replay_buffer(
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_features.keys(),
|
||||
action_mask=active_action_dims,
|
||||
action_delta=cfg.env.wrapper.delta_action,
|
||||
storage_device=storage_device,
|
||||
optimize_memory=True,
|
||||
capacity=cfg.policy.offline_buffer_capacity,
|
||||
@@ -970,11 +1027,9 @@ def get_observation_features(
|
||||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
observation_features = (
|
||||
policy.actor.encoder(observations) if policy.actor.encoder is not None else None
|
||||
)
|
||||
next_observation_features = (
|
||||
policy.actor.encoder(next_observations) if policy.actor.encoder is not None else None
|
||||
observation_features = policy.actor.encoder.get_cached_image_features(observations, normalize=True)
|
||||
next_observation_features = policy.actor.encoder.get_cached_image_features(
|
||||
next_observations, normalize=True
|
||||
)
|
||||
|
||||
return observation_features, next_observation_features
|
||||
@@ -1037,6 +1092,44 @@ def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
|
||||
parameters_queue.put(state_bytes)
|
||||
|
||||
|
||||
def check_weight_gradients(module: nn.Module) -> dict[str, bool]:
|
||||
"""
|
||||
Checks whether each parameter in the module has a gradient.
|
||||
|
||||
Args:
|
||||
module (nn.Module): A PyTorch module whose parameters will be inspected.
|
||||
|
||||
Returns:
|
||||
dict[str, bool]: A dictionary where each key is the parameter name and the value is
|
||||
True if the parameter has an associated gradient (i.e. .grad is not None),
|
||||
otherwise False.
|
||||
"""
|
||||
grad_status = {}
|
||||
for name, param in module.named_parameters():
|
||||
grad_status[name] = param.grad is not None
|
||||
return grad_status
|
||||
|
||||
|
||||
def get_overlapping_parameters(model: nn.Module, grad_status: dict[str, bool]) -> dict[str, bool]:
|
||||
"""
|
||||
Returns a dictionary of parameters (from actor) that also exist in the grad_status dictionary.
|
||||
|
||||
Args:
|
||||
actor (nn.Module): The actor model.
|
||||
grad_status (dict[str, bool]): A dictionary where keys are parameter names and values indicate
|
||||
whether each parameter has a gradient.
|
||||
|
||||
Returns:
|
||||
dict[str, bool]: A dictionary containing only the overlapping parameter names and their gradient status.
|
||||
"""
|
||||
# Get actor parameter names as a set.
|
||||
model_param_names = {name for name, _ in model.named_parameters()}
|
||||
|
||||
# Intersect parameter names between actor and grad_status.
|
||||
overlapping = {name: grad_status[name] for name in grad_status if name in model_param_names}
|
||||
return overlapping
|
||||
|
||||
|
||||
def process_interaction_message(
|
||||
message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None
|
||||
):
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import einops
|
||||
@@ -10,7 +8,6 @@ from mani_skill.utils.wrappers.record import RecordEpisode
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
|
||||
from lerobot.common.envs.configs import ManiskillEnvConfig
|
||||
from lerobot.configs import parser
|
||||
|
||||
|
||||
def preprocess_maniskill_observation(
|
||||
@@ -153,6 +150,27 @@ class TimeLimitWrapper(gym.Wrapper):
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
|
||||
class ManiskillMockGripperWrapper(gym.Wrapper):
|
||||
def __init__(self, env, nb_discrete_actions: int = 3):
|
||||
super().__init__(env)
|
||||
new_shape = env.action_space[0].shape[0] + 1
|
||||
new_low = np.concatenate([env.action_space[0].low, [0]])
|
||||
new_high = np.concatenate([env.action_space[0].high, [nb_discrete_actions - 1]])
|
||||
action_space_agent = gym.spaces.Box(low=new_low, high=new_high, shape=(new_shape,))
|
||||
self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1]))
|
||||
|
||||
def step(self, action):
|
||||
if isinstance(action, tuple):
|
||||
action_agent, telop_action = action
|
||||
else:
|
||||
telop_action = 0
|
||||
action_agent = action
|
||||
real_action = action_agent[:-1]
|
||||
final_action = (real_action, telop_action)
|
||||
obs, reward, terminated, truncated, info = self.env.step(final_action)
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
|
||||
def make_maniskill(
|
||||
cfg: ManiskillEnvConfig,
|
||||
n_envs: int | None = None,
|
||||
@@ -197,40 +215,42 @@ def make_maniskill(
|
||||
env = ManiSkillCompat(env)
|
||||
env = ManiSkillActionWrapper(env)
|
||||
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control
|
||||
if cfg.mock_gripper:
|
||||
env = ManiskillMockGripperWrapper(env, nb_discrete_actions=3)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def main(cfg: ManiskillEnvConfig):
|
||||
"""Main function to run the ManiSkill environment."""
|
||||
# Create the ManiSkill environment
|
||||
env = make_maniskill(cfg, n_envs=1)
|
||||
# @parser.wrap()
|
||||
# def main(cfg: TrainPipelineConfig):
|
||||
# """Main function to run the ManiSkill environment."""
|
||||
# # Create the ManiSkill environment
|
||||
# env = make_maniskill(cfg.env, n_envs=1)
|
||||
|
||||
# Reset the environment
|
||||
obs, info = env.reset()
|
||||
# # Reset the environment
|
||||
# obs, info = env.reset()
|
||||
|
||||
# Run a simple interaction loop
|
||||
sum_reward = 0
|
||||
for i in range(100):
|
||||
# Sample a random action
|
||||
action = env.action_space.sample()
|
||||
# # Run a simple interaction loop
|
||||
# sum_reward = 0
|
||||
# for i in range(100):
|
||||
# # Sample a random action
|
||||
# action = env.action_space.sample()
|
||||
|
||||
# Step the environment
|
||||
start_time = time.perf_counter()
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
step_time = time.perf_counter() - start_time
|
||||
sum_reward += reward
|
||||
# Log information
|
||||
# # Step the environment
|
||||
# start_time = time.perf_counter()
|
||||
# obs, reward, terminated, truncated, info = env.step(action)
|
||||
# step_time = time.perf_counter() - start_time
|
||||
# sum_reward += reward
|
||||
# # Log information
|
||||
|
||||
# Reset if episode terminated
|
||||
if terminated or truncated:
|
||||
logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
|
||||
sum_reward = 0
|
||||
obs, info = env.reset()
|
||||
# # Reset if episode terminated
|
||||
# if terminated or truncated:
|
||||
# logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
|
||||
# sum_reward = 0
|
||||
# obs, info = env.reset()
|
||||
|
||||
# Close the environment
|
||||
env.close()
|
||||
# # Close the environment
|
||||
# env.close()
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user