[HIL-SERL] Review feedback modifications (#1112)

This commit is contained in:
Adil Zouitine
2025-05-15 15:24:41 +02:00
committed by AdilZouitine
parent 5902f8fcc7
commit a5f758d7c6
17 changed files with 504 additions and 180 deletions

View File

@@ -13,13 +13,67 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Actor server runner for distributed HILSerl robot policy training.
This script implements the actor component of the distributed HILSerl architecture.
It executes the policy in the robot environment, collects experience,
and sends transitions to the learner server for policy updates.
Examples of usage:
- Start an actor server for real robot training with human-in-the-loop intervention:
```bash
python lerobot/scripts/server/actor_server.py --config_path lerobot/configs/train_config_hilserl_so100.json
```
- Run with a specific robot type for a pick and place task:
```bash
python lerobot/scripts/server/actor_server.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--robot.type=so100 \
--task=pick_and_place
```
- Set a custom workspace bound for the robot's end-effector:
```bash
python lerobot/scripts/server/actor_server.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--env.ee_action_space_params.bounds.max="[0.24, 0.20, 0.10]" \
--env.ee_action_space_params.bounds.min="[0.16, -0.08, 0.03]"
```
- Run with specific camera crop parameters:
```bash
python lerobot/scripts/server/actor_server.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--env.crop_params_dict="{'observation.images.side': [180, 207, 180, 200], 'observation.images.front': [180, 250, 120, 150]}"
```
**NOTE**: The actor server requires a running learner server to connect to. Ensure the learner
server is started before launching the actor.
**NOTE**: Human intervention is key to HILSerl training. Press the upper right trigger button on the
gamepad to take control of the robot during training. Initially intervene frequently, then gradually
reduce interventions as the policy improves.
**WORKFLOW**:
1. Determine robot workspace bounds using `find_joint_limits.py`
2. Record demonstrations with `gym_manipulator.py` in record mode
3. Process the dataset and determine camera crops with `crop_dataset_roi.py`
4. Start the learner server with the training configuration
5. Start this actor server with the same configuration
6. Use human interventions to guide policy learning
For more details on the complete HILSerl training workflow, see:
https://github.com/michel-aractingi/lerobot-hilserl-guide
"""
import logging
import os
import time
from functools import lru_cache
from queue import Empty
from statistics import mean, quantiles
import grpc
import torch
@@ -65,10 +119,12 @@ ACTOR_SHUTDOWN_TIMEOUT = 30
@parser.wrap()
def actor_cli(cfg: TrainPipelineConfig):
cfg.validate()
display_pid = False
if not use_threads(cfg):
import torch.multiprocessing as mp
mp.set_start_method("spawn")
display_pid = True
# Create logs directory to ensure it exists
log_dir = os.path.join(cfg.output_dir, "logs")
@@ -76,7 +132,7 @@ def actor_cli(cfg: TrainPipelineConfig):
log_file = os.path.join(log_dir, f"actor_{cfg.job_name}.log")
# Initialize logging with explicit log file
init_logging(log_file=log_file)
init_logging(log_file=log_file, display_pid=display_pid)
logging.info(f"Actor logging initialized, writing to {log_file}")
shutdown_event = setup_process_handlers(use_threads(cfg))
@@ -193,7 +249,7 @@ def act_with_policy(
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"actor_policy_{os.getpid()}.log")
init_logging(log_file=log_file)
init_logging(log_file=log_file, display_pid=True)
logging.info("Actor policy process logging initialized")
logging.info("make_env online")
@@ -223,12 +279,13 @@ def act_with_policy(
# NOTE: For the moment we will solely handle the case of a single environment
sum_reward_episode = 0
list_transition_to_send_to_learner = []
list_policy_time = []
episode_intervention = False
# Add counters for intervention rate calculation
episode_intervention_steps = 0
episode_total_steps = 0
policy_timer = TimerManager("Policy inference", log=False)
for interaction_step in range(cfg.policy.online_steps):
start_time = time.perf_counter()
if shutdown_event.is_set():
@@ -237,13 +294,9 @@ def act_with_policy(
if interaction_step >= cfg.policy.online_step_before_learning:
# Time policy inference and check if it meets FPS requirement
with TimerManager(
elapsed_time_list=list_policy_time,
label="Policy inference time",
log=False,
) as timer: # noqa: F841
with policy_timer:
action = policy.select_action(batch=obs)
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
policy_fps = policy_timer.fps_last
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
@@ -291,8 +344,8 @@ def act_with_policy(
)
list_transition_to_send_to_learner = []
stats = get_frequency_stats(list_policy_time)
list_policy_time.clear()
stats = get_frequency_stats(policy_timer)
policy_timer.reset()
# Calculate intervention rate
intervention_rate = 0.0
@@ -429,7 +482,7 @@ def receive_policy(
log_file = os.path.join(log_dir, f"actor_receive_policy_{os.getpid()}.log")
# Initialize logging with explicit log file
init_logging(log_file=log_file)
init_logging(log_file=log_file, display_pid=True)
logging.info("Actor receive policy process logging initialized")
# Setup process handlers to handle shutdown signal
@@ -484,7 +537,7 @@ def send_transitions(
log_file = os.path.join(log_dir, f"actor_transitions_{os.getpid()}.log")
# Initialize logging with explicit log file
init_logging(log_file=log_file)
init_logging(log_file=log_file, display_pid=True)
logging.info("Actor transitions process logging initialized")
# Setup process handlers to handle shutdown signal
@@ -533,7 +586,7 @@ def send_interactions(
log_file = os.path.join(log_dir, f"actor_interactions_{os.getpid()}.log")
# Initialize logging with explicit log file
init_logging(log_file=log_file)
init_logging(log_file=log_file, display_pid=True)
logging.info("Actor interactions process logging initialized")
# Setup process handlers to handle shutdown signal
@@ -632,25 +685,24 @@ def push_transitions_to_transport_queue(transitions: list, transitions_queue):
transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner))
def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
def get_frequency_stats(timer: TimerManager) -> dict[str, float]:
"""Get the frequency statistics of the policy.
Args:
list_policy_time (list[float]): The list of policy times.
timer (TimerManager): The timer with collected metrics.
Returns:
dict[str, float]: The frequency statistics of the policy.
"""
stats = {}
list_policy_fps = [1.0 / t for t in list_policy_time]
if len(list_policy_fps) > 1:
policy_fps = mean(list_policy_fps)
quantiles_90 = quantiles(list_policy_fps, n=10)[-1]
logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}")
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}")
if timer.count > 1:
avg_fps = timer.fps_avg
p90_fps = timer.fps_percentile(90)
logging.debug(f"[ACTOR] Average policy frame rate: {avg_fps}")
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {p90_fps}")
stats = {
"Policy frequency [Hz]": policy_fps,
"Policy frequency 90th-p [Hz]": quantiles_90,
"Policy frequency [Hz]": avg_fps,
"Policy frequency 90th-p [Hz]": p90_fps,
}
return stats

View File

@@ -203,6 +203,7 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
if key in new_dataset.meta.info["features"]:
new_dataset.meta.info["features"][key]["shape"] = [3] + list(resize_size)
# TODO: Directly modify the mp4 video + meta info features, instead of recreating a dataset
prev_episode_index = 0
for frame_idx in tqdm(range(len(original_dataset))):
frame = original_dataset[frame_idx]

View File

@@ -23,10 +23,9 @@ import numpy as np
import torch
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import init_logging
from lerobot.scripts.server.kinematics import RobotKinematics
logging.basicConfig(level=logging.INFO)
class InputController:
"""Base class for input controllers that generate motion deltas."""
@@ -726,6 +725,8 @@ if __name__ == "__main__":
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.scripts.server.gym_manipulator import make_robot_env
init_logging()
parser = argparse.ArgumentParser(description="Test end-effector control")
parser.add_argument(
"--mode",

View File

@@ -1588,19 +1588,20 @@ class GamepadControlWrapper(gym.Wrapper):
input_threshold: Minimum movement delta to consider as active input.
"""
super().__init__(env)
from lerobot.scripts.server.end_effector_control_utils import (
GamepadController,
GamepadControllerHID,
)
# use HidApi for macos
if sys.platform == "darwin":
# NOTE: On macOS, pygame doesnt reliably detect input from some controllers so we fall back to hidapi
from lerobot.scripts.server.end_effector_control_utils import GamepadControllerHID
self.controller = GamepadControllerHID(
x_step_size=x_step_size,
y_step_size=y_step_size,
z_step_size=z_step_size,
)
else:
from lerobot.scripts.server.end_effector_control_utils import GamepadController
self.controller = GamepadController(
x_step_size=x_step_size,
y_step_size=y_step_size,
@@ -1748,6 +1749,8 @@ class GymHilDeviceWrapper(gym.Wrapper):
for k in obs:
obs[k] = obs[k].to(self.device)
if "action_intervention" in info:
# NOTE: This is a hack to ensure the action intervention is a float32 tensor and supported on MPS device
info["action_intervention"] = info["action_intervention"].astype(np.float32)
info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device)
return obs, reward, terminated, truncated, info
@@ -1756,6 +1759,8 @@ class GymHilDeviceWrapper(gym.Wrapper):
for k in obs:
obs[k] = obs[k].to(self.device)
if "action_intervention" in info:
# NOTE: This is a hack to ensure the action intervention is a float32 tensor and supported on MPS device
info["action_intervention"] = info["action_intervention"].astype(np.float32)
info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device)
return obs, info

