forked from tangger/lerobot
[HIL-SERL] Review feedback modifications (#1112)
This commit is contained in:
committed by
AdilZouitine
parent
5902f8fcc7
commit
a5f758d7c6
@@ -13,13 +13,67 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Actor server runner for distributed HILSerl robot policy training.
|
||||
|
||||
This script implements the actor component of the distributed HILSerl architecture.
|
||||
It executes the policy in the robot environment, collects experience,
|
||||
and sends transitions to the learner server for policy updates.
|
||||
|
||||
Examples of usage:
|
||||
|
||||
- Start an actor server for real robot training with human-in-the-loop intervention:
|
||||
```bash
|
||||
python lerobot/scripts/server/actor_server.py --config_path lerobot/configs/train_config_hilserl_so100.json
|
||||
```
|
||||
|
||||
- Run with a specific robot type for a pick and place task:
|
||||
```bash
|
||||
python lerobot/scripts/server/actor_server.py \
|
||||
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||
--robot.type=so100 \
|
||||
--task=pick_and_place
|
||||
```
|
||||
|
||||
- Set a custom workspace bound for the robot's end-effector:
|
||||
```bash
|
||||
python lerobot/scripts/server/actor_server.py \
|
||||
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||
--env.ee_action_space_params.bounds.max="[0.24, 0.20, 0.10]" \
|
||||
--env.ee_action_space_params.bounds.min="[0.16, -0.08, 0.03]"
|
||||
```
|
||||
|
||||
- Run with specific camera crop parameters:
|
||||
```bash
|
||||
python lerobot/scripts/server/actor_server.py \
|
||||
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||
--env.crop_params_dict="{'observation.images.side': [180, 207, 180, 200], 'observation.images.front': [180, 250, 120, 150]}"
|
||||
```
|
||||
|
||||
**NOTE**: The actor server requires a running learner server to connect to. Ensure the learner
|
||||
server is started before launching the actor.
|
||||
|
||||
**NOTE**: Human intervention is key to HILSerl training. Press the upper right trigger button on the
|
||||
gamepad to take control of the robot during training. Initially intervene frequently, then gradually
|
||||
reduce interventions as the policy improves.
|
||||
|
||||
**WORKFLOW**:
|
||||
1. Determine robot workspace bounds using `find_joint_limits.py`
|
||||
2. Record demonstrations with `gym_manipulator.py` in record mode
|
||||
3. Process the dataset and determine camera crops with `crop_dataset_roi.py`
|
||||
4. Start the learner server with the training configuration
|
||||
5. Start this actor server with the same configuration
|
||||
6. Use human interventions to guide policy learning
|
||||
|
||||
For more details on the complete HILSerl training workflow, see:
|
||||
https://github.com/michel-aractingi/lerobot-hilserl-guide
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from queue import Empty
|
||||
from statistics import mean, quantiles
|
||||
|
||||
import grpc
|
||||
import torch
|
||||
@@ -65,10 +119,12 @@ ACTOR_SHUTDOWN_TIMEOUT = 30
|
||||
@parser.wrap()
|
||||
def actor_cli(cfg: TrainPipelineConfig):
|
||||
cfg.validate()
|
||||
display_pid = False
|
||||
if not use_threads(cfg):
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
display_pid = True
|
||||
|
||||
# Create logs directory to ensure it exists
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
@@ -76,7 +132,7 @@ def actor_cli(cfg: TrainPipelineConfig):
|
||||
log_file = os.path.join(log_dir, f"actor_{cfg.job_name}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file)
|
||||
init_logging(log_file=log_file, display_pid=display_pid)
|
||||
logging.info(f"Actor logging initialized, writing to {log_file}")
|
||||
|
||||
shutdown_event = setup_process_handlers(use_threads(cfg))
|
||||
@@ -193,7 +249,7 @@ def act_with_policy(
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_policy_{os.getpid()}.log")
|
||||
init_logging(log_file=log_file)
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor policy process logging initialized")
|
||||
|
||||
logging.info("make_env online")
|
||||
@@ -223,12 +279,13 @@ def act_with_policy(
|
||||
# NOTE: For the moment we will solely handle the case of a single environment
|
||||
sum_reward_episode = 0
|
||||
list_transition_to_send_to_learner = []
|
||||
list_policy_time = []
|
||||
episode_intervention = False
|
||||
# Add counters for intervention rate calculation
|
||||
episode_intervention_steps = 0
|
||||
episode_total_steps = 0
|
||||
|
||||
policy_timer = TimerManager("Policy inference", log=False)
|
||||
|
||||
for interaction_step in range(cfg.policy.online_steps):
|
||||
start_time = time.perf_counter()
|
||||
if shutdown_event.is_set():
|
||||
@@ -237,13 +294,9 @@ def act_with_policy(
|
||||
|
||||
if interaction_step >= cfg.policy.online_step_before_learning:
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with TimerManager(
|
||||
elapsed_time_list=list_policy_time,
|
||||
label="Policy inference time",
|
||||
log=False,
|
||||
) as timer: # noqa: F841
|
||||
with policy_timer:
|
||||
action = policy.select_action(batch=obs)
|
||||
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
|
||||
policy_fps = policy_timer.fps_last
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
|
||||
@@ -291,8 +344,8 @@ def act_with_policy(
|
||||
)
|
||||
list_transition_to_send_to_learner = []
|
||||
|
||||
stats = get_frequency_stats(list_policy_time)
|
||||
list_policy_time.clear()
|
||||
stats = get_frequency_stats(policy_timer)
|
||||
policy_timer.reset()
|
||||
|
||||
# Calculate intervention rate
|
||||
intervention_rate = 0.0
|
||||
@@ -429,7 +482,7 @@ def receive_policy(
|
||||
log_file = os.path.join(log_dir, f"actor_receive_policy_{os.getpid()}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file)
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor receive policy process logging initialized")
|
||||
|
||||
# Setup process handlers to handle shutdown signal
|
||||
@@ -484,7 +537,7 @@ def send_transitions(
|
||||
log_file = os.path.join(log_dir, f"actor_transitions_{os.getpid()}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file)
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor transitions process logging initialized")
|
||||
|
||||
# Setup process handlers to handle shutdown signal
|
||||
@@ -533,7 +586,7 @@ def send_interactions(
|
||||
log_file = os.path.join(log_dir, f"actor_interactions_{os.getpid()}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file)
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor interactions process logging initialized")
|
||||
|
||||
# Setup process handlers to handle shutdown signal
|
||||
@@ -632,25 +685,24 @@ def push_transitions_to_transport_queue(transitions: list, transitions_queue):
|
||||
transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner))
|
||||
|
||||
|
||||
def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
|
||||
def get_frequency_stats(timer: TimerManager) -> dict[str, float]:
|
||||
"""Get the frequency statistics of the policy.
|
||||
|
||||
Args:
|
||||
list_policy_time (list[float]): The list of policy times.
|
||||
timer (TimerManager): The timer with collected metrics.
|
||||
|
||||
Returns:
|
||||
dict[str, float]: The frequency statistics of the policy.
|
||||
"""
|
||||
stats = {}
|
||||
list_policy_fps = [1.0 / t for t in list_policy_time]
|
||||
if len(list_policy_fps) > 1:
|
||||
policy_fps = mean(list_policy_fps)
|
||||
quantiles_90 = quantiles(list_policy_fps, n=10)[-1]
|
||||
logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}")
|
||||
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}")
|
||||
if timer.count > 1:
|
||||
avg_fps = timer.fps_avg
|
||||
p90_fps = timer.fps_percentile(90)
|
||||
logging.debug(f"[ACTOR] Average policy frame rate: {avg_fps}")
|
||||
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {p90_fps}")
|
||||
stats = {
|
||||
"Policy frequency [Hz]": policy_fps,
|
||||
"Policy frequency 90th-p [Hz]": quantiles_90,
|
||||
"Policy frequency [Hz]": avg_fps,
|
||||
"Policy frequency 90th-p [Hz]": p90_fps,
|
||||
}
|
||||
return stats
|
||||
|
||||
|
||||
@@ -203,6 +203,7 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
if key in new_dataset.meta.info["features"]:
|
||||
new_dataset.meta.info["features"][key]["shape"] = [3] + list(resize_size)
|
||||
|
||||
# TODO: Directly modify the mp4 video + meta info features, instead of recreating a dataset
|
||||
prev_episode_index = 0
|
||||
for frame_idx in tqdm(range(len(original_dataset))):
|
||||
frame = original_dataset[frame_idx]
|
||||
|
||||
@@ -23,10 +23,9 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
class InputController:
|
||||
"""Base class for input controllers that generate motion deltas."""
|
||||
@@ -726,6 +725,8 @@ if __name__ == "__main__":
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
||||
|
||||
init_logging()
|
||||
|
||||
parser = argparse.ArgumentParser(description="Test end-effector control")
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
|
||||
@@ -1588,19 +1588,20 @@ class GamepadControlWrapper(gym.Wrapper):
|
||||
input_threshold: Minimum movement delta to consider as active input.
|
||||
"""
|
||||
super().__init__(env)
|
||||
from lerobot.scripts.server.end_effector_control_utils import (
|
||||
GamepadController,
|
||||
GamepadControllerHID,
|
||||
)
|
||||
|
||||
# use HidApi for macos
|
||||
if sys.platform == "darwin":
|
||||
# NOTE: On macOS, pygame doesn’t reliably detect input from some controllers so we fall back to hidapi
|
||||
from lerobot.scripts.server.end_effector_control_utils import GamepadControllerHID
|
||||
|
||||
self.controller = GamepadControllerHID(
|
||||
x_step_size=x_step_size,
|
||||
y_step_size=y_step_size,
|
||||
z_step_size=z_step_size,
|
||||
)
|
||||
else:
|
||||
from lerobot.scripts.server.end_effector_control_utils import GamepadController
|
||||
|
||||
self.controller = GamepadController(
|
||||
x_step_size=x_step_size,
|
||||
y_step_size=y_step_size,
|
||||
@@ -1748,6 +1749,8 @@ class GymHilDeviceWrapper(gym.Wrapper):
|
||||
for k in obs:
|
||||
obs[k] = obs[k].to(self.device)
|
||||
if "action_intervention" in info:
|
||||
# NOTE: This is a hack to ensure the action intervention is a float32 tensor and supported on MPS device
|
||||
info["action_intervention"] = info["action_intervention"].astype(np.float32)
|
||||
info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device)
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
@@ -1756,6 +1759,8 @@ class GymHilDeviceWrapper(gym.Wrapper):
|
||||
for k in obs:
|
||||
obs[k] = obs[k].to(self.device)
|
||||
if "action_intervention" in info:
|
||||
# NOTE: This is a hack to ensure the action intervention is a float32 tensor and supported on MPS device
|
||||
info["action_intervention"] = info["action_intervention"].astype(np.float32)
|
||||
info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device)
|
||||
return obs, info
|
||||
|
||||
|
||||
@@ -14,6 +14,66 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Learner server runner for distributed HILSerl robot policy training.
|
||||
|
||||
This script implements the learner component of the distributed HILSerl architecture.
|
||||
It initializes the policy network, maintains replay buffers, and updates
|
||||
the policy based on transitions received from the actor server.
|
||||
|
||||
Examples of usage:
|
||||
|
||||
- Start a learner server for training:
|
||||
```bash
|
||||
python lerobot/scripts/server/learner_server.py --config_path lerobot/configs/train_config_hilserl_so100.json
|
||||
```
|
||||
|
||||
- Run with specific SAC hyperparameters:
|
||||
```bash
|
||||
python lerobot/scripts/server/learner_server.py \
|
||||
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||
--learner.sac.alpha=0.1 \
|
||||
--learner.sac.gamma=0.99
|
||||
```
|
||||
|
||||
- Run with a specific dataset and wandb logging:
|
||||
```bash
|
||||
python lerobot/scripts/server/learner_server.py \
|
||||
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||
--dataset.repo_id=username/pick_lift_cube \
|
||||
--wandb.enable=true \
|
||||
--wandb.project=hilserl_training
|
||||
```
|
||||
|
||||
- Run with a pretrained policy for fine-tuning:
|
||||
```bash
|
||||
python lerobot/scripts/server/learner_server.py \
|
||||
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||
--pretrained_policy_name_or_path=outputs/previous_training/checkpoints/080000/pretrained_model
|
||||
```
|
||||
|
||||
- Run with a reward classifier model:
|
||||
```bash
|
||||
python lerobot/scripts/server/learner_server.py \
|
||||
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||
--reward_classifier_pretrained_path=outputs/reward_model/best_model
|
||||
```
|
||||
|
||||
**NOTE**: Start the learner server before launching the actor server. The learner opens a gRPC server
|
||||
to communicate with actors.
|
||||
|
||||
**NOTE**: Training progress can be monitored through Weights & Biases if wandb.enable is set to true
|
||||
in your configuration.
|
||||
|
||||
**WORKFLOW**:
|
||||
1. Create training configuration with proper policy, dataset, and environment settings
|
||||
2. Start this learner server with the configuration
|
||||
3. Start an actor server with the same configuration
|
||||
4. Monitor training progress through wandb dashboard
|
||||
|
||||
For more details on the complete HILSerl training workflow, see:
|
||||
https://github.com/michel-aractingi/lerobot-hilserl-guide
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
@@ -73,7 +133,6 @@ from lerobot.scripts.server.utils import (
|
||||
|
||||
LOG_PREFIX = "[LEARNER]"
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
#################################################
|
||||
# MAIN ENTRY POINTS AND CORE ALGORITHM FUNCTIONS #
|
||||
@@ -113,13 +172,17 @@ def train(cfg: TrainPipelineConfig, job_name: str | None = None):
|
||||
if job_name is None:
|
||||
raise ValueError("Job name must be specified either in config or as a parameter")
|
||||
|
||||
display_pid = False
|
||||
if not use_threads(cfg):
|
||||
display_pid = True
|
||||
|
||||
# Create logs directory to ensure it exists
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"learner_{job_name}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file)
|
||||
init_logging(log_file=log_file, display_pid=display_pid)
|
||||
logging.info(f"Learner logging initialized, writing to {log_file}")
|
||||
logging.info(pformat(cfg.to_dict()))
|
||||
|
||||
@@ -275,7 +338,7 @@ def add_actor_information_and_train(
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"learner_train_process_{os.getpid()}.log")
|
||||
init_logging(log_file=log_file)
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Initialized logging for actor information and training process")
|
||||
|
||||
logging.info("Initializing policy")
|
||||
@@ -604,7 +667,7 @@ def start_learner_server(
|
||||
log_file = os.path.join(log_dir, f"learner_server_process_{os.getpid()}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file)
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Learner server process logging initialized")
|
||||
|
||||
# Setup process handlers to handle shutdown signal
|
||||
|
||||
Reference in New Issue
Block a user