View File

@@ -14,6 +14,66 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Learner server runner for distributed HILSerl robot policy training.
This script implements the learner component of the distributed HILSerl architecture.
It initializes the policy network, maintains replay buffers, and updates
the policy based on transitions received from the actor server.
Examples of usage:
- Start a learner server for training:
```bash
python lerobot/scripts/server/learner_server.py --config_path lerobot/configs/train_config_hilserl_so100.json
```
- Run with specific SAC hyperparameters:
```bash
python lerobot/scripts/server/learner_server.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--learner.sac.alpha=0.1 \
--learner.sac.gamma=0.99
```
- Run with a specific dataset and wandb logging:
```bash
python lerobot/scripts/server/learner_server.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--dataset.repo_id=username/pick_lift_cube \
--wandb.enable=true \
--wandb.project=hilserl_training
```
- Run with a pretrained policy for fine-tuning:
```bash
python lerobot/scripts/server/learner_server.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--pretrained_policy_name_or_path=outputs/previous_training/checkpoints/080000/pretrained_model
```
- Run with a reward classifier model:
```bash
python lerobot/scripts/server/learner_server.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--reward_classifier_pretrained_path=outputs/reward_model/best_model
```
**NOTE**: Start the learner server before launching the actor server. The learner opens a gRPC server
to communicate with actors.
**NOTE**: Training progress can be monitored through Weights & Biases if wandb.enable is set to true
in your configuration.
**WORKFLOW**:
1. Create training configuration with proper policy, dataset, and environment settings
2. Start this learner server with the configuration
3. Start an actor server with the same configuration
4. Monitor training progress through wandb dashboard
For more details on the complete HILSerl training workflow, see:
https://github.com/michel-aractingi/lerobot-hilserl-guide
"""
import logging
import os
@@ -73,7 +133,6 @@ from lerobot.scripts.server.utils import (
LOG_PREFIX = "[LEARNER]"
logging.basicConfig(level=logging.INFO)
#################################################
# MAIN ENTRY POINTS AND CORE ALGORITHM FUNCTIONS #
@@ -113,13 +172,17 @@ def train(cfg: TrainPipelineConfig, job_name: str | None = None):
if job_name is None:
raise ValueError("Job name must be specified either in config or as a parameter")
display_pid = False
if not use_threads(cfg):
display_pid = True
# Create logs directory to ensure it exists
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"learner_{job_name}.log")
# Initialize logging with explicit log file
init_logging(log_file=log_file)
init_logging(log_file=log_file, display_pid=display_pid)
logging.info(f"Learner logging initialized, writing to {log_file}")
logging.info(pformat(cfg.to_dict()))
@@ -275,7 +338,7 @@ def add_actor_information_and_train(
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"learner_train_process_{os.getpid()}.log")
init_logging(log_file=log_file)
init_logging(log_file=log_file, display_pid=True)
logging.info("Initialized logging for actor information and training process")
logging.info("Initializing policy")
@@ -604,7 +667,7 @@ def start_learner_server(
log_file = os.path.join(log_dir, f"learner_server_process_{os.getpid()}.log")
# Initialize logging with explicit log file
init_logging(log_file=log_file)
init_logging(log_file=log_file, display_pid=True)
logging.info("Learner server process logging initialized")
# Setup process handlers to handle shutdown signal