Compare commits

..

43 Commits

Author SHA1 Message Date
Michel Aractingi
daa1480a91 nit 2025-01-22 10:26:52 +01:00
Michel Aractingi
71ec721e48 cleaned eval_on_robot.py; readded policy; fixed doc strings 2025-01-22 10:26:52 +01:00
Michel Aractingi
bbb5ba0adf Extend reward classifier for multiple camera views (#626) 2025-01-22 10:26:52 +01:00
Eugene Mironov
844bfcf484 [Port HIL_SERL] Final fixes for the Reward Classifier (#598) 2025-01-22 10:26:52 +01:00
Michel Aractingi
13441f0d98 added temporary fix for missing task_index key in online environment 2025-01-22 10:26:50 +01:00
Michel Aractingi
41b377211c split encoder for critic and actor 2025-01-22 10:25:52 +01:00
KeWang1017
9ceb68ee90 Refine SAC configuration and policy for enhanced performance
- Updated standard deviation parameterization in SACConfig to 'softplus' with defined min and max values for improved stability.
- Modified action sampling in SACPolicy to use reparameterized sampling, ensuring better gradient flow and log probability calculations.
- Cleaned up log probability calculations in TanhMultivariateNormalDiag for clarity and efficiency.
- Increased evaluation frequency in YAML configuration to 50000 for more efficient training cycles.

These changes aim to enhance the robustness and performance of the SAC implementation during training and inference.
2025-01-22 10:23:33 +01:00
KeWang1017
d1baa5a82f trying to get sac running 2025-01-22 10:20:56 +01:00
Michel Aractingi
04da4dd3e3 Added normalization schemes and style checks 2025-01-22 10:19:19 +01:00
Michel Aractingi
b0e2fcdba7 added optimizer and sac to factory.py 2025-01-22 10:17:48 +01:00
Eugene Mironov
1e2a757cd3 [Port Hil-SERL] Add unit tests for the reward classifier & fix imports & check script (#578) 2025-01-22 10:14:06 +01:00
Michel Aractingi
ab842ba6ae nit in control_robot.py 2025-01-22 10:06:39 +01:00
Michel Aractingi
94a7221a94 Update lerobot/scripts/train_hilserl_classifier.py
Co-authored-by: Yoel <yoel.chornton@gmail.com>
2025-01-22 10:06:39 +01:00
Claudio Coppola
00dadcace0 LerobotDataset pushable to HF from any folder (#563) 2025-01-22 10:06:39 +01:00
berjaoui
81a2f2958d Update 7_get_started_with_real_robot.md (#559) 2025-01-22 10:06:39 +01:00
Michel Aractingi
68b4fb60ad Control simulated robot with real leader (#514)
Co-authored-by: Remi <remi.cadene@huggingface.co>
2025-01-22 10:06:39 +01:00
Remi
96b2b62377 Fix missing local_files_only in record/replay (#540)
Co-authored-by: Simon Alibert <alibert.sim@gmail.com>
2025-01-22 10:06:39 +01:00
Michel Aractingi
b5c98bbfd3 Refactor OpenX (#505) 2025-01-22 10:06:39 +01:00
Eugene Mironov
58e12cf2e8 Fixup 2025-01-22 10:06:39 +01:00
Michel Aractingi
d8b5fae622 Add human intervention mechanism and eval_robot script to evaluate policy on the robot (#541)
Co-authored-by: Yoel <yoel.chornton@gmail.com>
2025-01-22 10:06:39 +01:00
Yoel
67ac81d728 Reward classifier and training (#528)
Co-authored-by: Daniel Ritchie <daniel@brainwavecollective.ai>
Co-authored-by: resolver101757 <kelster101757@hotmail.com>
Co-authored-by: Jannik Grothusen <56967823+J4nn1K@users.noreply.github.com>
Co-authored-by: Remi <re.cadene@gmail.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2025-01-22 10:06:39 +01:00
Michel Aractingi
b5f1ea3140 nit 2025-01-22 10:06:39 +01:00
AdilZouitine
4d854a1513 Stable version of rlpd + drq 2025-01-22 09:00:16 +00:00
AdilZouitine
87da655eab Add type annotations and restructure SACConfig class fields 2025-01-21 09:51:12 +00:00
Adil Zouitine
a8fda9c61a Change SAC policy implementation with configuration and modeling classes 2025-01-17 09:39:04 +01:00
Adil Zouitine
55505ff817 Add rlpd tricks 2025-01-16 11:53:36 +01:00
Adil Zouitine
20d31ab8e0 SAC works 2025-01-16 11:53:27 +01:00
Adil Zouitine
e5b83aab5e remove breakpoint 2025-01-16 11:52:03 +01:00
Adil Zouitine
a9d5f62304 [WIP] correct sac implementation 2025-01-16 11:51:18 +01:00
Adil Zouitine
72e1ed7058 Add rlpd tricks 2025-01-16 11:42:24 +01:00
Adil Zouitine
d8e67a2609 SAC works 2025-01-16 11:42:24 +01:00
Adil Zouitine
50e12376de remove breakpoint 2025-01-16 11:42:23 +01:00
Adil Zouitine
73aa6c25f3 [WIP] correct sac implementation 2025-01-16 11:42:14 +01:00
Pradeep Kadubandi
380b836eee Fix for the issue https://github.com/huggingface/lerobot/issues/638 (#639) 2025-01-15 10:50:38 +01:00
Philip Fung
eec6796cb8 fixes to SO-100 readme (#600)
Co-authored-by: Philip Fung <no@one>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
2025-01-10 11:30:01 +01:00
Mishig
25a8597680 [viz] Fixes & updates to html visualizer (#617) 2025-01-09 11:39:54 +01:00
CharlesCNorton
b8b368310c typo fix: batch_convert_dataset_v1_to_v2.py (#615)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
2025-01-09 09:57:45 +01:00
Ville Kuosmanen
5097cd900e fix(visualise): use correct language description for each episode id (#604)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
2025-01-09 09:39:48 +01:00
CharlesCNorton
bc16e1b497 fix(docs): typos in benchmark readme.md (#614)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
2025-01-09 09:35:27 +01:00
Simon Alibert
8f821ecad0 Fix Quality workflow (#622) 2025-01-08 13:35:11 +01:00
CharlesCNorton
4519016e67 Update README.md (#612) 2025-01-03 16:19:37 +01:00
Eugene Mironov
59e2757434 Fix broken create_lerobot_dataset_card (#590) 2024-12-23 15:05:59 +01:00
Mishig
73b64c3089 [vizualizer] for LeRobodDataset V2 (#576) 2024-12-20 16:26:23 +01:00
37 changed files with 386 additions and 4429 deletions

View File

@@ -1,18 +0,0 @@
import socket
def check_port(host, port):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
s.connect((host, port))
print(f"Connection successful to {host}:{port}!")
except Exception as e:
print(f"Connection failed to {host}:{port}: {e}")
finally:
s.close()
if __name__ == "__main__":
host = "127.0.0.1" # or "localhost"
port = 51350
check_port(host, port)

View File

@@ -74,23 +74,7 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
image_transforms = None
if cfg.training.image_transforms.enable:
default_tf = OmegaConf.create(
{
"brightness": {"weight": 0.0, "min_max": None},
"contrast": {"weight": 0.0, "min_max": None},
"saturation": {"weight": 0.0, "min_max": None},
"hue": {"weight": 0.0, "min_max": None},
"sharpness": {"weight": 0.0, "min_max": None},
"max_num_transforms": None,
"random_order": False,
"image_size": None,
"interpolation": None,
"image_mean": None,
"image_std": None,
}
)
cfg_tf = OmegaConf.merge(OmegaConf.create(default_tf), cfg.training.image_transforms)
cfg_tf = cfg.training.image_transforms
image_transforms = get_image_transforms(
brightness_weight=cfg_tf.brightness.weight,
brightness_min_max=cfg_tf.brightness.min_max,
@@ -104,10 +88,6 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
sharpness_min_max=cfg_tf.sharpness.min_max,
max_num_transforms=cfg_tf.max_num_transforms,
random_order=cfg_tf.random_order,
image_size=(cfg_tf.image_size.height, cfg_tf.image_size.width) if cfg_tf.image_size else None,
interpolation=cfg_tf.interpolation,
image_mean=cfg_tf.image_mean,
image_std=cfg_tf.image_std,
)
if isinstance(cfg.dataset_repo_id, str):

View File

@@ -84,8 +84,7 @@ class LeRobotDatasetMetadata:
# Load metadata
(self.root / "meta").mkdir(exist_ok=True, parents=True)
if not self.local_files_only:
self.pull_from_repo(allow_patterns="meta/")
self.pull_from_repo(allow_patterns="meta/")
self.info = load_info(self.root)
self.stats = load_stats(self.root)
self.tasks = load_tasks(self.root)
@@ -538,8 +537,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
]
files += video_files
if not self.local_files_only:
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
def load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""

View File

@@ -150,10 +150,6 @@ def get_image_transforms(
sharpness_min_max: tuple[float, float] | None = None,
max_num_transforms: int | None = None,
random_order: bool = False,
interpolation: str | None = None,
image_size: tuple[int, int] | None = None,
image_mean: list[float] | None = None,
image_std: list[float] | None = None,
):
def check_value(name, weight, min_max):
if min_max is not None:
@@ -174,18 +170,6 @@ def get_image_transforms(
weights = []
transforms = []
if image_size is not None:
interpolations = [interpolation.value for interpolation in v2.InterpolationMode]
if interpolation is None:
# Use BICUBIC as default interpolation
interpolation_mode = v2.InterpolationMode.BICUBIC
elif interpolation in interpolations:
interpolation_mode = v2.InterpolationMode(interpolation)
else:
raise ValueError("The interpolation passed is not supported")
# Weight for resizing is always 1
weights.append(1.0)
transforms.append(v2.Resize(size=(image_size[0], image_size[1]), interpolation=interpolation_mode))
if brightness_min_max is not None and brightness_weight > 0.0:
weights.append(brightness_weight)
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
@@ -201,15 +185,6 @@ def get_image_transforms(
if sharpness_min_max is not None and sharpness_weight > 0.0:
weights.append(sharpness_weight)
transforms.append(SharpnessJitter(sharpness=sharpness_min_max))
if image_mean is not None and image_std is not None:
# Weight for normalization is always 1
weights.append(1.0)
transforms.append(
v2.Normalize(
mean=image_mean,
std=image_std,
)
)
n_subset = len(transforms)
if max_num_transforms is not None:

View File

@@ -275,7 +275,6 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
hf_features[key] = datasets.Sequence(
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
)
# TODO: (alibers, azouitine) Add support for ft["shap"] == 0 as Value
return datasets.Features(hf_features)

View File

@@ -20,7 +20,7 @@ import gymnasium as gym
import numpy as np
import torch
from omegaconf import DictConfig
# from mani_skill.utils import common
from mani_skill.utils import common
def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
@@ -122,30 +122,28 @@ class PixelWrapper(gym.Wrapper):
obs, reward, terminated, truncated, info = self.env.step(action)
return self._get_obs(obs), reward, terminated, truncated, info
class ConvertToLeRobotEnv(gym.Wrapper):
def __init__(self, env, num_envs):
super().__init__(env)
def reset(self, seed=None, options=None):
obs, info = self.env.reset(seed=seed, options={})
return self._get_obs(obs), info
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
return self._get_obs(obs), reward, terminated, truncated, info
def _get_obs(self, observation):
sensor_data = observation.pop("sensor_data")
del observation["sensor_param"]
images = []
for cam_data in sensor_data.values():
images.append(cam_data["rgb"])
images.append(cam_data["rgb"])
images = torch.concat(images, axis=-1)
# flatten the rest of the data which should just be state data
observation = common.flatten_state_dict(observation, use_torch=True, device=self.base_env.device)
observation = common.flatten_state_dict(
observation, use_torch=True, device=self.base_env.device
)
ret = dict()
ret["state"] = observation
ret["pixels"] = images
return ret
return ret

View File

@@ -31,25 +31,28 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
# TODO: You have to merge all tensors from agent key and extra key
# You don't keep sensor param key in the observation
# And you keep sensor data rgb
for key, img in observations.items():
if "images" not in key:
continue
if "pixels" in observations:
if isinstance(observations["pixels"], dict):
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
else:
imgs = {"observation.image": observations["pixels"]}
if img.ndim == 3:
img = img.unsqueeze(0)
# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
for imgkey, img in imgs.items():
img = torch.from_numpy(img)
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
img /= 255
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
return_observations[key] = img
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
img /= 255
return_observations[imgkey] = img
# obs state agent qpos and qvel
# image
@@ -60,8 +63,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
# requirement for "agent_pos"
# return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
return_observations["observation.state"] = observations["observation.state"].float()
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
return return_observations

View File

@@ -127,8 +127,6 @@ class Logger:
job_type="train_eval",
resume="must" if cfg.resume else None,
)
# Handle custom step key for rl asynchronous training.
self._wandb_custom_step_key: set[str] | None = None
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb
@@ -174,32 +172,18 @@ class Logger:
self,
save_dir: Path,
train_step: int,
optimizer: Optimizer | dict,
optimizer: Optimizer,
scheduler: LRScheduler | None,
interaction_step: int | None = None,
):
"""Checkpoint the global training_step, optimizer state, scheduler state, and random state.
All of these are saved as "training_state.pth" under the checkpoint directory.
"""
# In Sac, for example, we have a dictionary of torch.optim.Optimizer
if type(optimizer) is dict:
optimizer_state_dict = {}
for k in optimizer:
optimizer_state_dict[k] = optimizer[k].state_dict()
else:
optimizer_state_dict = optimizer.state_dict()
training_state = {
"step": train_step,
"optimizer": optimizer_state_dict,
"optimizer": optimizer.state_dict(),
**get_global_random_state(),
}
# Interaction step is related to the distributed training code
# In that setup, we have two kinds of steps, the online step of the env and the optimization step
# We need to save both in order to resume the optimization properly and not break the logs dependant on the interaction step
if interaction_step is not None:
training_state["interaction_step"] = interaction_step
if scheduler is not None:
training_state["scheduler"] = scheduler.state_dict()
torch.save(training_state, save_dir / self.training_state_file_name)
@@ -211,7 +195,6 @@ class Logger:
optimizer: Optimizer,
scheduler: LRScheduler | None,
identifier: str,
interaction_step: int | None = None,
):
"""Checkpoint the model weights and the training state."""
checkpoint_dir = self.checkpoints_dir / str(identifier)
@@ -223,24 +206,16 @@ class Logger:
self.save_model(
checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name
)
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler, interaction_step)
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler)
os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir)
def load_last_training_state(self, optimizer: Optimizer | dict, scheduler: LRScheduler | None) -> int:
def load_last_training_state(self, optimizer: Optimizer, scheduler: LRScheduler | None) -> int:
"""
Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and
random state, and return the global training step.
"""
training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name)
# For the case where the optimizer is a dictionary of optimizers (e.g., sac)
if type(training_state["optimizer"]) is dict:
assert set(training_state["optimizer"].keys()) == set(optimizer.keys()), (
"Optimizer dictionaries do not have the same keys during resume!"
)
for k, v in training_state["optimizer"].items():
optimizer[k].load_state_dict(v)
else:
optimizer.load_state_dict(training_state["optimizer"])
optimizer.load_state_dict(training_state["optimizer"])
if scheduler is not None:
scheduler.load_state_dict(training_state["scheduler"])
elif "scheduler" in training_state:
@@ -251,44 +226,17 @@ class Logger:
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
return training_state["step"]
def log_dict(self, d, step: int | None = None, mode="train", custom_step_key: str | None = None):
"""Log a dictionary of metrics to WandB."""
def log_dict(self, d, step, mode="train"):
assert mode in {"train", "eval"}
# TODO(alexander-soare): Add local text log.
if step is None and custom_step_key is None:
raise ValueError("Either step or custom_step_key must be provided.")
if self._wandb is not None:
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it
# increases with each wandb.log call, but in the case of asynchronous RL for example,
# multiple time steps is possible for example, the interaction step with the environment,
# the training step, the evaluation step, etc. So we need to define a custom step key
# to log the correct step for each metric.
if custom_step_key is not None:
if self._wandb_custom_step_key is None:
self._wandb_custom_step_key = set()
new_custom_key = f"{mode}/{custom_step_key}"
if new_custom_key not in self._wandb_custom_step_key:
self._wandb_custom_step_key.add(new_custom_key)
self._wandb.define_metric(new_custom_key, hidden=True)
for k, v in d.items():
if not isinstance(v, (int, float, str, wandb.Table)):
logging.warning(
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
)
continue
# Do not log the custom step key itself.
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
continue
if custom_step_key is not None:
value_custom_step = d[custom_step_key]
self._wandb.log({f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step})
continue
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
self._wandb.log({f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"):
assert mode in {"train", "eval"}

View File

@@ -76,11 +76,7 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
def make_policy(
hydra_cfg: DictConfig,
pretrained_policy_name_or_path: str | None = None,
dataset_stats=None,
*args,
**kwargs,
hydra_cfg: DictConfig, pretrained_policy_name_or_path: str | None = None, dataset_stats=None
) -> Policy:
"""Make an instance of a policy class.
@@ -104,9 +100,7 @@ def make_policy(
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
if pretrained_policy_name_or_path is None:
# Make a fresh policy.
# HACK: We pass *args and **kwargs to the policy constructor to allow for additional arguments
# for example device for the sac policy.
policy = policy_cls(config=policy_cfg, dataset_stats=dataset_stats)
policy = policy_cls(policy_cfg, dataset_stats)
else:
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
# hyperparameters that we want to vary).

View File

@@ -10,7 +10,7 @@ class ClassifierConfig:
num_classes: int = 2
hidden_dim: int = 256
dropout_rate: float = 0.1
model_name: str = "helper2424/resnet10"
model_name: str = "microsoft/resnet-50"
device: str = "cpu"
model_type: str = "cnn" # "transformer" or "cnn"
num_cameras: int = 2

View File

@@ -47,7 +47,7 @@ class Classifier(
super().__init__()
self.config = config
# self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
# Extract vision model if we're given a multimodal model
if hasattr(encoder, "vision_model"):
@@ -108,12 +108,11 @@ class Classifier(
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
"""Extract the appropriate output from the encoder."""
# Process images with the processor (handles resizing and normalization)
# processed = self.processor(
# images=x, # LeRobotDataset already provides proper tensor format
# return_tensors="pt",
# )
# processed = processed["pixel_values"].to(x.device)
processed = x
processed = self.processor(
images=x, # LeRobotDataset already provides proper tensor format
return_tensors="pt",
)
processed = processed["pixel_values"].to(x.device)
with torch.no_grad():
if self.is_cnn:
@@ -145,10 +144,8 @@ class Classifier(
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
def predict_reward(self, x, threshold=0.6):
def predict_reward(self, x):
if self.config.num_classes == 2:
probs = self.forward(x).probabilities
logging.debug(f"Predicted reward images: {probs}")
return (probs > threshold).float()
return (self.forward(x).probabilities > 0.5).float()
else:
return torch.argmax(self.forward(x).probabilities, dim=1)

View File

@@ -39,32 +39,16 @@ class SACConfig:
"observation.environment_state": "min_max",
}
)
input_normalization_params: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: {
"observation.image": {"mean": [[0.485, 0.456, 0.406]], "std": [[0.229, 0.224, 0.225]]},
"observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]},
}
)
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
output_normalization_params: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: {
"action": {"min": [-1, -1], "max": [1, 1]},
}
)
# TODO: Move it outside of the config
actor_learner_config: dict[str, str | int] = field(
default_factory=lambda: {
"actor_ip": "127.0.0.1",
"port": 50051,
"learner_ip": "127.0.0.1",
}
)
camera_number: int = 1
# Add type annotations for these fields:
vision_encoder_name: str | None = field(default="helper2424/resnet10")
freeze_vision_encoder: bool = True
image_encoder_hidden_dim: int = 32
shared_encoder: bool = True
shared_encoder: bool = False
discount: float = 0.99
temperature_init: float = 1.0
num_critics: int = 2
@@ -95,6 +79,5 @@ class SACConfig:
"use_tanh_squash": True,
"log_std_min": -5,
"log_std_max": 2,
"init_final": 0.005,
}
)

View File

@@ -17,7 +17,8 @@
# TODO: (1) better device management
from typing import Callable, Optional, Tuple
from collections import deque
from typing import Callable, Optional, Sequence, Tuple, Union
import einops
import numpy as np
@@ -29,7 +30,6 @@ from torch import Tensor
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.policies.utils import get_device_from_parameters
class SACPolicy(
@@ -45,26 +45,25 @@ class SACPolicy(
self,
config: SACConfig | None = None,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
device: str = "cpu",
):
super().__init__()
if config is None:
config = SACConfig()
self.config = config
if config.input_normalization_modes is not None:
input_normalization_params = _convert_normalization_params_to_tensor(
config.input_normalization_params
)
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, input_normalization_params
config.input_shapes, config.input_normalization_modes, dataset_stats
)
else:
self.normalize_inputs = nn.Identity()
output_normalization_params = _convert_normalization_params_to_tensor(
config.output_normalization_params
)
output_normalization_params = {}
for outer_key, inner_dict in config.output_normalization_params.items():
output_normalization_params[outer_key] = {}
for key, value in inner_dict.items():
output_normalization_params[outer_key][key] = torch.tensor(value)
# HACK: This is hacky and should be removed
dataset_stats = dataset_stats or output_normalization_params
@@ -75,80 +74,88 @@ class SACPolicy(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
# NOTE: For images the encoder should be shared between the actor and critic
if config.shared_encoder:
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
encoder_critic = SACObservationEncoder(config)
encoder_actor: SACObservationEncoder = encoder_critic
else:
encoder_critic = SACObservationEncoder(config)
encoder_actor = SACObservationEncoder(config)
# Define networks
critic_nets = []
for _ in range(config.num_critics):
critic_net = Critic(
encoder=encoder_critic,
network=MLP(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
),
device=device,
)
critic_nets.append(critic_net)
self.critic_ensemble = CriticEnsemble(
encoder=encoder_critic,
network_list=nn.ModuleList(
[
MLP(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
),
output_normalization=self.normalize_targets,
target_critic_nets = []
for _ in range(config.num_critics):
target_critic_net = Critic(
encoder=encoder_critic,
network=MLP(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
),
device=device,
)
target_critic_nets.append(target_critic_net)
self.critic_ensemble = create_critic_ensemble(
critics=critic_nets, num_critics=config.num_critics, device=device
)
self.critic_target = CriticEnsemble(
encoder=encoder_critic,
network_list=nn.ModuleList(
[
MLP(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
),
output_normalization=self.normalize_targets,
self.critic_target = create_critic_ensemble(
critics=target_critic_nets, num_critics=config.num_critics, device=device
)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
self.actor = Policy(
encoder=encoder_actor,
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
action_dim=config.output_shapes["action"][0],
device=device,
encoder_is_shared=config.shared_encoder,
**config.policy_kwargs,
)
if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
# TODO (azouitine): Handle the case where the temparameter is a fixed
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
# it triggers "can't optimize a non-leaf Tensor"
self.log_alpha = nn.Parameter(torch.tensor([0.0]))
# TODO: Handle the case where the temparameter is a fixed
self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
self.temperature = self.log_alpha.exp().item()
def reset(self):
"""Reset the policy"""
pass
def to(self, *args, **kwargs):
"""Override .to(device) method to involve moving the log_alpha fixed_std"""
if self.actor.fixed_std is not None:
self.actor.fixed_std = self.actor.fixed_std.to(*args, **kwargs)
# self.log_alpha = self.log_alpha.to(*args, **kwargs)
super().to(*args, **kwargs)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation"""
actions, _, _ = self.actor(batch)
actions = self.unnormalize_outputs({"action": actions})["action"]
return actions
def critic_forward(self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False) -> Tensor:
"""Forward pass through a critic network ensemble
Args:
observations: Dictionary of observations
actions: Action tensor
use_target: If True, use target critics, otherwise use ensemble critics
Returns:
Tensor of Q-values from all critics
"""
critics = self.critic_target if use_target else self.critic_ensemble
q_values = torch.stack([critic(observations, actions) for critic in critics])
return q_values
def critic_forward(
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False, features: Optional[Tensor] = None
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False
) -> Tensor:
"""Forward pass through a critic network ensemble
@@ -161,28 +168,27 @@ class SACPolicy(
Tensor of Q-values from all critics
"""
critics = self.critic_target if use_target else self.critic_ensemble
q_values = critics(observations, actions, features)
q_values = torch.stack([critic(observations, actions) for critic in critics])
return q_values
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ...
def update_target_networks(self):
"""Update target networks with exponential moving average"""
for target_param, param in zip(
self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=False
):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False):
for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
def compute_loss_critic(self, observations, actions, rewards, next_observations, done, obs_features=None, next_obs_features=None) -> Tensor:
def compute_loss_critic(self, observations, actions, rewards, next_observations, done) -> Tensor:
temperature = self.log_alpha.exp().item()
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_obs_features)
next_action_preds, next_log_probs, _ = self.actor(next_observations)
# 2- compute q targets
q_targets = self.critic_forward(
observations=next_observations, actions=next_action_preds, use_target=True, features=next_obs_features
observations=next_observations, actions=next_action_preds, use_target=True
)
# subsample critics to prevent overfitting if use high UTD (update to date)
@@ -199,7 +205,7 @@ class SACPolicy(
td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs
q_preds = self.critic_forward(observations, actions, use_target=False, features=obs_features)
q_preds = self.critic_forward(observations, actions, use_target=False)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
@@ -214,18 +220,18 @@ class SACPolicy(
).sum()
return critics_loss
def compute_loss_temperature(self, observations, obs_features=None) -> Tensor:
def compute_loss_temperature(self, observations) -> Tensor:
"""Compute the temperature loss"""
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.actor(observations, obs_features)
_, log_probs, _ = self.actor(observations)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
return temperature_loss
def compute_loss_actor(self, observations, obs_features=None) -> Tensor:
def compute_loss_actor(self, observations) -> Tensor:
temperature = self.log_alpha.exp().item()
actions_pi, log_probs, _ = self.actor(observations, obs_features)
actions_pi, log_probs, _ = self.actor(observations)
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
min_q_preds = q_preds.min(dim=0)[0]
@@ -272,115 +278,54 @@ class MLP(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class CriticEnsemble(nn.Module):
"""
┌──────────────────┬─────────────────────────────────────────────────────────┐
│ Critic Ensemble │ │
├──────────────────┘ │
│ │
│ ┌────┐ ┌────┐ ┌────┐ │
│ │ Q1 │ │ Q2 │ │ Qn │ │
│ └────┘ └────┘ └────┘ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ │ │ │ │ │ │
│ │ MLP 1 │ │ MLP 2 │ │ MLP │ │
│ │ │ │ │ ... │ num_critics │ │
│ │ │ │ │ │ │ │
│ └──────────────┘ └──────────────┘ └──────────────┘ │
│ ▲ ▲ ▲ │
│ └───────────────────┴───────┬────────────────────────────┘ │
│ │ │
│ │ │
│ ┌───────────────────┐ │
│ │ Embedding │ │
│ │ │ │
│ └───────────────────┘ │
│ ▲ │
│ │ │
│ ┌─────────────┴────────────┐ │
│ │ │ │
│ │ SACObservationEncoder │ │
│ │ │ │
│ └──────────────────────────┘ │
│ ▲ │
│ │ │
│ │ │
│ │ │
└───────────────────────────┬────────────────────┬───────────────────────────┘
│ Observation │
└────────────────────┘
"""
class Critic(nn.Module):
def __init__(
self,
encoder: Optional[nn.Module],
network_list: nn.ModuleList,
output_normalization: nn.Module,
network: nn.Module,
init_final: Optional[float] = None,
device: str = "cpu",
):
super().__init__()
self.device = torch.device(device)
self.encoder = encoder
self.network_list = network_list
self.network = network
self.init_final = init_final
self.output_normalization = output_normalization
self.parameters_to_optimize = []
# Handle the case where a part of the encoder if frozen
if self.encoder is not None:
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
self.parameters_to_optimize += list(self.network_list.parameters())
# Find the last Linear layer's output dimension
for layer in reversed(network_list[0].net):
for layer in reversed(network.net):
if isinstance(layer, nn.Linear):
out_features = layer.out_features
break
# Output layer
self.output_layers = []
if init_final is not None:
for _ in network_list:
output_layer = nn.Linear(out_features, 1)
nn.init.uniform_(output_layer.weight, -init_final, init_final)
nn.init.uniform_(output_layer.bias, -init_final, init_final)
self.output_layers.append(output_layer)
self.output_layer = nn.Linear(out_features, 1)
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
else:
self.output_layers = []
for _ in network_list:
output_layer = nn.Linear(out_features, 1)
orthogonal_init()(output_layer.weight)
self.output_layers.append(output_layer)
self.output_layers = nn.ModuleList(self.output_layers)
self.parameters_to_optimize += list(self.output_layers.parameters())
self.output_layer = nn.Linear(out_features, 1)
orthogonal_init()(self.output_layer.weight)
self.to(self.device)
def forward(
self,
observations: dict[str, torch.Tensor],
actions: torch.Tensor,
features: Optional[torch.Tensor] = None,
) -> torch.Tensor:
device = get_device_from_parameters(self)
# Move observations to the correct device
observations = {k: v.to(device) for k, v in observations.items()}
# Normalize actions for sample efficiency
actions: dict[str, torch.Tensor] = {"action": actions}
actions = self.output_normalization(actions)["action"]
actions = actions.to(device)
# Use precomputed features if provided; otherwise, encode observations.
obs_enc = features if features is not None else (observations if self.encoder is None else self.encoder(observations))
# Move each tensor in observations to device
observations = {k: v.to(self.device) for k, v in observations.items()}
actions = actions.to(self.device)
obs_enc = observations if self.encoder is None else self.encoder(observations)
inputs = torch.cat([obs_enc, actions], dim=-1)
list_q_values = []
for network, output_layer in zip(self.network_list, self.output_layers, strict=False):
x = network(inputs)
value = output_layer(x)
list_q_values.append(value.squeeze(-1))
return torch.stack(list_q_values)
x = self.network(inputs)
value = self.output_layer(x)
return value.squeeze(-1)
class Policy(nn.Module):
def __init__(
@@ -393,15 +338,17 @@ class Policy(nn.Module):
fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None,
use_tanh_squash: bool = False,
device: str = "cpu",
encoder_is_shared: bool = False,
):
super().__init__()
self.device = torch.device(device)
self.encoder = encoder
self.network = network
self.action_dim = action_dim
self.log_std_min = log_std_min
self.log_std_max = log_std_max
self.fixed_std = fixed_std
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
self.use_tanh_squash = use_tanh_squash
self.parameters_to_optimize = []
@@ -433,14 +380,16 @@ class Policy(nn.Module):
orthogonal_init()(self.std_layer.weight)
self.parameters_to_optimize += list(self.std_layer.parameters())
self.to(self.device)
def forward(
self,
self,
observations: torch.Tensor,
features: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Use precomputed features if provided; otherwise compute encoder representations.
obs_enc = features if features is not None else (observations if self.encoder is None else self.encoder(observations))
) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists
obs_enc = observations if self.encoder is None else self.encoder(observations)
# Get network outputs
outputs = self.network(obs_enc)
means = self.mean_layer(outputs)
@@ -449,7 +398,7 @@ class Policy(nn.Module):
if self.fixed_std is None:
log_std = self.std_layer(outputs)
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
if self.use_tanh_squash:
log_std = torch.tanh(log_std)
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
@@ -457,8 +406,8 @@ class Policy(nn.Module):
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
else:
log_std = self.fixed_std.expand_as(means)
# Get distribution and sample actions
# uses tanh activation function to squash the action to be in the range of [-1, 1]
normal = torch.distributions.Normal(means, torch.exp(log_std))
x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1))
log_probs = normal.log_prob(x_t) # Base log probability before Tanh
@@ -475,8 +424,7 @@ class Policy(nn.Module):
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
"""Get encoded features from observations"""
device = get_device_from_parameters(self)
observations = observations.to(device)
observations = observations.to(self.device)
if self.encoder is not None:
with torch.inference_mode():
return self.encoder(observations)
@@ -484,35 +432,65 @@ class Policy(nn.Module):
class SACObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
"""Encode image and/or state vector observations.
TODO(ke-wang): The original work allows for (1) stacking multiple history frames and (2) using pretrained resnet encoders.
"""
def __init__(self, config: SACConfig, input_normalizer: nn.Module):
def __init__(self, config: SACConfig):
"""
Creates encoders for pixel and/or state modalities.
"""
super().__init__()
self.config = config
self.input_normalization = input_normalizer
self.has_pretrained_vision_encoder = False
self.parameters_to_optimize = []
self.aggregation_size: int = 0
if any("observation.image" in key for key in config.input_shapes):
if "observation.image" in config.input_shapes:
self.image_enc_layers = nn.Sequential(
nn.Conv2d(
in_channels=config.input_shapes["observation.image"][0],
out_channels=config.image_encoder_hidden_dim,
kernel_size=7,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=5,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=3,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=3,
stride=2,
),
nn.ReLU(),
)
self.camera_number = config.camera_number
self.aggregation_size: int = 0
if self.config.vision_encoder_name is not None:
self.image_enc_layers = PretrainedImageEncoder(config)
self.has_pretrained_vision_encoder = True
else:
self.image_enc_layers = DefaultImageEncoder(config)
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
with torch.inference_mode():
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_layers.extend(
sequential=nn.Sequential(
nn.Flatten(),
nn.Linear(
in_features=np.prod(out_shape) * self.camera_number, out_features=config.latent_dim
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
)
)
self.aggregation_size += config.latent_dim * self.camera_number
if config.freeze_vision_encoder:
freeze_image_encoder(self.image_enc_layers)
else:
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential(
nn.Linear(
@@ -523,8 +501,6 @@ class SACObservationEncoder(nn.Module):
)
self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
if "observation.environment_state" in config.input_shapes:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
@@ -534,11 +510,9 @@ class SACObservationEncoder(nn.Module):
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
)
self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
self.aggregation_size += config.latent_dim
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
"""Encode the image and/or state vector.
@@ -547,20 +521,16 @@ class SACObservationEncoder(nn.Module):
over all features.
"""
feat = []
obs_dict = self.input_normalization(obs_dict)
# Concatenate all images along the channel dimension.
image_keys = [k for k in obs_dict if k.startswith("observation.image")]
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")]
for image_key in image_keys:
enc_feat = self.image_enc_layers(obs_dict[image_key])
# if not self.has_pretrained_vision_encoder:
# enc_feat = flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])
feat.append(enc_feat)
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key]))
if "observation.environment_state" in self.config.input_shapes:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
# TODO(ke-wang): currently average over all features, concatenate all features maybe a better way
# return torch.stack(feat, dim=0).mean(0)
features = torch.cat(tensors=feat, dim=-1)
features = self.aggregation_layer(features)
@@ -572,109 +542,15 @@ class SACObservationEncoder(nn.Module):
return self.config.latent_dim
class DefaultImageEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.image_enc_layers = nn.Sequential(
nn.Conv2d(
in_channels=config.input_shapes["observation.image"][0],
out_channels=config.image_encoder_hidden_dim,
kernel_size=7,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=5,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=3,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=3,
stride=2,
),
nn.ReLU(),
)
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
with torch.inference_mode():
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_layers.extend(
nn.Sequential(
nn.Flatten(),
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
)
def forward(self, x):
return self.image_enc_layers(x)
class PretrainedImageEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
self.image_enc_proj = nn.Sequential(
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
def _load_pretrained_vision_encoder(self, config):
"""Set up CNN encoder"""
from transformers import AutoModel
self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True)
# self.image_enc_layers.pooler = Identity()
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
elif hasattr(self.image_enc_layers, "fc"):
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
else:
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
return self.image_enc_layers, self.image_enc_out_shape
def forward(self, x):
# TODO: (maractingi, azouitine) check the forward pass of the pretrained model
# doesn't reach the classifier layer because we don't need it
enc_feat = self.image_enc_layers(x).pooler_output
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
return enc_feat
def freeze_image_encoder(image_encoder: nn.Module):
"""Freeze all parameters in the encoder"""
for param in image_encoder.parameters():
param.requires_grad = False
def orthogonal_init():
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cpu") -> nn.ModuleList:
"""Creates an ensemble of critic networks"""
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
return nn.ModuleList(critics).to(device)
def forward(self, x):
return x
# TODO (azouitine): I think in our case this function is not usefull we should remove it
# after some investigation
# borrowed from tdmpc
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor.
@@ -682,7 +558,7 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
Args:
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
(B, *), where * is any number of dimensions.
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
can be more than 1 dimensions, generally different from *.
Returns:
A return value from the callable reshaped to (**, *).
@@ -692,67 +568,4 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
start_dims = image_tensor.shape[:-3]
inp = torch.flatten(image_tensor, end_dim=-4)
flat_out = fn(inp)
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
converted_params = {}
for outer_key, inner_dict in normalization_params.items():
converted_params[outer_key] = {}
for key, value in inner_dict.items():
converted_params[outer_key][key] = torch.tensor(value)
if "image" in outer_key:
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
return converted_params
if __name__ == "__main__":
# Test the SACObservationEncoder
import time
config = SACConfig()
config.num_critics = 10
encoder = SACObservationEncoder(config)
actor_encoder = SACObservationEncoder(config)
encoder = torch.compile(encoder)
critic_ensemble = CriticEnsemble(
encoder=encoder,
network_list=nn.ModuleList(
[
MLP(
input_dim=encoder.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
),
)
actor = Policy(
encoder=actor_encoder,
network=MLP(input_dim=actor_encoder.output_dim, **config.actor_network_kwargs),
action_dim=config.output_shapes["action"][0],
encoder_is_shared=config.shared_encoder,
**config.policy_kwargs,
)
encoder = encoder.to("cuda:0")
critic_ensemble = torch.compile(critic_ensemble)
critic_ensemble = critic_ensemble.to("cuda:0")
actor = torch.compile(actor)
actor = actor.to("cuda:0")
obs_dict = {
"observation.image": torch.randn(1, 3, 84, 84),
"observation.state": torch.randn(1, 4),
}
actions = torch.randn(1, 2).to("cuda:0")
obs_dict = {k: v.to("cuda:0") for k, v in obs_dict.items()}
print("compiling...")
# q_value = critic_ensemble(obs_dict, actions)
action = actor(obs_dict)
print("compiled")
start = time.perf_counter()
for _ in range(1000):
# features = encoder(obs_dict)
action = actor(obs_dict)
# q_value = critic_ensemble(obs_dict, actions)
print("Time taken:", time.perf_counter() - start)
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))

View File

@@ -36,7 +36,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
def log_dt(shortname, dt_val_s):
nonlocal log_items, fps
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)"
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)"
if fps is not None:
actual_fps = 1 / dt_val_s
if actual_fps < fps - 1:
@@ -225,7 +225,6 @@ def record_episode(
device,
use_amp,
fps,
record_delta_actions,
):
control_loop(
robot=robot,
@@ -237,7 +236,6 @@ def record_episode(
device=device,
use_amp=use_amp,
fps=fps,
record_delta_actions=record_delta_actions,
teleoperate=policy is None,
)
@@ -254,7 +252,6 @@ def control_loop(
device=None,
use_amp=None,
fps=None,
record_delta_actions=False,
):
# TODO(rcadene): Add option to record logs
if not robot.is_connected:
@@ -277,12 +274,8 @@ def control_loop(
while timestamp < control_time_s:
start_loop_t = time.perf_counter()
current_joint_positions = robot.follower_arms["main"].read("Present_Position")
if teleoperate:
observation, action = robot.teleop_step(record_data=True)
if record_delta_actions:
action["action"] = action["action"] - current_joint_positions
else:
observation = robot.capture_observation()
@@ -297,12 +290,8 @@ def control_loop(
frame = {**observation, **action}
if "next.reward" in events:
frame["next.reward"] = events["next.reward"]
frame["next.done"] = (events["next.reward"] == 1) or (events["exit_early"])
dataset.add_frame(frame)
# if frame["next.done"]:
# break
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
@@ -346,9 +335,7 @@ def reset_environment(robot, events, reset_time_s):
def reset_follower_position(robot: Robot, target_position):
current_position = robot.follower_arms["main"].read("Present_Position")
trajectory = torch.from_numpy(
np.linspace(current_position, target_position, 30)
) # NOTE: 30 is just an aribtrary number
trajectory = torch.from_numpy(np.linspace(current_position, target_position, 30)) # NOTE: 30 is just an aribtrary number
for pose in trajectory:
robot.send_action(pose)
busy_wait(0.015)

View File

@@ -32,7 +32,7 @@ def ensure_safe_goal_position(
safe_goal_pos = present_pos + safe_diff
if not torch.allclose(goal_pos, safe_goal_pos):
logging.debug(
logging.warning(
"Relative goal position magnitude had to be clamped to be safe.\n"
f" requested relative goal position target: {diff}\n"
f" clamped relative goal position target: {safe_diff}"
@@ -67,8 +67,6 @@ class ManipulatorRobotConfig:
# gripper is not put in torque mode.
gripper_open_degree: float | None = None
joint_position_relative_bounds: dict[np.ndarray] | None = None
def __setattr__(self, prop: str, val):
if prop == "max_relative_target" and val is not None and isinstance(val, Sequence):
for name in self.follower_arms:
@@ -80,9 +78,6 @@ class ManipulatorRobotConfig:
"Note: This feature does not yet work with robots where different follower arms have "
"different numbers of motors."
)
if prop == "joint_position_relative_bounds" and val is not None:
for key in val:
val[key] = torch.tensor(val[key])
super().__setattr__(prop, val)
def __post_init__(self):
@@ -528,14 +523,6 @@ class ManipulatorRobot:
before_fwrite_t = time.perf_counter()
goal_pos = leader_pos[name]
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
if self.config.joint_position_relative_bounds is not None:
goal_pos = torch.clamp(
goal_pos,
self.config.joint_position_relative_bounds["min"],
self.config.joint_position_relative_bounds["max"],
)
# Cap goal position when too far away from present position.
# Slower fps expected due to reading from the follower.
if self.config.max_relative_target is not None:
@@ -657,14 +644,6 @@ class ManipulatorRobot:
goal_pos = action[from_idx:to_idx]
from_idx = to_idx
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
if self.config.joint_position_relative_bounds is not None:
goal_pos = torch.clamp(
goal_pos,
self.config.joint_position_relative_bounds["min"],
self.config.joint_position_relative_bounds["max"],
)
# Cap goal position when too far away from present position.
# Slower fps expected due to reading from the follower.
if self.config.max_relative_target is not None:
@@ -677,7 +656,6 @@ class ManipulatorRobot:
# Send goal position to each follower
goal_pos = goal_pos.numpy().astype(np.int32)
self.follower_arms[name].write("Goal_Position", goal_pos)
return torch.cat(action_sent)

View File

@@ -18,7 +18,6 @@ import os
import os.path as osp
import platform
import random
import time
from contextlib import contextmanager
from datetime import datetime, timezone
from pathlib import Path
@@ -218,28 +217,3 @@ def log_say(text, play_sounds, blocking=False):
if play_sounds:
say(text, blocking)
class TimerManager:
def __init__(self, elapsed_time_list: list[float] | None = None, label="Elapsed time", log=True):
self.label = label
self.elapsed_time_list = elapsed_time_list
self.log = log
self.elapsed = 0.0
def __enter__(self):
self.start = time.perf_counter()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.elapsed: float = time.perf_counter() - self.start
if self.elapsed_time_list is not None:
self.elapsed_time_list.append(self.elapsed)
if self.log:
print(f"{self.label}: {self.elapsed:.6f} seconds")
@property
def elapsed_seconds(self):
return self.elapsed

View File

@@ -2,7 +2,6 @@ defaults:
- _self_
- env: pusht
- policy: diffusion
- robot: so100
hydra:
run:

View File

@@ -1,20 +0,0 @@
# @package _global_
fps: 20
env:
name: maniskill/pushcube
task: PushCube-v1
image_size: 128
control_mode: pd_ee_delta_pose
state_dim: 25
action_dim: 7
fps: ${fps}
obs: rgb
render_mode: rgb_array
render_size: 128
device: cuda
reward_classifier:
pretrained_path: null
config_path: null

View File

@@ -1,6 +1,6 @@
# @package _global_
fps: 10
fps: 30
env:
name: real_world
@@ -8,24 +8,3 @@ env:
state_dim: 6
action_dim: 6
fps: ${fps}
device: mps
wrapper:
crop_params_dict:
observation.images.front: [102, 43, 358, 523]
observation.images.side: [92, 123, 379, 349]
# observation.images.front: [109, 37, 361, 557]
# observation.images.side: [94, 161, 372, 315]
resize_size: [128, 128]
control_time_s: 20
reset_follower_pos: true
use_relative_joint_positions: true
reset_time_s: 5
display_cameras: false
delta_action: 0.1
joint_masking_action_space: [1, 1, 1, 1, 0, 0] # disable wrist and gripper
reward_classifier:
pretrained_path: outputs/classifier/13-02-random-sample-resnet10-frozen/checkpoints/best/pretrained_model
config_path: lerobot/configs/policy/hilserl_classifier.yaml

View File

@@ -4,10 +4,7 @@ defaults:
- _self_
seed: 13
dataset_repo_id: aractingi/push_cube_square_light_reward_cropped_resized
# aractingi/push_cube_square_reward_1_cropped_resized
dataset_root: data/aractingi/push_cube_square_light_reward_cropped_resized
local_files_only: true
dataset_repo_id: aractingi/pick_place_lego_cube_1
train_split_proportion: 0.8
# Required by logger
@@ -17,7 +14,7 @@ env:
training:
num_epochs: 6
num_epochs: 5
batch_size: 16
learning_rate: 1e-4
num_workers: 4
@@ -27,18 +24,16 @@ training:
eval_freq: 1 # How often to run validation (in epochs)
save_freq: 1 # How often to save checkpoints (in epochs)
save_checkpoint: true
image_keys: ["observation.images.front", "observation.images.side"]
image_keys: ["observation.images.top", "observation.images.wrist"]
label_key: "next.reward"
profile_inference_time: false
profile_inference_time_iters: 20
eval:
batch_size: 16
num_samples_to_log: 30 # Number of validation samples to log in the table
policy:
name: "hilserl/classifier"
model_name: "helper2424/resnet10" # "facebook/convnext-base-224
name: "hilserl/classifier/pick_place_lego_cube_1"
model_name: "facebook/convnext-base-224"
model_type: "cnn"
num_cameras: 2 # Has to be len(training.image_keys)
@@ -50,4 +45,4 @@ wandb:
device: "mps"
resume: false
output_dir: "outputs/classifier/old_trainer_resnet10_frozen"
output_dir: "outputs/classifier"

View File

@@ -8,7 +8,8 @@
# env.gym.obs_type=environment_state_agent_pos \
seed: 1
dataset_repo_id: null
dataset_repo_id: null
training:
# Offline training dataloader
@@ -20,8 +21,8 @@ training:
lr: 3e-4
eval_freq: 2500
log_freq: 10
save_freq: 2000000
log_freq: 500
save_freq: 50000
online_steps: 1000000
online_rollout_n_episodes: 10
@@ -52,16 +53,12 @@ policy:
n_action_steps: 1
shared_encoder: true
# vision_encoder_name: null
vision_encoder_name: "helper2424/resnet10"
freeze_vision_encoder: true
# freeze_vision_encoder: false
input_shapes:
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.state: ["${env.state_dim}"]
observation.image: [3, 128, 128]
observation.image: [3, 64, 64]
output_shapes:
action: [7]
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes: null
@@ -69,8 +66,8 @@ policy:
action: min_max
output_normalization_params:
action:
min: [-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0]
max: [10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]
min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0]
max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
# Architecture / modeling.
# Neural networks.
@@ -78,15 +75,23 @@ policy:
# discount: 0.99
discount: 0.80
temperature_init: 1.0
num_critics: 2 #10
num_critics: 2
num_subsample_critics: null
critic_lr: 3e-4
actor_lr: 3e-4
temperature_lr: 3e-4
# critic_target_update_weight: 0.005
critic_target_update_weight: 0.01
utd_ratio: 2 # 10
utd_ratio: 1
actor_learner_config:
actor_ip: "127.0.0.1"
port: 50051
# # Loss coefficients.
# reward_coeff: 0.5
# expectile_weight: 0.9
# value_coeff: 0.1
# consistency_coeff: 20.0
# advantage_scaling: 3.0
# pi_coeff: 0.5
# temporal_decay_coeff: 0.5
# # Target model.
# target_model_momentum: 0.995

View File

@@ -1,127 +0,0 @@
# @package _global_
# Train with:
#
# python lerobot/scripts/train.py \
# +dataset=lerobot/pusht_keypoints
# env=pusht \
# env.gym.obs_type=environment_state_agent_pos \
seed: 1
dataset_repo_id: aractingi/push_cube_overfit_cropped_resized
#aractingi/push_cube_square_offline_demo_cropped_resized
training:
# Offline training dataloader
num_workers: 4
# batch_size: 256
batch_size: 512
grad_clip_norm: 10.0
lr: 3e-4
eval_freq: 2500
log_freq: 1
save_freq: 2000000
online_steps: 1000000
online_rollout_n_episodes: 10
online_rollout_batch_size: 10
online_steps_between_rollouts: 1000
online_sampling_ratio: 1.0
online_env_seed: 10000
online_buffer_capacity: 1000000
online_buffer_seed_size: 0
online_step_before_learning: 100 #5000
do_online_rollout_async: false
policy_update_freq: 1
# delta_timestamps:
# observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
# observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
# action: "[i / ${fps} for i in range(${policy.horizon})]"
# next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
policy:
name: sac
pretrained_model_path:
# Input / output structure.
n_action_repeats: 1
horizon: 1
n_action_steps: 1
shared_encoder: true
vision_encoder_name: "helper2424/resnet10"
freeze_vision_encoder: true
input_shapes:
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.state: ["${env.state_dim}"]
observation.images.front: [3, 128, 128]
observation.images.side: [3, 128, 128]
# observation.image: [3, 128, 128]
output_shapes:
action: [4] # ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.images.front: mean_std
observation.images.side: mean_std
observation.state: min_max
input_normalization_params:
observation.images.front:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
observation.images.side:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
observation.state:
min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786]
max: [ 7.215820e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01]
# min: [-87.09961, 62.402344, 67.23633, 36.035156, 77.34375,0.53691274]
# max: [58.183594, 131.83594, 145.98633, 82.08984, 78.22266, 0.60402685]
# min: [-88.50586, 23.81836, 0.87890625, -32.16797, 78.66211, 0.53691274]
# max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156, 88.18792]
output_normalization_modes:
action: min_max
output_normalization_params:
# action:
# min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0]
# max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
action:
min: [-149.23828125, -97.734375, -100.1953125, -73.740234375]
max: [149.23828125, 97.734375, 100.1953125, 73.740234375]
# Architecture / modeling.
# Neural networks.
image_encoder_hidden_dim: 32
# discount: 0.99
discount: 0.97
temperature_init: 1.0
num_critics: 2 #10
camera_number: 2
num_subsample_critics: null
critic_lr: 3e-4
actor_lr: 3e-4
temperature_lr: 3e-4
# critic_target_update_weight: 0.005
critic_target_update_weight: 0.01
utd_ratio: 2 # 10
actor_learner_config:
actor_ip: "127.0.0.1"
port: 50051
# # Loss coefficients.
# reward_coeff: 0.5
# expectile_weight: 0.9
# value_coeff: 0.1
# consistency_coeff: 20.0
# advantage_scaling: 3.0
# pi_coeff: 0.5
# temporal_decay_coeff: 0.5
# # Target model.
# target_model_momentum: 0.995

View File

@@ -14,9 +14,6 @@ calibration_dir: .cache/calibration/so100
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: null
joint_position_relative_bounds:
max: [ 7.2158203e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01]
min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786]
leader_arms:
main:
@@ -34,7 +31,7 @@ leader_arms:
follower_arms:
main:
_target_: lerobot.common.robot_devices.motors.feetech.FeetechMotorsBus
port: /dev/tty.usbmodem58760431631
port: /dev/tty.usbmodem585A0080971
motors:
# name: (index, model)
shoulder_pan: [1, "sts3215"]
@@ -45,13 +42,13 @@ follower_arms:
gripper: [6, "sts3215"]
cameras:
front:
laptop:
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
camera_index: 0
fps: 30
width: 640
height: 480
side:
phone:
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
camera_index: 1
fps: 30

View File

@@ -206,8 +206,7 @@ def record(
num_image_writer_threads_per_camera: int = 4,
display_cameras: bool = True,
play_sounds: bool = True,
reset_follower: bool = False,
record_delta_actions: bool = False,
reset_follower: bool = False,
resume: bool = False,
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
local_files_only: bool = False,
@@ -219,12 +218,7 @@ def record(
device = None
use_amp = None
extra_features = (
{
"next.reward": {"dtype": "int64", "shape": (1,), "names": None},
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
}
if assign_rewards
else None
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
)
if single_task:
@@ -275,7 +269,7 @@ def record(
if reset_follower:
initial_position = robot.follower_arms["main"].read("Present_Position")
# Execute a few seconds without recording to:
# 1. teleoperate the robot to move it in starting position if no policy provided,
# 2. give times to the robot devices to connect and start synchronizing,
@@ -308,7 +302,6 @@ def record(
device=device,
use_amp=use_amp,
fps=fps,
record_delta_actions=record_delta_actions,
)
# Execute a few seconds without recording to give time to manually reset the environment
@@ -360,24 +353,21 @@ def replay(
fps: int | None = None,
play_sounds: bool = True,
local_files_only: bool = False,
replay_delta_actions: bool = False,
):
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
# TODO(rcadene): Add option to record logs
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
actions = dataset.hf_dataset.select_columns("action")
if not robot.is_connected:
robot.connect()
log_say("Replaying episode", play_sounds, blocking=True)
for idx in range(dataset.num_frames):
current_joint_positions = robot.follower_arms["main"].read("Present_Position")
start_episode_t = time.perf_counter()
action = actions[idx]["action"]
if replay_delta_actions:
action = action + current_joint_positions
robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t
@@ -544,12 +534,6 @@ if __name__ == "__main__":
default=0,
help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.",
)
parser_record.add_argument(
"--record-delta-actions",
type=int,
default=0,
help="Enables the recording of delta actions instead of absolute actions.",
)
parser_record.add_argument(
"--reset-follower",
type=int,
@@ -579,12 +563,6 @@ if __name__ == "__main__":
default=0,
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
)
parser_replay.add_argument(
"--replay-delta-actions",
type=int,
default=0,
help="Enables the replay of delta actions instead of absolute actions.",
)
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episode to replay.")
args = parser.parse_args()

View File

@@ -15,25 +15,34 @@
# limitations under the License.
"""Evaluate a policy by running rollouts on the real robot and computing metrics.
Usage examples: evaluate a checkpoint from the LeRobot training script for 10 episodes.
This script supports performing human interventions during rollouts.
Human interventions allow the user to take control of the robot from the policy
and correct its behavior. It is specifically designed for reinforcement learning
experiments and HIL-SERL (human-in-the-loop reinforcement learning) methods.
```
python lerobot/scripts/eval_on_robot.py \
-p outputs/train/model/checkpoints/005000/pretrained_model \
eval.n_episodes=10
```
### How to Use
Test reward classifier with teleoperation (you need to press space to take over)
To rollout a policy on the robot:
```
python lerobot/scripts/eval_on_robot.py \
--robot-path lerobot/configs/robot/so100.yaml \
--pretrained-policy-path-or-name path/to/pretrained_model \
--policy-config path/to/policy/config.yaml \
--display-cameras 1
```
If you trained a reward classifier on your task, you can also evaluate it using this script.
You can annotate the collection with a pre-trained reward classifier by running:
```
python lerobot/scripts/eval_on_robot.py \
--robot-path lerobot/configs/robot/so100.yaml \
--pretrained-policy-path-or-name path/to/pretrained_model \
--policy-config path/to/policy/config.yaml \
--reward-classifier-pretrained-path outputs/classifier/checkpoints/best/pretrained_model \
--reward-classifier-config-file lerobot/configs/policy/hilserl_classifier.yaml \
--display-cameras 1
```
**NOTE** (michel-aractingi): This script is incomplete and it is being prepared
for running training on the real robot.
"""
import argparse
@@ -46,7 +55,8 @@ import torch
from tqdm import trange
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position, predict_action
from lerobot.common.robot_devices.robots.factory import Robot, make_robot
from lerobot.common.utils.utils import (
init_hydra_config,
@@ -69,7 +79,6 @@ def get_classifier(pretrained_path, config_path):
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
model = Classifier(classifier_config)
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
model = model.to("mps")
return model
@@ -81,48 +90,45 @@ def rollout(
control_time_s: float = 20,
use_amp: bool = True,
display_cameras: bool = False,
device: str = "cpu"
) -> dict:
"""Run a batched policy rollout on the real robot.
The return dictionary contains:
"robot": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation
keys. NOTE the that this has an extra sequence element relative to the other keys in the
dictionary. This is because an extra observation is included for after the environment is
terminated or truncated.
"action": A (batch, sequence, action_dim) tensor of actions applied based on the observations (not
including the last observations).
"reward": A (batch, sequence) tensor of rewards received for applying the actions.
"success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon
environment termination/truncation).
"done": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
the first True is followed by True's all the way till the end. This can be used for masking
extraneous elements from the sequences above.
This function executes a rollout using the provided policy and robot interface,
simulating batched interactions for a fixed control duration.
The returned dictionary contains rollout statistics, which can be used for analysis and debugging.
Args:
robot: The robot class that defines the interface with the real robot.
policy: The policy. Must be a PyTorch nn module.
"robot": The robot interface for interacting with the real robot hardware.
"policy": The policy to execute. Must be a PyTorch `nn.Module` object.
"reward_classifier": A module to classify rewards during the rollout.
"fps": The control frequency at which the policy is executed.
"control_time_s": The total control duration of the rollout in seconds.
"use_amp": Whether to use automatic mixed precision (AMP) for policy evaluation.
"display_cameras": If True, displays camera streams during the rollout.
"device": The device to use for computations (e.g., "cpu", "cuda" or "mps").
Returns:
The dictionary described above.
Dictionary of the statisitcs collected during rollouts.
"""
# TODO (michel-aractingi): Infer the device from policy parameters when policy is added
# assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
# device = get_device_from_parameters(policy)
# define keyboard listener
listener, events = init_keyboard_listener()
# Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready.
# policy.reset()
if policy is not None:
policy.reset()
# NOTE: sorting to make sure the key sequence is the same during training and testing.
observation = robot.capture_observation()
image_keys = [key for key in observation if "image" in key]
image_keys.sort()
image_keys.sort() # CG{T}
all_actions = []
all_rewards = []
all_successes = []
indices_from_policy = []
start_episode_t = time.perf_counter()
init_pos = robot.follower_arms["main"].read("Present_Position")
@@ -141,27 +147,32 @@ def rollout(
else:
# explore with policy
with torch.inference_mode():
# TODO (michel-aractingi) replace this part with policy (predict_action)
action = robot.follower_arms["main"].read("Present_Position")
action = torch.from_numpy(action)
# TODO (michel-aractingi) in placy temporarly for testing purposes
if policy is None:
action = robot.follower_arms["main"].read("Present_Position")
action = torch.from_numpy(action)
indices_from_policy.append(False)
else:
action = predict_action(observation, policy, device, use_amp)
indices_from_policy.append(True)
robot.send_action(action)
# action = predict_action(observation, policy, device, use_amp)
observation = robot.capture_observation()
observation = robot.capture_observation()
images = []
for key in image_keys:
if display_cameras:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
images.append(observation[key].to("mps"))
images.append(observation[key].to(device))
reward = reward_classifier.predict_reward(images) if reward_classifier is not None else 0.0
# TODO send data through the server as soon as you have it
all_rewards.append(reward)
# print("REWARD : ", reward)
all_actions.append(action)
all_successes.append(torch.tensor([False]))
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
@@ -180,7 +191,6 @@ def rollout(
ret = {
"action": torch.stack(all_actions, dim=1),
"next.reward": torch.stack(all_rewards, dim=1),
"next.success": torch.stack(all_successes, dim=1),
"done": dones,
}
@@ -199,14 +209,32 @@ def eval_policy(
display_cameras: bool = False,
reward_classifier_pretrained_path: str | None = None,
reward_classifier_config_file: str | None = None,
device: str | None = None,
) -> dict:
"""
Evaluate a policy on a real robot by running multiple episodes and collecting metrics.
This function executes rollouts of the specified policy on the robot, computes metrics
for the rollouts, and optionally evaluates a reward classifier if provided.
Args:
env: The batch of environments.
policy: The policy.
n_episodes: The number of episodes to evaluate.
"robot": The robot interface used to interact with the real robot hardware.
"policy": The policy to be evaluated. Must be a PyTorch neural network module.
"fps": Frames per second (control frequency) for running the policy.
"n_episodes": The number of episodes to evaluate the policy.
"control_time_s": The max duration for each episode in seconds.
"use_amp": Whether to use automatic mixed precision (AMP) for policy evaluation.
"display_cameras": Whether to display camera streams during rollouts.
"reward_classifier_pretrained_path": Path to the pretrained reward classifier.
If provided, the reward classifier will be evaluated during rollouts.
"reward_classifier_config_file": Path to the configuration file for the reward classifier.
Required if `reward_classifier_pretrained_path` is provided.
"device": The device for computations (e.g., "cpu", "cuda" or "mps").
Returns:
Dictionary with metrics and data regarding the rollouts.
"dict": A dictionary containing the following rollout metrics and data:
- "metrics": Evaluation metrics such as cumulative rewards, success rates, etc.
- "rollout_data": Detailed data from the rollouts, including observations, actions, rewards, and done flags.
"""
# TODO (michel-aractingi) comment this out for testing with a fixed policy
# assert isinstance(policy, Policy)
@@ -214,22 +242,22 @@ def eval_policy(
sum_rewards = []
max_rewards = []
successes = []
rollouts = []
start_eval = time.perf_counter()
progbar = trange(n_episodes, desc="Evaluating policy on real robot")
reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file)
reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file).to(device)
device = get_device_from_parameters(policy) if device is None else device
for _ in progbar:
rollout_data = rollout(
robot, policy, reward_classifier, fps, control_time_s, use_amp, display_cameras
robot, policy, reward_classifier, fps, control_time_s, use_amp, display_cameras, device
)
rollouts.append(rollout_data)
sum_rewards.append(sum(rollout_data["next.reward"]))
max_rewards.append(max(rollout_data["next.reward"]))
successes.append(rollout_data["next.success"][-1])
info = {
"per_episode": [
@@ -237,21 +265,18 @@ def eval_policy(
"episode_ix": i,
"sum_reward": sum_reward,
"max_reward": max_reward,
"pc_success": success * 100,
}
for i, (sum_reward, max_reward, success) in enumerate(
for i, (sum_reward, max_reward) in enumerate(
zip(
sum_rewards[:n_episodes],
max_rewards[:n_episodes],
successes[:n_episodes],
strict=False,
)
)
],
"aggregated": {
"avg_sum_reward": float(np.nanmean(torch.cat(sum_rewards[:n_episodes]))),
"avg_sum_reward": float(np.nanmean(torch.cat(sum_rewards[:n_episodes]))),
"avg_max_reward": float(np.nanmean(torch.cat(max_rewards[:n_episodes]))),
"pc_success": float(np.nanmean(torch.cat(successes[:n_episodes])) * 100),
"eval_s": time.time() - start_eval,
"eval_ep_s": (time.time() - start_eval) / n_episodes,
},
@@ -264,9 +289,18 @@ def eval_policy(
def init_keyboard_listener():
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
"""
Initialize a keyboard listener for controlling the recording and human intervention process.
Keyboard controls: (Note that this might require sudo permissions to monitor keyboard events)
- Right Arrow Key ('->'): Stops the current recording and exits early, useful for ending an episode
and moving the next episode recording.
- Left Arrow Key ('<-'): Re-records the current episode, allowing the user to start over.
- Space Bar: Controls the human intervention process in three steps:
1. First press pauses the policy and prompts the user to position the leader similar to the follower.
2. Second press initiates human interventions, allowing teleop control of the robot.
3. Third press resumes the policy rollout.
"""
events = {}
events["exit_early"] = False
events["rerecord_episode"] = False
@@ -302,10 +336,15 @@ def init_keyboard_listener():
)
events["pause_policy"] = True
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
else:
elif events["pause_policy"] and not events["human_intervention_step"]:
events["human_intervention_step"] = True
print("Space key pressed. Human intervention starting.")
log_say("Starting human intervention.", play_sounds=True)
elif events["human_intervention_step"]:
events["human_intervention_step"] = False
events["pause_policy"] = False
print("Space key pressed. Human intervention ending, policy resumes control.")
log_say("Policy resuming.", play_sounds=True)
except Exception as e:
print(f"Error handling key press: {e}")

View File

@@ -1,375 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import logging
import pickle
import queue
import time
from concurrent import futures
from statistics import mean, quantiles
# from lerobot.scripts.eval import eval_policy
from threading import Thread
import grpc
import hydra
import torch
from omegaconf import DictConfig
from torch import nn
# TODO: Remove the import of maniskill
# from lerobot.common.envs.factory import make_maniskill_env
# from lerobot.common.envs.utils import preprocess_maniskill_observation
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.robot_devices.control_utils import busy_wait
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import (
TimerManager,
get_safe_torch_device,
set_global_seed,
)
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc
from lerobot.scripts.server.buffer import Transition, move_state_dict_to_device, move_transition_to_device
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
logging.basicConfig(level=logging.INFO)
parameters_queue = queue.Queue(maxsize=1)
message_queue = queue.Queue(maxsize=1_000_000)
class ActorInformation:
"""
This helper class is used to differentiate between two types of messages that are placed in the same queue during streaming:
- **Transition Data:** Contains experience tuples (observation, action, reward, next observation) collected during interaction.
- **Interaction Messages:** Encapsulates statistics related to the interaction process.
Attributes:
transition (Optional): Transition data to be sent to the learner.
interaction_message (Optional): Iteraction message providing additional statistics for logging.
"""
def __init__(self, transition=None, interaction_message=None):
self.transition = transition
self.interaction_message = interaction_message
class ActorServiceServicer(hilserl_pb2_grpc.ActorServiceServicer):
"""
gRPC service for actor-learner communication in reinforcement learning.
This service is responsible for:
1. Streaming batches of transition data and statistical metrics from the actor to the learner.
2. Receiving updated network parameters from the learner.
"""
def StreamTransition(self, request, context): # noqa: N802
"""
Streams data from the actor to the learner.
This function continuously retrieves messages from the queue and processes them based on their type:
- **Transition Data:**
- A batch of transitions (observation, action, reward, next observation) is collected.
- Transitions are moved to the CPU and serialized using PyTorch.
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
- **Interaction Messages:**
- Contains useful statistics about episodic rewards and policy timings.
- The message is serialized using `pickle` and sent to the learner.
Yields:
hilserl_pb2.ActorInformation: The response message containing either transition data or an interaction message.
"""
while True:
message = message_queue.get(block=True)
if message.transition is not None:
transition_to_send_to_learner: list[Transition] = [
move_transition_to_device(transition=T, device="cpu") for T in message.transition
]
# Check for NaNs in transitions before sending to learner
for transition in transition_to_send_to_learner:
for key, value in transition["state"].items():
if torch.isnan(value).any():
logging.warning(f"Found NaN values in transition {key}")
buf = io.BytesIO()
torch.save(transition_to_send_to_learner, buf)
transition_bytes = buf.getvalue()
transition_message = hilserl_pb2.Transition(transition_bytes=transition_bytes)
response = hilserl_pb2.ActorInformation(transition=transition_message)
elif message.interaction_message is not None:
content = hilserl_pb2.InteractionMessage(
interaction_message_bytes=pickle.dumps(message.interaction_message)
)
response = hilserl_pb2.ActorInformation(interaction_message=content)
yield response
def SendParameters(self, request, context): # noqa: N802
"""
Receives updated parameters from the learner and updates the actor.
The learner calls this method to send new model parameters. The received parameters are deserialized
and placed in a queue to be consumed by the actor.
Args:
request (hilserl_pb2.ParameterUpdate): The request containing serialized network parameters.
context (grpc.ServicerContext): The gRPC context.
Returns:
hilserl_pb2.Empty: An empty response to acknowledge receipt.
"""
buffer = io.BytesIO(request.parameter_bytes)
params = torch.load(buffer)
parameters_queue.put(params)
return hilserl_pb2.Empty()
def serve_actor_service(port=50052):
"""
Runs a gRPC server to start streaming the data from the actor to the learner.
Throught this server the learner can push parameters to the Actor as well.
"""
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=20),
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
)
hilserl_pb2_grpc.add_ActorServiceServicer_to_server(ActorServiceServicer(), server)
server.add_insecure_port(f"[::]:{port}")
server.start()
logging.info(f"[ACTOR] gRPC server listening on port {port}")
server.wait_for_termination()
def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, device):
if not parameters_queue.empty():
logging.info("[ACTOR] Load new parameters from Learner.")
state_dict = parameters_queue.get()
state_dict = move_state_dict_to_device(state_dict, device=device)
policy.load_state_dict(state_dict, strict=False)
def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module):
"""
Executes policy interaction within the environment.
This function rolls out the policy in the environment, collecting interaction data and pushing it to a queue for streaming to the learner.
Once an episode is completed, updated network parameters received from the learner are retrieved from a queue and loaded into the network.
Args:
cfg (DictConfig): Configuration settings for the interaction process.
"""
logging.info("make_env online")
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg.env)
set_global_seed(cfg.seed)
device = get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
logging.info("make_policy")
# HACK: This is an ugly hack to pass the normalization parameters to the policy
# Because the action space is dynamic so we override the output normalization parameters
# it's ugly, we know ... and we will fix it
min_action_space: list = online_env.action_space.spaces[0].low.tolist()
max_action_space: list = online_env.action_space.spaces[0].high.tolist()
output_normalization_params: dict[dict[str, list]] = {
"action": {"min": min_action_space, "max": max_action_space}
}
cfg.policy.output_normalization_params = output_normalization_params
cfg.policy.output_shapes["action"] = online_env.action_space.spaces[0].shape
### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy intance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
# TODO: At some point we should just need make sac policy
policy: SACPolicy = make_policy(
hydra_cfg=cfg,
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online training, we do not need dataset_stats
dataset_stats=None,
# TODO: Handle resume training
device=device,
)
policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
obs, info = online_env.reset()
# NOTE: For the moment we will solely handle the case of a single environment
sum_reward_episode = 0
list_transition_to_send_to_learner = []
list_policy_time = []
episode_intervention = False
for interaction_step in range(cfg.training.online_steps):
if interaction_step >= cfg.training.online_step_before_learning:
# Time policy inference and check if it meets FPS requirement
with TimerManager(
elapsed_time_list=list_policy_time, label="Policy inference time", log=False
) as timer: # noqa: F841
action = policy.select_action(batch=obs)
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
else:
# TODO (azouitine): Make a custom space for torch tensor
action = online_env.action_space.sample()
next_obs, reward, done, truncated, info = online_env.step(action)
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
action = (
torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0)
)
sum_reward_episode += float(reward)
# NOTE: We overide the action if the intervention is True, because the action applied is the intervention action
if "is_intervention" in info and info["is_intervention"]:
# TODO: Check the shape
# NOTE: The action space for demonstration before hand is with the full action space
# but sometimes for example we want to deactivate the gripper
action = info["action_intervention"]
episode_intervention = True
# Check for NaN values in observations
for key, tensor in obs.items():
if torch.isnan(tensor).any():
logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}")
list_transition_to_send_to_learner.append(
Transition(
state=obs,
action=action,
reward=reward,
next_state=next_obs,
done=done,
complementary_info=info, # TODO Handle information for the transition, is_demonstraction: bool
)
)
# assign obs to the next obs and continue the rollout
obs = next_obs
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
# Because we are using a single environment we can index at zero
if done or truncated:
# TODO: Handle logging for episode information
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device)
if len(list_transition_to_send_to_learner) > 0:
send_transitions_in_chunks(
transitions=list_transition_to_send_to_learner, message_queue=message_queue, chunk_size=4
)
list_transition_to_send_to_learner = []
stats = get_frequency_stats(list_policy_time)
list_policy_time.clear()
# Send episodic reward to the learner
message_queue.put(
ActorInformation(
interaction_message={
"Episodic reward": sum_reward_episode,
"Interaction step": interaction_step,
"Episode intervention": int(episode_intervention),
**stats,
}
)
)
sum_reward_episode = 0.0
episode_intervention = False
obs, info = online_env.reset()
def send_transitions_in_chunks(transitions: list, message_queue, chunk_size: int = 100):
"""Send transitions to learner in smaller chunks to avoid network issues.
Args:
transitions: List of transitions to send
message_queue: Queue to send messages to learner
chunk_size: Size of each chunk to send
"""
for i in range(0, len(transitions), chunk_size):
chunk = transitions[i : i + chunk_size]
logging.debug(f"[ACTOR] Sending chunk of {len(chunk)} transitions to Learner.")
message_queue.put(ActorInformation(transition=chunk))
def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
stats = {}
list_policy_fps = [1.0 / t for t in list_policy_time]
if len(list_policy_fps) > 1:
policy_fps = mean(list_policy_fps)
quantiles_90 = quantiles(list_policy_fps, n=10)[-1]
logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}")
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}")
stats = {"Policy frequency [Hz]": policy_fps, "Policy frequency 90th-p [Hz]": quantiles_90}
return stats
def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_step: int):
if policy_fps < cfg.fps:
logging.warning(
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"
)
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
def actor_cli(cfg: dict):
robot = make_robot(cfg=cfg.robot)
server_thread = Thread(target=serve_actor_service, args=(cfg.actor_learner_config.port,), daemon=True)
# HACK: FOR MANISKILL we do not have a reward classifier
# TODO: Remove this once we merge into main
reward_classifier = None
if (
cfg.env.reward_classifier.pretrained_path is not None
and cfg.env.reward_classifier.config_path is not None
):
reward_classifier = get_classifier(
pretrained_path=cfg.env.reward_classifier.pretrained_path,
config_path=cfg.env.reward_classifier.config_path,
)
policy_thread = Thread(
target=act_with_policy,
daemon=True,
args=(cfg, robot, reward_classifier),
)
server_thread.start()
policy_thread.start()
policy_thread.join()
server_thread.join()
if __name__ == "__main__":
actor_cli()

View File

@@ -1,593 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import random
from typing import Any, Callable, Optional, Sequence, TypedDict
import torch
import torch.nn.functional as F # noqa: N812
from tqdm import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
class Transition(TypedDict):
state: dict[str, torch.Tensor]
action: torch.Tensor
reward: float
next_state: dict[str, torch.Tensor]
done: bool
complementary_info: dict[str, Any] = None
class BatchTransition(TypedDict):
state: dict[str, torch.Tensor]
action: torch.Tensor
reward: torch.Tensor
next_state: dict[str, torch.Tensor]
done: torch.Tensor
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
# Move state tensors to CPU
device = torch.device(device)
transition["state"] = {
key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items()
}
# Move action to CPU
transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda")
# No need to move reward or done, as they are float and bool
# No need to move reward or done, as they are float and bool
if isinstance(transition["reward"], torch.Tensor):
transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda")
if isinstance(transition["done"], torch.Tensor):
transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda")
# Move next_state tensors to CPU
transition["next_state"] = {
key: val.to(device, non_blocking=device.type == "cuda")
for key, val in transition["next_state"].items()
}
# If complementary_info is present, move its tensors to CPU
# if transition["complementary_info"] is not None:
# transition["complementary_info"] = {
# key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items()
# }
return transition
def move_state_dict_to_device(state_dict, device):
"""
Recursively move all tensors in a (potentially) nested
dict/list/tuple structure to the CPU.
"""
if isinstance(state_dict, torch.Tensor):
return state_dict.to(device)
elif isinstance(state_dict, dict):
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
elif isinstance(state_dict, list):
return [move_state_dict_to_device(v, device=device) for v in state_dict]
elif isinstance(state_dict, tuple):
return tuple(move_state_dict_to_device(v, device=device) for v in state_dict)
else:
return state_dict
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
"""
Perform a per-image random crop over a batch of images in a vectorized way.
(Same as shown previously.)
"""
B, C, H, W = images.shape # noqa: N806
crop_h, crop_w = output_size
if crop_h > H or crop_w > W:
raise ValueError(
f"Requested crop size ({crop_h}, {crop_w}) is bigger than the image size ({H}, {W})."
)
tops = torch.randint(0, H - crop_h + 1, (B,), device=images.device)
lefts = torch.randint(0, W - crop_w + 1, (B,), device=images.device)
rows = torch.arange(crop_h, device=images.device).unsqueeze(0) + tops.unsqueeze(1)
cols = torch.arange(crop_w, device=images.device).unsqueeze(0) + lefts.unsqueeze(1)
rows = rows.unsqueeze(2).expand(-1, -1, crop_w) # (B, crop_h, crop_w)
cols = cols.unsqueeze(1).expand(-1, crop_h, -1) # (B, crop_h, crop_w)
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
# Gather pixels
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
# cropped_hwcn => (B, crop_h, crop_w, C)
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
return cropped
def random_shift(images: torch.Tensor, pad: int = 4):
"""Vectorized random shift, imgs: (B,C,H,W), pad: #pixels"""
_, _, h, w = images.shape
images = F.pad(input=images, pad=(pad, pad, pad, pad), mode="replicate")
return random_crop_vectorized(images=images, output_size=(h, w))
class ReplayBuffer:
def __init__(
self,
capacity: int,
device: str = "cuda:0",
state_keys: Optional[Sequence[str]] = None,
image_augmentation_function: Optional[Callable] = None,
use_drq: bool = True,
storage_device: str = "cpu",
):
"""
Args:
capacity (int): Maximum number of transitions to store in the buffer.
device (str): The device where the tensors will be moved ("cuda:0" or "cpu").
state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
and returns a batch of augmented images. If None, a default augmentation function is used.
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored when adding transitions to the buffer.
Using "cpu" can help save GPU memory.
"""
self.capacity = capacity
self.device = device
self.storage_device = storage_device
self.memory: list[Transition] = []
self.position = 0
# If no state_keys provided, default to an empty list
# (you can handle this differently if needed)
self.state_keys = state_keys if state_keys is not None else []
if image_augmentation_function is None:
self.image_augmentation_function = functools.partial(random_shift, pad=4)
self.use_drq = use_drq
def __len__(self):
return len(self.memory)
def add(
self,
state: dict[str, torch.Tensor],
action: torch.Tensor,
reward: float,
next_state: dict[str, torch.Tensor],
done: bool,
complementary_info: Optional[dict[str, torch.Tensor]] = None,
):
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
# Move tensors to the storage device
state = {key: tensor.to(self.storage_device) for key, tensor in state.items()}
next_state = {key: tensor.to(self.storage_device) for key, tensor in next_state.items()}
action = action.to(self.storage_device)
# if complementary_info is not None:
# complementary_info = {
# key: tensor.to(self.storage_device) for key, tensor in complementary_info.items()
# }
if len(self.memory) < self.capacity:
self.memory.append(None)
# Create and store the Transition
self.memory[self.position] = Transition(
state=state,
action=action,
reward=reward,
next_state=next_state,
done=done,
complementary_info=complementary_info,
)
self.position = (self.position + 1) % self.capacity
# TODO: ADD image_augmentation and use_drq arguments in this function in order to instantiate the class with them
@classmethod
def from_lerobot_dataset(
cls,
lerobot_dataset: LeRobotDataset,
device: str = "cuda:0",
state_keys: Optional[Sequence[str]] = None,
capacity: Optional[int] = None,
action_mask: Optional[Sequence[int]] = None,
action_delta: Optional[float] = None,
) -> "ReplayBuffer":
"""
Convert a LeRobotDataset into a ReplayBuffer.
Args:
lerobot_dataset (LeRobotDataset): The dataset to convert.
device (str): The device . Defaults to "cuda:0".
state_keys (Optional[Sequence[str]], optional): The list of keys that appear in `state` and `next_state`.
Defaults to None.
Returns:
ReplayBuffer: The replay buffer with offline dataset transitions.
"""
# We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from
# a replay buffer than from a lerobot dataset.
if capacity is None:
capacity = len(lerobot_dataset)
if capacity < len(lerobot_dataset):
raise ValueError(
"The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset."
)
replay_buffer = cls(capacity=capacity, device=device, state_keys=state_keys)
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
# Fill the replay buffer with the lerobot dataset transitions
for data in list_transition:
for k, v in data.items():
if isinstance(v, dict):
for key, tensor in v.items():
v[key] = tensor.to(device)
elif isinstance(v, torch.Tensor):
data[k] = v.to(device)
if action_mask is not None:
if data["action"].dim() == 1:
data["action"] = data["action"][action_mask]
else:
data["action"] = data["action"][:, action_mask]
if action_delta is not None:
data["action"] = data["action"] / action_delta
replay_buffer.add(
state=data["state"],
action=data["action"],
reward=data["reward"],
next_state=data["next_state"],
done=data["done"],
)
return replay_buffer
@staticmethod
def _lerobotdataset_to_transitions(
dataset: LeRobotDataset,
state_keys: Optional[Sequence[str]] = None,
) -> list[Transition]:
"""
Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions.
Args:
dataset (LeRobotDataset):
The dataset to convert. Each item in the dataset is expected to have
at least the following keys:
{
"action": ...
"next.reward": ...
"next.done": ...
"episode_index": ...
}
plus whatever your 'state_keys' specify.
state_keys (Optional[Sequence[str]]):
The dataset keys to include in 'state' and 'next_state'. Their names
will be kept as-is in the output transitions. E.g.
["observation.state", "observation.environment_state"].
If None, you must handle or define default keys.
Returns:
transitions (List[Transition]):
A list of Transition dictionaries with the same length as `dataset`.
"""
# If not provided, you can either raise an error or define a default:
if state_keys is None:
raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.")
transitions: list[Transition] = []
num_frames = len(dataset)
for i in tqdm(range(num_frames)):
current_sample = dataset[i]
# ----- 1) Current state -----
current_state: dict[str, torch.Tensor] = {}
for key in state_keys:
val = current_sample[key]
current_state[key] = val.unsqueeze(0) # Add batch dimension
# ----- 2) Action -----
action = current_sample["action"].unsqueeze(0) # Add batch dimension
# ----- 3) Reward and done -----
reward = float(current_sample["next.reward"].item()) # ensure float
done = bool(current_sample["next.done"].item()) # ensure bool
# ----- 4) Next state -----
# If not done and the next sample is in the same episode, we pull the next sample's state.
# Otherwise (done=True or next sample crosses to a new episode), next_state = current_state.
next_state = current_state # default
if not done and (i < num_frames - 1):
next_sample = dataset[i + 1]
if next_sample["episode_index"] == current_sample["episode_index"]:
# Build next_state from the same keys
next_state_data: dict[str, torch.Tensor] = {}
for key in state_keys:
val = next_sample[key]
next_state_data[key] = val.unsqueeze(0) # Add batch dimension
next_state = next_state_data
# ----- Construct the Transition -----
transition = Transition(
state=current_state,
action=action,
reward=reward,
next_state=next_state,
done=done,
)
transitions.append(transition)
return transitions
def sample(self, batch_size: int) -> BatchTransition:
"""Sample a random batch of transitions and collate them into batched tensors."""
batch_size = min(batch_size, len(self.memory))
list_of_transitions = random.sample(self.memory, batch_size)
# -- Build batched states --
batch_state = {}
for key in self.state_keys:
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to(
self.device
)
if key.startswith("observation.image") and self.use_drq:
batch_state[key] = self.image_augmentation_function(batch_state[key])
# -- Build batched actions --
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device)
# -- Build batched rewards --
batch_rewards = torch.tensor([t["reward"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
# -- Build batched next states --
batch_next_state = {}
for key in self.state_keys:
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to(
self.device
)
if key.startswith("observation.image") and self.use_drq:
batch_next_state[key] = self.image_augmentation_function(batch_next_state[key])
# -- Build batched dones --
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
# Return a BatchTransition typed dict
return BatchTransition(
state=batch_state,
action=batch_actions,
reward=batch_rewards,
next_state=batch_next_state,
done=batch_dones,
)
def to_lerobot_dataset(
self,
repo_id: str,
fps=1, # If you have real timestamps, adjust this
root=None,
task_name="from_replay_buffer",
) -> LeRobotDataset:
"""
Converts all transitions in this ReplayBuffer into a single LeRobotDataset object,
splitting episodes by transitions where 'done=True'.
Returns:
LeRobotDataset: The resulting offline dataset.
"""
if len(self.memory) == 0:
raise ValueError("The replay buffer is empty. Cannot convert to a dataset.")
# Infer the shapes and dtypes of your features
# We'll create a features dict that is suitable for LeRobotDataset
# --------------------------------------------------------------------------------------------
# First, grab one transition to inspect shapes
first_transition = self.memory[0]
# We'll store default metadata for every episode: indexes, timestamps, etc.
features = {
"index": {"dtype": "int64", "shape": [1]}, # global index across episodes
"episode_index": {"dtype": "int64", "shape": [1]}, # which episode
"frame_index": {"dtype": "int64", "shape": [1]}, # index inside an episode
"timestamp": {"dtype": "float32", "shape": [1]}, # for now we store dummy
"task_index": {"dtype": "int64", "shape": [1]},
}
# Add "action"
act_info = guess_feature_info(
first_transition["action"].squeeze(dim=0), "action"
) # Remove batch dimension
features["action"] = act_info
# Add "reward" (scalars)
features["next.reward"] = {"dtype": "float32", "shape": (1,)}
# Add "done" (boolean scalars)
features["next.done"] = {"dtype": "bool", "shape": (1,)}
# Add state keys
for key in self.state_keys:
sample_val = first_transition["state"][key].squeeze(dim=0) # Remove batch dimension
if not isinstance(sample_val, torch.Tensor):
raise ValueError(
f"State key '{key}' is not a torch.Tensor. Please ensure your states are stored as torch.Tensors."
)
f_info = guess_feature_info(sample_val, key)
features[key] = f_info
# --------------------------------------------------------------------------------------------
# Create an empty LeRobotDataset
# We'll store all frames as separate images only if we detect shape = (3, H, W) or (1, H, W).
# By default we won't do videos, but feel free to adapt if you have them.
# --------------------------------------------------------------------------------------------
lerobot_dataset = LeRobotDataset.create(
repo_id=repo_id,
fps=fps, # If you have real timestamps, adjust this
root=root, # Or some local path where you'd like the dataset files to go
robot=None,
robot_type=None,
features=features,
use_videos=True, # We won't do actual video encoding for a replay buffer
)
# Start writing images if needed. If you have no image features, this is harmless.
# Set num_processes or num_threads if you want concurrency.
lerobot_dataset.start_image_writer(num_processes=0, num_threads=2)
# --------------------------------------------------------------------------------------------
# Convert transitions into episodes and frames
# We detect episode boundaries by `done == True`.
# --------------------------------------------------------------------------------------------
episode_index = 0
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index)
frame_idx_in_episode = 0
for global_frame_idx, transition in enumerate(self.memory):
frame_dict = {}
# Fill the data for state keys
for key in self.state_keys:
# Expand dimension to match what the dataset expects (the dataset wants the raw shape)
# We assume your buffer has shape [C, H, W] (if image) or [D] if vector
# This is typically already correct, but if needed you can reshape below.
frame_dict[key] = transition["state"][key].cpu().squeeze(dim=0) # Remove batch dimension
# Fill action, reward, done
# Make sure they are shape (X,) or (X,Y,...) as needed.
frame_dict["action"] = transition["action"].cpu().squeeze(dim=0) # Remove batch dimension
frame_dict["next.reward"] = (
torch.tensor([transition["reward"]], dtype=torch.float32).cpu().squeeze(dim=0)
)
frame_dict["next.done"] = (
torch.tensor([transition["done"]], dtype=torch.bool).cpu().squeeze(dim=0)
)
# Add to the dataset's buffer
lerobot_dataset.add_frame(frame_dict)
# Move to next frame
frame_idx_in_episode += 1
# If we reached an episode boundary, call save_episode, reset counters
if transition["done"]:
# Use some placeholder name for the task
lerobot_dataset.save_episode(task="from_replay_buffer")
episode_index += 1
frame_idx_in_episode = 0
# Start a new buffer for the next episode
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index)
# We are done adding frames
# If the last transition wasn't done=True, we still have an open buffer with frames.
# We'll consider that an incomplete episode and still save it:
if lerobot_dataset.episode_buffer["size"] > 0:
lerobot_dataset.save_episode(task=task_name)
lerobot_dataset.stop_image_writer()
lerobot_dataset.consolidate(run_compute_stats=False, keep_image_files=False)
return lerobot_dataset
# Utility function to guess shapes/dtypes from a tensor
def guess_feature_info(t: torch.Tensor, name: str):
"""
Return a dictionary with the 'dtype' and 'shape' for a given tensor or array.
If it looks like a 3D (C,H,W) shape, we might consider it an 'image'.
Otherwise default to 'float32' for numeric. You can customize as needed.
"""
shape = tuple(t.shape)
# Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image'
if len(shape) == 3 and shape[0] in [1, 3]:
return {
"dtype": "image",
"shape": shape,
}
else:
# Otherwise treat as numeric
return {
"dtype": "float32",
"shape": shape,
}
def concatenate_batch_transitions(
left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
) -> BatchTransition:
"""NOTE: Be careful it change the left_batch_transitions in place"""
left_batch_transitions["state"] = {
key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0)
for key in left_batch_transitions["state"]
}
left_batch_transitions["action"] = torch.cat(
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
)
left_batch_transitions["reward"] = torch.cat(
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
)
left_batch_transitions["next_state"] = {
key: torch.cat(
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0
)
for key in left_batch_transitions["next_state"]
}
left_batch_transitions["done"] = torch.cat(
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
)
return left_batch_transitions
# if __name__ == "__main__":
# dataset_name = "aractingi/push_green_cube_hf_cropped_resized"
# dataset = LeRobotDataset(repo_id=dataset_name)
# replay_buffer = ReplayBuffer.from_lerobot_dataset(
# lerobot_dataset=dataset, state_keys=["observation.image", "observation.state"]
# )
# replay_buffer_converted = replay_buffer.to_lerobot_dataset(repo_id="AdilZtn/pusht_image_converted")
# for i in range(len(replay_buffer_converted)):
# replay_convert = replay_buffer_converted[i]
# dataset_convert = dataset[i]
# for key in replay_convert.keys():
# if key in {"index", "episode_index", "frame_index", "timestamp", "task_index"}:
# continue
# if key in dataset_convert.keys():
# assert torch.equal(replay_convert[key], dataset_convert[key])
# print(f"Key {key} is equal : {replay_convert[key].size()}, {dataset_convert[key].size()}")
# re_reconverted_dataset = ReplayBuffer.from_lerobot_dataset(
# replay_buffer_converted, state_keys=["observation.image", "observation.state"], device="cpu"
# )
# for _ in range(20):
# batch = re_reconverted_dataset.sample(32)
# for key in batch.keys():
# if key in {"state", "next_state"}:
# for key_state in batch[key].keys():
# print(key_state, batch[key][key_state].size())
# continue
# print(key, batch[key].size())

View File

@@ -1,287 +0,0 @@
import argparse # noqa: I001
import json
from copy import deepcopy
from typing import Dict, Tuple
from pathlib import Path
import cv2
# import torch.nn.functional as F # noqa: N812
import torchvision.transforms.functional as F # type: ignore # noqa: N812
from tqdm import tqdm # type: ignore
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
def select_rect_roi(img):
"""
Allows the user to draw a rectangular ROI on the image.
The user must click and drag to draw the rectangle.
- While dragging, the rectangle is dynamically drawn.
- On mouse button release, the rectangle is fixed.
- Press 'c' to confirm the selection.
- Press 'r' to reset the selection.
- Press ESC to cancel.
Returns:
A tuple (top, left, height, width) representing the rectangular ROI,
or None if no valid ROI is selected.
"""
# Create a working copy of the image
clone = img.copy()
working_img = clone.copy()
roi = None # Will store the final ROI as (top, left, height, width)
drawing = False
ix, iy = -1, -1 # Initial click coordinates
def mouse_callback(event, x, y, flags, param):
nonlocal ix, iy, drawing, roi, working_img
if event == cv2.EVENT_LBUTTONDOWN:
# Start drawing: record starting coordinates
drawing = True
ix, iy = x, y
elif event == cv2.EVENT_MOUSEMOVE:
if drawing:
# Compute the top-left and bottom-right corners regardless of drag direction
top = min(iy, y)
left = min(ix, x)
bottom = max(iy, y)
right = max(ix, x)
# Show a temporary image with the current rectangle drawn
temp = working_img.copy()
cv2.rectangle(temp, (left, top), (right, bottom), (0, 255, 0), 2)
cv2.imshow("Select ROI", temp)
elif event == cv2.EVENT_LBUTTONUP:
# Finish drawing
drawing = False
top = min(iy, y)
left = min(ix, x)
bottom = max(iy, y)
right = max(ix, x)
height = bottom - top
width = right - left
roi = (top, left, height, width) # (top, left, height, width)
# Draw the final rectangle on the working image and display it
working_img = clone.copy()
cv2.rectangle(working_img, (left, top), (right, bottom), (0, 255, 0), 2)
cv2.imshow("Select ROI", working_img)
# Create the window and set the callback
cv2.namedWindow("Select ROI")
cv2.setMouseCallback("Select ROI", mouse_callback)
cv2.imshow("Select ROI", working_img)
print("Instructions for ROI selection:")
print(" - Click and drag to draw a rectangular ROI.")
print(" - Press 'c' to confirm the selection.")
print(" - Press 'r' to reset and draw again.")
print(" - Press ESC to cancel the selection.")
# Wait until the user confirms with 'c', resets with 'r', or cancels with ESC
while True:
key = cv2.waitKey(1) & 0xFF
# Confirm ROI if one has been drawn
if key == ord("c") and roi is not None:
break
# Reset: clear the ROI and restore the original image
elif key == ord("r"):
working_img = clone.copy()
roi = None
cv2.imshow("Select ROI", working_img)
# Cancel selection for this image
elif key == 27: # ESC key
roi = None
break
cv2.destroyWindow("Select ROI")
return roi
def select_square_roi_for_images(images: dict) -> dict:
"""
For each image in the provided dictionary, open a window to allow the user
to select a rectangular ROI. Returns a dictionary mapping each key to a tuple
(top, left, height, width) representing the ROI.
Parameters:
images (dict): Dictionary where keys are identifiers and values are OpenCV images.
Returns:
dict: Mapping of image keys to the selected rectangular ROI.
"""
selected_rois = {}
for key, img in images.items():
if img is None:
print(f"Image for key '{key}' is None, skipping.")
continue
print(f"\nSelect rectangular ROI for image with key: '{key}'")
roi = select_rect_roi(img)
if roi is None:
print(f"No valid ROI selected for '{key}'.")
else:
selected_rois[key] = roi
print(f"ROI for '{key}': {roi}")
return selected_rois
def get_image_from_lerobot_dataset(dataset: LeRobotDataset):
"""
Find the first row in the dataset and extract the image in order to be used for the crop.
"""
row = dataset[0]
image_dict = {}
for k in row:
if "image" in k:
image_dict[k] = deepcopy(row[k])
return image_dict
def convert_lerobot_dataset_to_cropper_lerobot_dataset(
original_dataset: LeRobotDataset,
crop_params_dict: Dict[str, Tuple[int, int, int, int]],
new_repo_id: str,
new_dataset_root: str,
resize_size: Tuple[int, int] = (128, 128),
) -> LeRobotDataset:
"""
Converts an existing LeRobotDataset by iterating over its episodes and frames,
applying cropping and resizing to image observations, and saving a new dataset
with the transformed data.
Args:
original_dataset (LeRobotDataset): The source dataset.
crop_params_dict (Dict[str, Tuple[int, int, int, int]]):
A dictionary mapping observation keys to crop parameters (top, left, height, width).
new_repo_id (str): Repository id for the new dataset.
new_dataset_root (str): The root directory where the new dataset will be written.
resize_size (Tuple[int, int], optional): The target size (height, width) after cropping.
Defaults to (128, 128).
Returns:
LeRobotDataset: A new LeRobotDataset where the specified image observations have been cropped
and resized.
"""
# 1. Create a new (empty) LeRobotDataset for writing.
new_dataset = LeRobotDataset.create(
repo_id=new_repo_id,
fps=original_dataset.fps,
root=new_dataset_root,
robot_type=original_dataset.meta.robot_type,
features=original_dataset.meta.info["features"],
use_videos=len(original_dataset.meta.video_keys) > 0,
)
# Update the metadata for every image key that will be cropped:
# (Here we simply set the shape to be the final resize_size.)
for key in crop_params_dict:
if key in new_dataset.meta.info["features"]:
new_dataset.meta.info["features"][key]["shape"] = list(resize_size)
# 2. Process each episode in the original dataset.
episodes_info = original_dataset.meta.episodes
# (Sort episodes by episode_index for consistency.)
episodes_info = sorted(episodes_info, key=lambda x: x["episode_index"])
# Use the first task from the episode metadata (or "unknown" if not provided)
task = episodes_info[0]["tasks"][0] if episodes_info[0].get("tasks") else "unknown"
last_episode_index = 0
for sample in tqdm(original_dataset):
episode_index = sample.pop("episode_index")
if episode_index != last_episode_index:
new_dataset.save_episode(task, encode_videos=True)
last_episode_index = episode_index
sample.pop("frame_index")
# Make a shallow copy of the sample (the values—e.g. torch tensors—are assumed immutable)
new_sample = sample.copy()
# Loop over each observation key that should be cropped/resized.
for key, params in crop_params_dict.items():
if key in new_sample:
top, left, height, width = params
# Apply crop then resize.
cropped = F.crop(new_sample[key], top, left, height, width)
resized = F.resize(cropped, resize_size)
new_sample[key] = resized
# Add the transformed frame to the new dataset.
new_dataset.add_frame(new_sample)
# save last episode
new_dataset.save_episode(task, encode_videos=True)
# Optionally, consolidate the new dataset to compute statistics and update video info.
new_dataset.consolidate(run_compute_stats=True, keep_image_files=True)
new_dataset.push_to_hub(tags=None)
return new_dataset
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.")
parser.add_argument(
"--repo-id",
type=str,
default="lerobot",
help="The repository id of the LeRobot dataset to process.",
)
parser.add_argument(
"--root",
type=str,
default=None,
help="The root directory of the LeRobot dataset.",
)
parser.add_argument(
"--crop-params-path",
type=str,
default=None,
help="The path to the JSON file containing the ROIs.",
)
args = parser.parse_args()
local_files_only = args.root is not None
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=local_files_only)
images = get_image_from_lerobot_dataset(dataset)
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
images = {k: (v * 255).astype("uint8") for k, v in images.items()}
if args.crop_params_path is None:
rois = select_square_roi_for_images(images)
else:
with open(args.crop_params_path, "r") as f:
rois = json.load(f)
# rois = {
# "observation.images.front": [102, 43, 358, 523],
# "observation.images.side": [92, 123, 379, 349],
# }
# Print the selected rectangular ROIs
print("\nSelected Rectangular Regions of Interest (top, left, height, width):")
for key, roi in rois.items():
print(f"{key}: {roi}")
new_repo_id = args.repo_id + "_cropped_resized"
new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
croped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
original_dataset=dataset,
crop_params_dict=rois,
new_repo_id=new_repo_id,
new_dataset_root=new_dataset_root,
resize_size=(128, 128),
)
meta_dir = new_dataset_root / "meta"
meta_dir.mkdir(exist_ok=True)
with open(meta_dir / "crop_params.json", "w") as f:
json.dump(rois, f, indent=4)

View File

@@ -1,65 +0,0 @@
import argparse
import time
import cv2
import numpy as np
from lerobot.common.robot_devices.control_utils import is_headless
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.utils.utils import init_hydra_config
def find_joint_bounds(
robot,
control_time_s=20,
display_cameras=False,
):
# TODO(rcadene): Add option to record logs
if not robot.is_connected:
robot.connect()
control_time_s = float("inf")
timestamp = 0
start_episode_t = time.perf_counter()
pos_list = []
while timestamp < control_time_s:
observation, action = robot.teleop_step(record_data=True)
pos_list.append(robot.follower_arms["main"].read("Present_Position"))
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
timestamp = time.perf_counter() - start_episode_t
if timestamp > 60:
max = np.max(np.stack(pos_list), 0)
min = np.min(np.stack(pos_list), 0)
print(f"Max angle position per joint {max}")
print(f"Min angle position per joint {min}")
break
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--robot-path",
type=str,
default="lerobot/configs/robot/koch.yaml",
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
)
parser.add_argument(
"--robot-overrides",
type=str,
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds")
args = parser.parse_args()
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
robot = make_robot(robot_cfg)
find_joint_bounds(robot, control_time_s=args.control_time_s)

View File

@@ -1,861 +0,0 @@
import argparse
import logging
import time
from threading import Lock
from typing import Annotated, Any, Callable, Dict, Optional, Tuple
import gymnasium as gym
import numpy as np
import torch
import torchvision.transforms.functional as F # noqa: N812
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.utils.utils import init_hydra_config, log_say
from lerobot.scripts.server.maniskill_manipulator import make_maniskill
logging.basicConfig(level=logging.INFO)
class HILSerlRobotEnv(gym.Env):
"""
Gym-compatible environment for evaluating robotic control policies with integrated human intervention.
This environment wraps a robot interface to provide a consistent API for policy evaluation. It supports both relative (delta)
and absolute joint position commands and automatically configures its observation and action spaces based on the robot's
sensors and configuration.
The environment can switch between executing actions from a policy or using teleoperated actions (human intervention) during
each step. When teleoperation is used, the override action is captured and returned in the `info` dict along with a flag
`is_intervention`.
"""
def __init__(
self,
robot,
use_delta_action_space: bool = True,
delta: float | None = None,
display_cameras: bool = False,
):
"""
Initialize the HILSerlRobotEnv environment.
The environment is set up with a robot interface, which is used to capture observations and send joint commands. The setup
supports both relative (delta) adjustments and absolute joint positions for controlling the robot.
Args:
robot: The robot interface object used to connect and interact with the physical robot.
use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute
joint positions are used.
delta (float or None): A scaling factor for the relative adjustments applied to joint positions. Should be a value between
0 and 1 when using a delta action space.
display_cameras (bool): If True, the robot's camera feeds will be displayed during execution.
"""
super().__init__()
self.robot = robot
self.display_cameras = display_cameras
# Connect to the robot if not already connected.
if not self.robot.is_connected:
self.robot.connect()
self.initial_follower_position = robot.follower_arms["main"].read("Present_Position")
# Episode tracking.
self.current_step = 0
self.episode_data = None
self.delta = delta
self.use_delta_action_space = use_delta_action_space
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
# Retrieve the size of the joint position interval bound.
self.relative_bounds_size = (
self.robot.config.joint_position_relative_bounds["max"]
- self.robot.config.joint_position_relative_bounds["min"]
)
self.delta_relative_bounds_size = self.relative_bounds_size * self.delta
self.robot.config.max_relative_target = self.delta_relative_bounds_size.float()
# Dynamically configure the observation and action spaces.
self._setup_spaces()
def _setup_spaces(self):
"""
Dynamically configure the observation and action spaces based on the robot's capabilities.
Observation Space:
- For keys with "image": A Box space with pixel values ranging from 0 to 255.
- For non-image keys: A nested Dict space is created under 'observation.state' with a suitable range.
Action Space:
- The action space is defined as a Tuple where:
• The first element is a Box space representing joint position commands. It is defined as relative (delta)
or absolute, based on the configuration.
• The second element is a Discrete space (with 2 values) serving as a flag for intervention (teleoperation).
"""
example_obs = self.robot.capture_observation()
# Define observation spaces for images and other states.
image_keys = [key for key in example_obs if "image" in key]
state_keys = [key for key in example_obs if "image" not in key]
observation_spaces = {
key: gym.spaces.Box(low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8)
for key in image_keys
}
observation_spaces["observation.state"] = gym.spaces.Dict(
{
key: gym.spaces.Box(low=0, high=10, shape=example_obs[key].shape, dtype=np.float32)
for key in state_keys
}
)
self.observation_space = gym.spaces.Dict(observation_spaces)
# Define the action space for joint positions along with setting an intervention flag.
action_dim = len(self.robot.follower_arms["main"].read("Present_Position"))
if self.use_delta_action_space:
action_space_robot = gym.spaces.Box(
low=-self.relative_bounds_size.cpu().numpy(),
high=self.relative_bounds_size.cpu().numpy(),
shape=(action_dim,),
dtype=np.float32,
)
else:
action_space_robot = gym.spaces.Box(
low=self.robot.config.joint_position_relative_bounds["min"].cpu().numpy(),
high=self.robot.config.joint_position_relative_bounds["max"].cpu().numpy(),
shape=(action_dim,),
dtype=np.float32,
)
self.action_space = gym.spaces.Tuple(
(
action_space_robot,
gym.spaces.Discrete(2),
),
)
def reset(self, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
"""
Reset the environment to its initial state.
This method resets the step counter and clears any episodic data.
Args:
seed (Optional[int]): A seed for random number generation to ensure reproducibility.
options (Optional[dict]): Additional options to influence the reset behavior.
Returns:
A tuple containing:
- observation (dict): The initial sensor observation.
- info (dict): A dictionary with supplementary information, including the key "initial_position".
"""
super().reset(seed=seed, options=options)
# Capture the initial observation.
observation = self.robot.capture_observation()
# Reset episode tracking variables.
self.current_step = 0
self.episode_data = None
return observation, {"initial_position": self.initial_follower_position}
def step(
self, action: Tuple[np.ndarray, bool]
) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]:
"""
Execute a single step within the environment using the specified action.
The provided action is a tuple comprised of:
• A policy action (joint position commands) that may be either in absolute values or as a delta.
• A boolean flag indicating whether teleoperation (human intervention) should be used for this step.
Behavior:
- When the intervention flag is False, the environment processes and sends the policy action to the robot.
- When True, a teleoperation step is executed. If using a delta action space, an absolute teleop action is converted
to relative change based on the current joint positions.
Args:
action (tuple): A tuple with two elements:
- policy_action (np.ndarray or torch.Tensor): The commanded joint positions.
- intervention_bool (bool): True if the human operator intervenes by providing a teleoperation input.
Returns:
tuple: A tuple containing:
- observation (dict): The new sensor observation after taking the step.
- reward (float): The step reward (default is 0.0 within this wrapper).
- terminated (bool): True if the episode has reached a terminal state.
- truncated (bool): True if the episode was truncated (e.g., time constraints).
- info (dict): Additional debugging information including:
"action_intervention": The teleop action if intervention was used.
"is_intervention": Flag indicating whether teleoperation was employed.
"""
policy_action, intervention_bool = action
teleop_action = None
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
if isinstance(policy_action, torch.Tensor):
policy_action = policy_action.cpu().numpy()
policy_action = np.clip(policy_action, self.action_space[0].low, self.action_space[0].high)
if not intervention_bool:
if self.use_delta_action_space:
target_joint_positions = self.current_joint_positions + self.delta * policy_action
else:
target_joint_positions = policy_action
self.robot.send_action(torch.from_numpy(target_joint_positions))
observation = self.robot.capture_observation()
else:
observation, teleop_action = self.robot.teleop_step(record_data=True)
teleop_action = teleop_action["action"] # Convert tensor to appropriate format
# When applying the delta action space, convert teleop absolute values to relative differences.
if self.use_delta_action_space:
teleop_action = (teleop_action - self.current_joint_positions) / self.delta
if torch.any(teleop_action < -self.relative_bounds_size) and torch.any(
teleop_action > self.relative_bounds_size
):
logging.debug(
f"Relative teleop delta exceeded bounds {self.relative_bounds_size}, teleop_action {teleop_action}\n"
f"lower bounds condition {teleop_action < -self.relative_bounds_size}\n"
f"upper bounds condition {teleop_action > self.relative_bounds_size}"
)
teleop_action = torch.clamp(
teleop_action, -self.relative_bounds_size, self.relative_bounds_size
)
# NOTE: To mimic the shape of a neural network output, we add a batch dimension to the teleop action.
if teleop_action.dim() == 1:
teleop_action = teleop_action.unsqueeze(0)
# self.render()
self.current_step += 1
reward = 0.0
terminated = False
truncated = False
return (
observation,
reward,
terminated,
truncated,
{"action_intervention": teleop_action, "is_intervention": teleop_action is not None},
)
def render(self):
"""
Render the current state of the environment by displaying the robot's camera feeds.
"""
import cv2
observation = self.robot.capture_observation()
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
def close(self):
"""
Close the environment and clean up resources by disconnecting the robot.
If the robot is currently connected, this method properly terminates the connection to ensure that all
associated resources are released.
"""
if self.robot.is_connected:
self.robot.disconnect()
class ActionRepeatWrapper(gym.Wrapper):
def __init__(self, env, nb_repeat: int = 1):
super().__init__(env)
self.nb_repeat = nb_repeat
def step(self, action):
for _ in range(self.nb_repeat):
obs, reward, done, truncated, info = self.env.step(action)
if done or truncated:
break
return obs, reward, done, truncated, info
class RewardWrapper(gym.Wrapper):
def __init__(self, env, reward_classifier, device: torch.device = "cuda"):
"""
Wrapper to add reward prediction to the environment, it use a trained classifer.
Args:
env: The environment to wrap
reward_classifier: The reward classifier model
device: The device to run the model on
"""
self.env = env
# NOTE: We got 15% speedup by compiling the model
self.reward_classifier = torch.compile(reward_classifier)
if isinstance(device, str):
device = torch.device(device)
self.device = device
def step(self, action):
observation, _, terminated, truncated, info = self.env.step(action)
images = [
observation[key].to(self.device, non_blocking=self.device.type == "cuda")
for key in observation
if "image" in key
]
start_time = time.perf_counter()
with torch.inference_mode():
reward = (
self.reward_classifier.predict_reward(images, threshold=0.8)
if self.reward_classifier is not None
else 0.0
)
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
# logging.info(f"Reward: {reward}")
if reward == 1.0:
terminated = True
return observation, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
return self.env.reset(seed=seed, options=options)
class JointMaskingActionSpace(gym.Wrapper):
def __init__(self, env, mask):
"""
Wrapper to mask out dimensions of the action space.
Args:
env: The environment to wrap
mask: Binary mask array where 0 indicates dimensions to remove
"""
super().__init__(env)
# Validate mask matches action space
# Keep only dimensions where mask is 1
self.active_dims = np.where(mask)[0]
if isinstance(env.action_space, gym.spaces.Box):
if len(mask) != env.action_space.shape[0]:
raise ValueError("Mask length must match action space dimensions")
low = env.action_space.low[self.active_dims]
high = env.action_space.high[self.active_dims]
self.action_space = gym.spaces.Box(low=low, high=high, dtype=env.action_space.dtype)
if isinstance(env.action_space, gym.spaces.Tuple):
if len(mask) != env.action_space[0].shape[0]:
raise ValueError("Mask length must match action space 0 dimensions")
low = env.action_space[0].low[self.active_dims]
high = env.action_space[0].high[self.active_dims]
action_space_masked = gym.spaces.Box(low=low, high=high, dtype=env.action_space[0].dtype)
self.action_space = gym.spaces.Tuple((action_space_masked, env.action_space[1]))
# Create new action space with masked dimensions
def action(self, action):
"""
Convert masked action back to full action space.
Args:
action: Action in masked space. For Tuple spaces, the first element is masked.
Returns:
Action in original space with masked dims set to 0.
"""
# Determine whether we are handling a Tuple space or a Box.
if isinstance(self.env.action_space, gym.spaces.Tuple):
# Extract the masked component from the tuple.
masked_action = action[0] if isinstance(action, tuple) else action
# Create a full action for the Box element.
full_box_action = np.zeros(self.env.action_space[0].shape, dtype=self.env.action_space[0].dtype)
full_box_action[self.active_dims] = masked_action
# Return a tuple with the reconstructed Box action and the unchanged remainder.
return (full_box_action, action[1])
else:
# For Box action spaces.
masked_action = action if not isinstance(action, tuple) else action[0]
full_action = np.zeros(self.env.action_space.shape, dtype=self.env.action_space.dtype)
full_action[self.active_dims] = masked_action
return full_action
def step(self, action):
action = self.action(action)
obs, reward, terminated, truncated, info = self.env.step(action)
if "action_intervention" in info and info["action_intervention"] is not None:
if info["action_intervention"].dim() == 1:
info["action_intervention"] = info["action_intervention"][self.active_dims]
else:
info["action_intervention"] = info["action_intervention"][:, self.active_dims]
return obs, reward, terminated, truncated, info
class TimeLimitWrapper(gym.Wrapper):
def __init__(self, env, control_time_s, fps):
self.env = env
self.control_time_s = control_time_s
self.fps = fps
self.last_timestamp = 0.0
self.episode_time_in_s = 0.0
self.max_episode_steps = int(self.control_time_s * self.fps)
self.current_step = 0
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
time_since_last_step = time.perf_counter() - self.last_timestamp
self.episode_time_in_s += time_since_last_step
self.last_timestamp = time.perf_counter()
self.current_step += 1
# check if last timestep took more time than the expected fps
if 1.0 / time_since_last_step < self.fps:
logging.debug(f"Current timestep exceeded expected fps {self.fps}")
if self.episode_time_in_s > self.control_time_s:
# if self.current_step >= self.max_episode_steps:
# Terminated = True
terminated = True
return obs, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
self.episode_time_in_s = 0.0
self.last_timestamp = time.perf_counter()
self.current_step = 0
return self.env.reset(seed=seed, options=options)
class ImageCropResizeWrapper(gym.Wrapper):
def __init__(self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None):
super().__init__(env)
self.env = env
self.crop_params_dict = crop_params_dict
print(f"obs_keys , {self.env.observation_space}")
print(f"crop params dict {crop_params_dict.keys()}")
for key_crop in crop_params_dict:
if key_crop not in self.env.observation_space.keys(): # noqa: SIM118
raise ValueError(f"Key {key_crop} not in observation space")
for key in crop_params_dict:
top, left, height, width = crop_params_dict[key]
new_shape = (top + height, left + width)
self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape)
self.resize_size = resize_size
if self.resize_size is None:
self.resize_size = (128, 128)
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
for k in self.crop_params_dict:
device = obs[k].device
# Check for NaNs before processing
if torch.isnan(obs[k]).any():
logging.error(f"NaN values detected in observation {k} before crop and resize")
if device == torch.device("mps:0"):
obs[k] = obs[k].cpu()
obs[k] = F.crop(obs[k], *self.crop_params_dict[k])
obs[k] = F.resize(obs[k], self.resize_size)
# Check for NaNs after processing
if torch.isnan(obs[k]).any():
logging.error(f"NaN values detected in observation {k} after crop and resize")
obs[k] = obs[k].to(device)
return obs, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
obs, info = self.env.reset(seed=seed, options=options)
for k in self.crop_params_dict:
device = obs[k].device
if device == torch.device("mps:0"):
obs[k] = obs[k].cpu()
obs[k] = F.crop(obs[k], *self.crop_params_dict[k])
obs[k] = F.resize(obs[k], self.resize_size)
obs[k] = obs[k].to(device)
return obs, info
class ConvertToLeRobotObservation(gym.ObservationWrapper):
def __init__(self, env, device):
super().__init__(env)
if isinstance(device, str):
device = torch.device(device)
self.device = device
def observation(self, observation):
observation = preprocess_observation(observation)
observation = {
key: observation[key].to(self.device, non_blocking=self.device.type == "cuda")
for key in observation
}
observation = {k: torch.tensor(v, device=self.device) for k, v in observation.items()}
return observation
class KeyboardInterfaceWrapper(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self.listener = None
self.events = {
"exit_early": False,
"pause_policy": False,
"reset_env": False,
"human_intervention_step": False,
"episode_success": False,
}
self.event_lock = Lock() # Thread-safe access to events
self._init_keyboard_listener()
def _init_keyboard_listener(self):
"""Initialize keyboard listener if not in headless mode"""
if is_headless():
logging.warning(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
return
try:
from pynput import keyboard
def on_press(key):
with self.event_lock:
try:
if key == keyboard.Key.right or key == keyboard.Key.esc:
print("Right arrow key pressed. Exiting loop...")
self.events["exit_early"] = True
return
if hasattr(key, "char") and key.char == "s":
print("Key 's' pressed. Episode success triggered.")
self.events["episode_success"] = True
return
if key == keyboard.Key.space and not self.events["exit_early"]:
if not self.events["pause_policy"]:
print(
"Space key pressed. Human intervention required.\n"
"Place the leader in similar pose to the follower and press space again."
)
self.events["pause_policy"] = True
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
return
if self.events["pause_policy"] and not self.events["human_intervention_step"]:
self.events["human_intervention_step"] = True
print("Space key pressed. Human intervention starting.")
log_say("Starting human intervention.", play_sounds=True)
return
if self.events["pause_policy"] and self.events["human_intervention_step"]:
self.events["pause_policy"] = False
self.events["human_intervention_step"] = False
print("Space key pressed for a third time.")
log_say("Continuing with policy actions.", play_sounds=True)
return
except Exception as e:
print(f"Error handling key press: {e}")
self.listener = keyboard.Listener(on_press=on_press)
self.listener.start()
except ImportError:
logging.warning("Could not import pynput. Keyboard interface will not be available.")
self.listener = None
def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]:
is_intervention = False
terminated_by_keyboard = False
# Extract policy_action if needed
if isinstance(self.env.action_space, gym.spaces.Tuple):
policy_action = action[0]
# Check the event flags without holding the lock for too long.
with self.event_lock:
if self.events["exit_early"]:
terminated_by_keyboard = True
pause_policy = self.events["pause_policy"]
if pause_policy:
# Now, wait for human_intervention_step without holding the lock
while True:
with self.event_lock:
if self.events["human_intervention_step"]:
is_intervention = True
break
time.sleep(0.1) # Check more frequently if desired
# Execute the step in the underlying environment
obs, reward, terminated, truncated, info = self.env.step((policy_action, is_intervention))
# Override reward and termination if episode success event triggered
with self.event_lock:
if self.events["episode_success"]:
reward = 1
terminated_by_keyboard = True
return obs, reward, terminated or terminated_by_keyboard, truncated, info
def reset(self, **kwargs) -> Tuple[Any, Dict]:
"""
Reset the environment and clear any pending events
"""
with self.event_lock:
self.events = {k: False for k in self.events}
return self.env.reset(**kwargs)
def close(self):
"""
Properly clean up the keyboard listener when the environment is closed
"""
if self.listener is not None:
self.listener.stop()
super().close()
class ResetWrapper(gym.Wrapper):
def __init__(
self, env: HILSerlRobotEnv, reset_fn: Optional[Callable[[], None]] = None, reset_time_s: float = 5
):
super().__init__(env)
self.reset_fn = reset_fn
self.reset_time_s = reset_time_s
self.robot = self.unwrapped.robot
self.init_pos = self.unwrapped.initial_follower_position
def reset(self, *, seed=None, options=None):
if self.reset_fn is not None:
self.reset_fn(self.env)
else:
log_say(f"Manually reset the environment for {self.reset_time_s} seconds.", play_sounds=True)
start_time = time.perf_counter()
while time.perf_counter() - start_time < self.reset_time_s:
self.robot.teleop_step()
log_say("Manual reseting of the environment done.", play_sounds=True)
return super().reset(seed=seed, options=options)
class BatchCompitableWrapper(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
for key in observation:
if "image" in key and observation[key].dim() == 3:
observation[key] = observation[key].unsqueeze(0)
if "state" in key and observation[key].dim() == 1:
observation[key] = observation[key].unsqueeze(0)
return observation
# TODO: REMOVE TH
def make_robot_env(
robot,
reward_classifier,
cfg,
n_envs: int = 1,
) -> gym.vector.VectorEnv:
"""
Factory function to create a vectorized robot environment.
Args:
robot: Robot instance to control
reward_classifier: Classifier model for computing rewards
cfg: Configuration object containing environment parameters
n_envs: Number of environments to create in parallel. Defaults to 1.
Returns:
A vectorized gym environment with all the necessary wrappers applied.
"""
if "maniskill" in cfg.name:
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
env = make_maniskill(
task=cfg.task,
obs_mode=cfg.obs,
control_mode=cfg.control_mode,
render_mode=cfg.render_mode,
sensor_configs={"width": cfg.render_size, "height": cfg.render_size},
device=cfg.device,
)
return env
# Create base environment
env = HILSerlRobotEnv(
robot=robot,
display_cameras=cfg.wrapper.display_cameras,
delta=cfg.wrapper.delta_action,
use_delta_action_space=cfg.wrapper.use_relative_joint_positions,
)
# Add observation and image processing
env = ConvertToLeRobotObservation(env=env, device=cfg.device)
if cfg.wrapper.crop_params_dict is not None:
env = ImageCropResizeWrapper(
env=env, crop_params_dict=cfg.wrapper.crop_params_dict, resize_size=cfg.wrapper.resize_size
)
# Add reward computation and control wrappers
env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
env = KeyboardInterfaceWrapper(env=env)
env = ResetWrapper(env=env, reset_fn=None, reset_time_s=cfg.wrapper.reset_time_s)
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
env = BatchCompitableWrapper(env=env)
return env
# batched version of the env that returns an observation of shape (b, c)
def get_classifier(pretrained_path, config_path, device="mps"):
if pretrained_path is None or config_path is None:
return None
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
cfg = init_hydra_config(config_path)
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
model = Classifier(classifier_config)
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
model = model.to(device)
return model
def replay_episode(env, repo_id, root=None, episode=0):
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
local_files_only = root is not None
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
actions = dataset.hf_dataset.select_columns("action")
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
action = actions[idx]["action"][:4]
print(action)
env.step((action / env.unwrapped.delta, False))
dt_s = time.perf_counter() - start_episode_t
busy_wait(1 / 10 - dt_s)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fps", type=int, default=30, help="control frequency")
parser.add_argument(
"--robot-path",
type=str,
default="lerobot/configs/robot/koch.yaml",
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
)
parser.add_argument(
"--robot-overrides",
type=str,
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
parser.add_argument(
"-p",
"--pretrained-policy-name-or-path",
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
"(useful for debugging). This argument is mutually exclusive with `--config`."
),
)
parser.add_argument(
"--config",
help=(
"Path to a yaml config you want to use for initializing a policy from scratch (useful for "
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
),
)
parser.add_argument(
"--display-cameras", help=("Whether to display the camera feed while the rollout is happening")
)
parser.add_argument(
"--reward-classifier-pretrained-path",
type=str,
default=None,
help="Path to the pretrained classifier weights.",
)
parser.add_argument(
"--reward-classifier-config-file",
type=str,
default=None,
help="Path to a yaml config file that is necessary to build the reward classifier model.",
)
parser.add_argument("--env-path", type=str, default=None, help="Path to the env yaml file")
parser.add_argument("--env-overrides", type=str, default=None, help="Overrides for the env yaml file")
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds")
parser.add_argument("--reset-follower-pos", type=int, default=1, help="Reset follower between episodes")
parser.add_argument("--replay-repo-id", type=str, default=None, help="Repo ID of the episode to replay")
parser.add_argument("--replay-root", type=str, default=None, help="Root of the dataset to replay")
parser.add_argument("--replay-episode", type=int, default=0, help="Episode to replay")
args = parser.parse_args()
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
robot = make_robot(robot_cfg)
reward_classifier = get_classifier(
args.reward_classifier_pretrained_path, args.reward_classifier_config_file
)
user_relative_joint_positions = True
cfg = init_hydra_config(args.env_path, args.env_overrides)
env = make_robot_env(
robot,
reward_classifier,
cfg.env, # .wrapper,
)
env.reset()
if args.replay_repo_id is not None:
replay_episode(env, args.replay_repo_id, root=args.replay_root, episode=args.replay_episode)
exit()
# Retrieve the robot's action space for joint commands.
action_space_robot = env.action_space.spaces[0]
# Initialize the smoothed action as a random sample.
smoothed_action = action_space_robot.sample()
# Smoothing coefficient (alpha) defines how much of the new random sample to mix in.
# A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth.
alpha = 0.4
while True:
start_loop_s = time.perf_counter()
# Sample a new random action from the robot's action space.
new_random_action = action_space_robot.sample()
# Update the smoothed action using an exponential moving average.
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
# Execute the step: wrap the NumPy action in a torch tensor.
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
if terminated or truncated:
env.reset()
dt_s = time.perf_counter() - start_loop_s
busy_wait(1 / args.fps - dt_s)

View File

@@ -1,58 +0,0 @@
// !/usr/bin/env python
// Copyright 2024 The HuggingFace Inc. team.
// All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package hil_serl;
// LearnerService: the Actor calls this to push transitions.
// The Learner implements this service.
service LearnerService {
// Actor -> Learner to store transitions
rpc SendTransition(Transition) returns (Empty);
rpc SendInteractionMessage(InteractionMessage) returns (Empty);
}
// ActorService: the Learner calls this to push parameters.
// The Actor implements this service.
service ActorService {
// Learner -> Actor to send new parameters
rpc StreamTransition(Empty) returns (stream ActorInformation) {};
rpc SendParameters(Parameters) returns (Empty);
}
message ActorInformation {
oneof data {
Transition transition = 1;
InteractionMessage interaction_message = 2;
}
}
// Messages
message Transition {
bytes transition_bytes = 1;
}
message Parameters {
bytes parameter_bytes = 1;
}
message InteractionMessage {
bytes interaction_message_bytes = 1;
}
message Empty {}

View File

@@ -1,48 +0,0 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: hilserl.proto
# Protobuf Python Version: 5.29.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
29,
0,
'',
'hilserl.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"\x83\x01\n\x10\x41\x63torInformation\x12*\n\ntransition\x18\x01 \x01(\x0b\x32\x14.hil_serl.TransitionH\x00\x12;\n\x13interaction_message\x18\x02 \x01(\x0b\x32\x1c.hil_serl.InteractionMessageH\x00\x42\x06\n\x04\x64\x61ta\"&\n\nTransition\x12\x18\n\x10transition_bytes\x18\x01 \x01(\x0c\"%\n\nParameters\x12\x17\n\x0fparameter_bytes\x18\x01 \x01(\x0c\"7\n\x12InteractionMessage\x12!\n\x19interaction_message_bytes\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty2\x92\x01\n\x0eLearnerService\x12\x37\n\x0eSendTransition\x12\x14.hil_serl.Transition\x1a\x0f.hil_serl.Empty\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty2\x8c\x01\n\x0c\x41\x63torService\x12\x43\n\x10StreamTransition\x12\x0f.hil_serl.Empty\x1a\x1a.hil_serl.ActorInformation\"\x00\x30\x01\x12\x37\n\x0eSendParameters\x12\x14.hil_serl.Parameters\x1a\x0f.hil_serl.Emptyb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'hilserl_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_ACTORINFORMATION']._serialized_start=28
_globals['_ACTORINFORMATION']._serialized_end=159
_globals['_TRANSITION']._serialized_start=161
_globals['_TRANSITION']._serialized_end=199
_globals['_PARAMETERS']._serialized_start=201
_globals['_PARAMETERS']._serialized_end=238
_globals['_INTERACTIONMESSAGE']._serialized_start=240
_globals['_INTERACTIONMESSAGE']._serialized_end=295
_globals['_EMPTY']._serialized_start=297
_globals['_EMPTY']._serialized_end=304
_globals['_LEARNERSERVICE']._serialized_start=307
_globals['_LEARNERSERVICE']._serialized_end=453
_globals['_ACTORSERVICE']._serialized_start=456
_globals['_ACTORSERVICE']._serialized_end=596
# @@protoc_insertion_point(module_scope)

View File

@@ -1,269 +0,0 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
import hilserl_pb2 as hilserl__pb2
GRPC_GENERATED_VERSION = '1.70.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in hilserl_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
)
class LearnerServiceStub(object):
"""LearnerService: the Actor calls this to push transitions.
The Learner implements this service.
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.SendTransition = channel.unary_unary(
'/hil_serl.LearnerService/SendTransition',
request_serializer=hilserl__pb2.Transition.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
_registered_method=True)
self.SendInteractionMessage = channel.unary_unary(
'/hil_serl.LearnerService/SendInteractionMessage',
request_serializer=hilserl__pb2.InteractionMessage.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
_registered_method=True)
class LearnerServiceServicer(object):
"""LearnerService: the Actor calls this to push transitions.
The Learner implements this service.
"""
def SendTransition(self, request, context):
"""Actor -> Learner to store transitions
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendInteractionMessage(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_LearnerServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'SendTransition': grpc.unary_unary_rpc_method_handler(
servicer.SendTransition,
request_deserializer=hilserl__pb2.Transition.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
),
'SendInteractionMessage': grpc.unary_unary_rpc_method_handler(
servicer.SendInteractionMessage,
request_deserializer=hilserl__pb2.InteractionMessage.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'hil_serl.LearnerService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('hil_serl.LearnerService', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class LearnerService(object):
"""LearnerService: the Actor calls this to push transitions.
The Learner implements this service.
"""
@staticmethod
def SendTransition(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/hil_serl.LearnerService/SendTransition',
hilserl__pb2.Transition.SerializeToString,
hilserl__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendInteractionMessage(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/hil_serl.LearnerService/SendInteractionMessage',
hilserl__pb2.InteractionMessage.SerializeToString,
hilserl__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
class ActorServiceStub(object):
"""ActorService: the Learner calls this to push parameters.
The Actor implements this service.
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.StreamTransition = channel.unary_stream(
'/hil_serl.ActorService/StreamTransition',
request_serializer=hilserl__pb2.Empty.SerializeToString,
response_deserializer=hilserl__pb2.ActorInformation.FromString,
_registered_method=True)
self.SendParameters = channel.unary_unary(
'/hil_serl.ActorService/SendParameters',
request_serializer=hilserl__pb2.Parameters.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
_registered_method=True)
class ActorServiceServicer(object):
"""ActorService: the Learner calls this to push parameters.
The Actor implements this service.
"""
def StreamTransition(self, request, context):
"""Learner -> Actor to send new parameters
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendParameters(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_ActorServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'StreamTransition': grpc.unary_stream_rpc_method_handler(
servicer.StreamTransition,
request_deserializer=hilserl__pb2.Empty.FromString,
response_serializer=hilserl__pb2.ActorInformation.SerializeToString,
),
'SendParameters': grpc.unary_unary_rpc_method_handler(
servicer.SendParameters,
request_deserializer=hilserl__pb2.Parameters.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'hil_serl.ActorService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('hil_serl.ActorService', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class ActorService(object):
"""ActorService: the Learner calls this to push parameters.
The Actor implements this service.
"""
@staticmethod
def StreamTransition(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_stream(
request,
target,
'/hil_serl.ActorService/StreamTransition',
hilserl__pb2.Empty.SerializeToString,
hilserl__pb2.ActorInformation.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendParameters(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/hil_serl.ActorService/SendParameters',
hilserl__pb2.Parameters.SerializeToString,
hilserl__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

View File

@@ -1,676 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import logging
import pickle
import queue
import shutil
import time
from pprint import pformat
from threading import Lock, Thread
import grpc
# Import generated stubs
import hilserl_pb2 # type: ignore
import hilserl_pb2_grpc # type: ignore
import hydra
import torch
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from torch import nn
from torch.optim.optimizer import Optimizer
from lerobot.common.datasets.factory import make_dataset
# TODO: Remove the import of maniskill
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.utils.utils import (
format_big_number,
get_global_random_state,
get_safe_torch_device,
init_hydra_config,
init_logging,
set_global_random_state,
set_global_seed,
)
from lerobot.scripts.server.buffer import (
ReplayBuffer,
concatenate_batch_transitions,
move_state_dict_to_device,
move_transition_to_device,
)
logging.basicConfig(level=logging.INFO)
transition_queue = queue.Queue()
interaction_message_queue = queue.Queue()
def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
if not cfg.resume:
if Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError(
f"Output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. "
"Use `resume=true` to resume training."
)
return cfg
# if resume == True
checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir)
if not checkpoint_dir.exists():
raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True")
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
logging.info(
colored(
"Resume=True detected, resuming previous run",
color="yellow",
attrs=["bold"],
)
)
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
del diff["values_changed"]["root['resume']"]
if len(diff) > 0:
logging.warning(
f"Differences between the checkpoint config and the provided config detected: \n{pformat(diff)}\n"
"Checkpoint configuration takes precedence."
)
checkpoint_cfg.resume = True
return checkpoint_cfg
def load_training_state(
cfg: DictConfig,
logger: Logger,
optimizers: Optimizer | dict,
):
if not cfg.resume:
return None, None
training_state = torch.load(logger.last_checkpoint_dir / logger.training_state_file_name)
if isinstance(training_state["optimizer"], dict):
assert set(training_state["optimizer"].keys()) == set(optimizers.keys())
for k, v in training_state["optimizer"].items():
optimizers[k].load_state_dict(v)
else:
optimizers.load_state_dict(training_state["optimizer"])
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
return training_state["step"], training_state["interaction_step"]
def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
log_output_dir(out_dir)
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.training.online_steps=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
def initialize_replay_buffer(cfg: DictConfig, logger: Logger, device: str) -> ReplayBuffer:
if not cfg.resume:
return ReplayBuffer(
capacity=cfg.training.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
)
dataset = LeRobotDataset(
repo_id=cfg.dataset_repo_id, local_files_only=True, root=logger.log_dir / "dataset"
)
return ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=dataset,
capacity=cfg.training.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
)
def start_learner_threads(
cfg: DictConfig,
device: str,
replay_buffer: ReplayBuffer,
offline_replay_buffer: ReplayBuffer,
batch_size: int,
optimizers: dict,
policy: SACPolicy,
policy_lock: Lock,
logger: Logger,
resume_optimization_step: int | None = None,
resume_interaction_step: int | None = None,
) -> None:
actor_ip = cfg.actor_learner_config.actor_ip
port = cfg.actor_learner_config.port
server_thread = Thread(
target=stream_transitions_from_actor,
args=(
actor_ip,
port,
),
daemon=True,
)
transition_thread = Thread(
target=add_actor_information_and_train,
daemon=True,
args=(
cfg,
device,
replay_buffer,
offline_replay_buffer,
batch_size,
optimizers,
policy,
policy_lock,
logger,
resume_optimization_step,
resume_interaction_step,
),
)
param_push_thread = Thread(
target=learner_push_parameters,
args=(policy, policy_lock, actor_ip, port, 15),
daemon=True,
)
server_thread.start()
transition_thread.start()
param_push_thread.start()
param_push_thread.join()
transition_thread.join()
server_thread.join()
def stream_transitions_from_actor(host="127.0.0.1", port=50051):
"""
Runs a gRPC client that listens for transition and interaction messages from an Actor service.
This function establishes a gRPC connection with the given `host` and `port`, then continuously
streams transition data from the `ActorServiceStub`. The received transition data is deserialized
and stored in a queue (`transition_queue`). Similarly, interaction messages are also deserialized
and stored in a separate queue (`interaction_message_queue`).
Args:
host (str, optional): The IP address or hostname of the gRPC server. Defaults to `"127.0.0.1"`.
port (int, optional): The port number on which the gRPC server is running. Defaults to `50051`.
"""
# NOTE: This is waiting for the handshake to be done
# In the future we will do it in a canonical way with a proper handshake
time.sleep(10)
channel = grpc.insecure_channel(
f"{host}:{port}",
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
)
stub = hilserl_pb2_grpc.ActorServiceStub(channel)
for response in stub.StreamTransition(hilserl_pb2.Empty()):
if response.HasField("transition"):
buffer = io.BytesIO(response.transition.transition_bytes)
transition = torch.load(buffer)
transition_queue.put(transition)
if response.HasField("interaction_message"):
content = pickle.loads(response.interaction_message.interaction_message_bytes)
interaction_message_queue.put(content)
def learner_push_parameters(
policy: nn.Module, policy_lock: Lock, actor_host="127.0.0.1", actor_port=50052, seconds_between_pushes=5
):
"""
As a client, connect to the Actor's gRPC server (ActorService)
and periodically push new parameters.
"""
time.sleep(10)
channel = grpc.insecure_channel(
f"{actor_host}:{actor_port}",
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
)
actor_stub = hilserl_pb2_grpc.ActorServiceStub(channel)
while True:
with policy_lock:
params_dict = policy.actor.state_dict()
if policy.config.vision_encoder_name is not None:
if policy.config.freeze_vision_encoder:
params_dict: dict[str, torch.Tensor] = {
k: v for k, v in params_dict.items() if not k.startswith("encoder.")
}
else:
raise NotImplementedError(
"Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
)
params_dict = move_state_dict_to_device(params_dict, device="cpu")
# Serialize
buf = io.BytesIO()
torch.save(params_dict, buf)
params_bytes = buf.getvalue()
# Push them to the Actor's "SendParameters" method
logging.info("[LEARNER] Publishing parameters to the Actor")
response = actor_stub.SendParameters(hilserl_pb2.Parameters(parameter_bytes=params_bytes)) # noqa: F841
time.sleep(seconds_between_pushes)
def check_nan_in_transition(observations: torch.Tensor, actions: torch.Tensor, next_state: torch.Tensor):
for k in observations:
if torch.isnan(observations[k]).any():
logging.error(f"observations[{k}] contains NaN values")
for k in next_state:
if torch.isnan(next_state[k]).any():
logging.error(f"next_state[{k}] contains NaN values")
if torch.isnan(actions).any():
logging.error("actions contains NaN values")
def add_actor_information_and_train(
cfg,
device: str,
replay_buffer: ReplayBuffer,
offline_replay_buffer: ReplayBuffer,
batch_size: int,
optimizers: dict[str, torch.optim.Optimizer],
policy: nn.Module,
policy_lock: Lock,
logger: Logger,
resume_optimization_step: int | None = None,
resume_interaction_step: int | None = None,
):
"""
Handles data transfer from the actor to the learner, manages training updates,
and logs training progress in an online reinforcement learning setup.
This function continuously:
- Transfers transitions from the actor to the replay buffer.
- Logs received interaction messages.
- Ensures training begins only when the replay buffer has a sufficient number of transitions.
- Samples batches from the replay buffer and performs multiple critic updates.
- Periodically updates the actor, critic, and temperature optimizers.
- Logs training statistics, including loss values and optimization frequency.
**NOTE:**
- This function performs multiple responsibilities (data transfer, training, and logging).
It should ideally be split into smaller functions in the future.
- Due to Python's **Global Interpreter Lock (GIL)**, running separate threads for different tasks
significantly reduces performance. Instead, this function executes all operations in a single thread.
Args:
cfg: Configuration object containing hyperparameters.
device (str): The computing device (`"cpu"` or `"cuda"`).
replay_buffer (ReplayBuffer): The primary replay buffer storing online transitions.
offline_replay_buffer (ReplayBuffer): An additional buffer for offline transitions.
batch_size (int): The number of transitions to sample per training step.
optimizers (Dict[str, torch.optim.Optimizer]): A dictionary of optimizers (`"actor"`, `"critic"`, `"temperature"`).
policy (nn.Module): The reinforcement learning policy with critic, actor, and temperature parameters.
policy_lock (Lock): A threading lock to ensure safe policy updates.
logger (Logger): Logger instance for tracking training progress.
resume_optimization_step (int | None): In the case of resume training, start from the last optimization step reached.
resume_interaction_step (int | None): In the case of resume training, shift the interaction step with the last saved step in order to not break logging.
"""
# NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
# are divided by 200. So we need to have a single thread that does all the work.
time.time()
logging.info("Starting learner thread")
interaction_message, transition = None, None
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
while True:
while not transition_queue.empty():
transition_list = transition_queue.get()
for transition in transition_list:
transition = move_transition_to_device(transition, device=device)
replay_buffer.add(**transition)
if transition.get("complementary_info", {}).get("is_intervention"):
offline_replay_buffer.add(**transition)
while not interaction_message_queue.empty():
interaction_message = interaction_message_queue.get()
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
interaction_message["Interaction step"] += interaction_step_shift
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
# logging.info(f"Interaction message: {interaction_message}")
if len(replay_buffer) < cfg.training.online_step_before_learning:
continue
# logging.info(f"Size of replay buffer: {len(replay_buffer)}")
# logging.info(f"Size of offline replay buffer: {len(offline_replay_buffer)}")
time_for_one_optimization_step = time.time()
for _ in range(cfg.policy.utd_ratio - 1):
batch = replay_buffer.sample(batch_size)
if cfg.dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(batch, batch_offline)
actions = batch["action"]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
# Precompute encoder features from the frozen vision encoder if enabled
obs_features, next_obs_features = None, None
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
with torch.no_grad():
obs_features = (
policy.actor.encoder(observations)
if policy.actor.encoder is not None
else None
)
next_obs_features = (
policy.actor.encoder(next_observations)
if policy.actor.encoder is not None
else None
)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
obs_features=obs_features, # pass precomputed features
next_obs_features=next_obs_features, # for target computation
)
optimizers["critic"].zero_grad()
loss_critic.backward()
optimizers["critic"].step()
batch = replay_buffer.sample(batch_size)
if cfg.dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
actions = batch["action"]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
# Precompute encoder features from the frozen vision encoder if enabled
obs_features, next_obs_features = None, None
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
with torch.no_grad():
obs_features = (
policy.actor.encoder(observations)
if policy.actor.encoder is not None
else None
)
next_obs_features = (
policy.actor.encoder(next_observations)
if policy.actor.encoder is not None
else None
)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
obs_features=obs_features, # pass precomputed features
next_obs_features=next_obs_features, # for target computation
)
optimizers["critic"].zero_grad()
loss_critic.backward()
optimizers["critic"].step()
training_infos = {}
training_infos["loss_critic"] = loss_critic.item()
if optimization_step % cfg.training.policy_update_freq == 0:
for _ in range(cfg.training.policy_update_freq):
with policy_lock:
loss_actor = policy.compute_loss_actor(
observations=observations,
obs_features=obs_features, # reuse precomputed features here
)
optimizers["actor"].zero_grad()
loss_actor.backward()
optimizers["actor"].step()
training_infos["loss_actor"] = loss_actor.item()
loss_temperature = policy.compute_loss_temperature(
observations=observations,
obs_features=obs_features, # and for temperature loss as well
)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
optimizers["temperature"].step()
training_infos["loss_temperature"] = loss_temperature.item()
policy.update_target_networks()
if optimization_step % cfg.training.log_freq == 0:
training_infos["Optimization step"] = optimization_step
logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step")
# logging.info(f"Training infos: {training_infos}")
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
logger.log_dict(
{
"Optimization frequency loop [Hz]": frequency_for_one_optimization_step,
"Optimization step": optimization_step,
},
mode="train",
custom_step_key="Optimization step",
)
optimization_step += 1
if optimization_step % cfg.training.log_freq == 0:
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
if cfg.training.save_checkpoint and (
optimization_step % cfg.training.save_freq == 0 or optimization_step == cfg.training.online_steps
):
logging.info(f"Checkpoint policy after step {optimization_step}")
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
# needed (choose 6 as a minimum for consistency without being overkill).
_num_digits = max(6, len(str(cfg.training.online_steps)))
step_identifier = f"{optimization_step:0{_num_digits}d}"
interaction_step = (
interaction_message["Interaction step"] if interaction_message is not None else 0
)
logger.save_checkpoint(
optimization_step,
policy,
optimizers,
scheduler=None,
identifier=step_identifier,
interaction_step=interaction_step,
)
# TODO : temporarly save replay buffer here, remove later when on the robot
# We want to control this with the keyboard inputs
dataset_dir = logger.log_dir / "dataset"
if dataset_dir.exists() and dataset_dir.is_dir():
shutil.rmtree(
dataset_dir,
)
replay_buffer.to_lerobot_dataset(
cfg.dataset_repo_id, fps=cfg.fps, root=logger.log_dir / "dataset"
)
logging.info("Resume training")
def make_optimizers_and_scheduler(cfg, policy: nn.Module):
"""
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
This function sets up Adam optimizers for:
- The **actor network**, ensuring that only relevant parameters are optimized.
- The **critic ensemble**, which evaluates the value function.
- The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods.
It also initializes a learning rate scheduler, though currently, it is set to `None`.
**NOTE:**
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
Args:
cfg: Configuration object containing hyperparameters.
policy (nn.Module): The policy model containing the actor, critic, and temperature components.
Returns:
Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]:
A tuple containing:
- `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers.
- `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling.
"""
optimizer_actor = torch.optim.Adam(
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
params=policy.actor.parameters_to_optimize,
lr=policy.config.actor_lr,
)
optimizer_critic = torch.optim.Adam(
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
lr_scheduler = None
optimizers = {
"actor": optimizer_actor,
"critic": optimizer_critic,
"temperature": optimizer_temperature,
}
return optimizers, lr_scheduler
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
if out_dir is None:
raise NotImplementedError()
if job_name is None:
raise NotImplementedError()
init_logging()
logging.info(pformat(OmegaConf.to_container(cfg)))
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
cfg = handle_resume_logic(cfg, out_dir)
set_global_seed(cfg.seed)
device = get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
logging.info("make_policy")
### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy intance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
# TODO: At some point we should just need make sac policy
policy_lock = Lock()
policy: SACPolicy = make_policy(
hydra_cfg=cfg,
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online traning, we do not need dataset_stats
dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
)
# compile policy
# policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers)
log_training_info(cfg, out_dir, policy)
replay_buffer = initialize_replay_buffer(cfg, logger, device)
batch_size = cfg.training.batch_size
offline_replay_buffer = None
if cfg.dataset_repo_id is not None:
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
logging.info("Convertion to a offline replay buffer")
active_action_dims = [i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask]
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
action_mask=active_action_dims,
action_delta=cfg.env.wrapper.delta_action,
)
batch_size: int = batch_size // 2 # We will sample from both replay buffer
start_learner_threads(
cfg,
device,
replay_buffer,
offline_replay_buffer,
batch_size,
optimizers,
policy,
policy_lock,
logger,
resume_optimization_step,
resume_interaction_step,
)
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
def train_cli(cfg: dict):
train(
cfg,
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
)
if __name__ == "__main__":
train_cli()

View File

@@ -1,176 +0,0 @@
import einops
import numpy as np
import gymnasium as gym
import torch
"""Make ManiSkill3 gym environment"""
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, torch.Tensor]:
"""Convert environment observation to LeRobot format observation.
Args:
observation: Dictionary of observation batches from a Gym vector environment.
Returns:
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
"""
# map to expected inputs for the policy
return_observations = {}
# TODO: You have to merge all tensors from agent key and extra key
# You don't keep sensor param key in the observation
# And you keep sensor data rgb
q_pos = observations["agent"]["qpos"]
q_vel = observations["agent"]["qvel"]
tcp_pos = observations["extra"]["tcp_pose"]
img = observations["sensor_data"]["base_camera"]["rgb"]
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
img /= 255
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
return_observations["observation.image"] = img
return_observations["observation.state"] = state
return return_observations
class ManiSkillObservationWrapper(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
def observation(self, observation):
return preprocess_maniskill_observation(observation)
class ManiSkillToDeviceWrapper(gym.Wrapper):
def __init__(self, env, device: torch.device = "cuda"):
super().__init__(env)
self.device = device
def reset(self, seed=None, options=None):
obs, info = self.env.reset(seed=seed, options=options)
obs = {k: v.to(self.device) for k, v in obs.items()}
return obs, info
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
obs = {k: v.to(self.device) for k, v in obs.items()}
return obs, reward, terminated, truncated, info
class ManiSkillCompat(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
reward = reward.item()
terminated = terminated.item()
truncated = truncated.item()
return obs, reward, terminated, truncated, info
class ManiSkillActionWrapper(gym.ActionWrapper):
def __init__(self, env):
super().__init__(env)
self.action_space = gym.spaces.Tuple(spaces=(env.action_space, gym.spaces.Discrete(2)))
def action(self, action):
action, telop = action
return action
class ManiSkillMultiplyActionWrapper(gym.Wrapper):
def __init__(self, env, multiply_factor: float = 10):
super().__init__(env)
self.multiply_factor = multiply_factor
action_space_agent: gym.spaces.Box = env.action_space[0]
action_space_agent.low = action_space_agent.low * multiply_factor
action_space_agent.high = action_space_agent.high * multiply_factor
self.action_space = gym.spaces.Tuple(spaces=(action_space_agent, gym.spaces.Discrete(2)))
def step(self, action):
if isinstance(action, tuple):
action, telop = action
else:
telop = 0
action = action / self.multiply_factor
obs, reward, terminated, truncated, info = self.env.step((action, telop))
return obs, reward, terminated, truncated, info
def make_maniskill(
task: str = "PushCube-v1",
obs_mode: str = "rgb",
control_mode: str = "pd_ee_delta_pose",
render_mode: str = "rgb_array",
sensor_configs: dict[str, int] | None = None,
n_envs: int = 1,
device: torch.device = "cuda",
) -> gym.Env:
"""
Factory function to create a ManiSkill environment with standard wrappers.
Args:
task: Name of the ManiSkill task
obs_mode: Observation mode (rgb, rgbd, etc)
control_mode: Control mode for the robot
render_mode: Rendering mode
sensor_configs: Camera sensor configurations
n_envs: Number of parallel environments
Returns:
A wrapped ManiSkill environment
"""
if sensor_configs is None:
sensor_configs = {"width": 64, "height": 64}
env = gym.make(
task,
obs_mode=obs_mode,
control_mode=control_mode,
render_mode=render_mode,
sensor_configs=sensor_configs,
num_envs=n_envs,
)
env = ManiSkillCompat(env)
env = ManiSkillObservationWrapper(env)
env = ManiSkillActionWrapper(env)
env = ManiSkillMultiplyActionWrapper(env)
env = ManiSkillToDeviceWrapper(env, device=device)
return env
if __name__ == "__main__":
import argparse
import hydra
from omegaconf import OmegaConf
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="lerobot/configs/env/maniskill_example.yaml")
args = parser.parse_args()
# Initialize config
with hydra.initialize(version_base=None, config_path="../../configs"):
cfg = hydra.compose(config_name="env/maniskill_example.yaml")
env = make_maniskill(
task=cfg.env.task,
obs_mode=cfg.env.obs,
control_mode=cfg.env.control_mode,
render_mode=cfg.env.render_mode,
sensor_configs={"width": cfg.env.render_size, "height": cfg.env.render_size},
)
print("env done")
obs, info = env.reset()
random_action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(random_action)

View File

@@ -1,3 +1,5 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,7 +20,6 @@ from pathlib import Path
from pprint import pformat
import hydra
import numpy as np
import torch
import torch.nn as nn
import wandb
@@ -26,9 +27,8 @@ from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from torch import optim
from torch.autograd import profiler
from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler, random_split
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
from tqdm import tqdm
from lerobot.common.datasets.factory import resolve_delta_timestamps
@@ -43,7 +43,6 @@ from lerobot.common.utils.utils import (
init_hydra_config,
set_global_seed,
)
from lerobot.scripts.server.buffer import random_shift
def get_model(cfg, logger): # noqa I001
@@ -81,7 +80,6 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
for batch_idx, batch in enumerate(pbar):
start_time = time.perf_counter()
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
images = [random_shift(img, 4) for img in images]
labels = batch[cfg.training.label_key].float().to(device)
# Forward pass with optional AMP
@@ -118,7 +116,7 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{current_acc:.2f}%"})
def validate(model, val_loader, criterion, device, logger, cfg):
def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_log=8):
# Validation loop with metric tracking and sample logging
model.eval()
correct = 0
@@ -126,7 +124,6 @@ def validate(model, val_loader, criterion, device, logger, cfg):
batch_start_time = time.perf_counter()
samples = []
running_loss = 0
inference_times = []
with (
torch.no_grad(),
@@ -136,18 +133,7 @@ def validate(model, val_loader, criterion, device, logger, cfg):
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
labels = batch[cfg.training.label_key].float().to(device)
if cfg.training.profile_inference_time and logger._cfg.wandb.enable:
with (
profiler.profile(record_shapes=True) as prof,
profiler.record_function("model_inference"),
):
outputs = model(images)
inference_times.append(
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
)
else:
outputs = model(images)
outputs = model(images)
loss = criterion(outputs.logits, labels)
# Track metrics
@@ -160,18 +146,15 @@ def validate(model, val_loader, criterion, device, logger, cfg):
running_loss += loss.item()
# Log sample predictions for visualization
if len(samples) < cfg.eval.num_samples_to_log:
for i in range(min(cfg.eval.num_samples_to_log - len(samples), len(images))):
if len(samples) < num_samples_to_log:
for i in range(min(num_samples_to_log - len(samples), len(images))):
if model.config.num_classes == 2:
confidence = round(outputs.probabilities[i].item(), 3)
else:
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
samples.append(
{
**{
f"image_{img_key}": wandb.Image(images[img_idx][i].cpu())
for img_idx, img_key in enumerate(cfg.training.image_keys)
},
"image": wandb.Image(images[i].cpu()),
"true_label": labels[i].item(),
"predicted": predictions[i].item(),
"confidence": confidence,
@@ -187,83 +170,16 @@ def validate(model, val_loader, criterion, device, logger, cfg):
"accuracy": accuracy,
"eval_s": time.perf_counter() - batch_start_time,
"eval/prediction_samples": wandb.Table(
data=[list(s.values()) for s in samples],
columns=list(samples[0].keys()),
data=[[s["image"], s["true_label"], s["predicted"], f"{s['confidence']}"] for s in samples],
columns=["Image", "True Label", "Predicted", "Confidence"],
)
if logger._cfg.wandb.enable
else None,
}
if len(inference_times) > 0:
eval_info["inference_time_avg"] = np.mean(inference_times)
eval_info["inference_time_median"] = np.median(inference_times)
eval_info["inference_time_std"] = np.std(inference_times)
eval_info["inference_time_batch_size"] = val_loader.batch_size
print(
f"Inference mean time: {eval_info['inference_time_avg']:.2f} us, median: {eval_info['inference_time_median']:.2f} us, std: {eval_info['inference_time_std']:.2f} us, with {len(inference_times)} iterations on {device.type} device, batch size: {eval_info['inference_time_batch_size']}"
)
return accuracy, eval_info
def benchmark_inference_time(model, dataset, logger, cfg, device, step):
if not cfg.training.profile_inference_time:
return
iters = cfg.training.profile_inference_time_iters
inference_times = []
loader = DataLoader(
dataset,
batch_size=1,
num_workers=cfg.training.num_workers,
sampler=RandomSampler(dataset),
pin_memory=True,
)
model.eval()
with torch.no_grad():
for _ in tqdm(range(iters), desc="Benchmarking inference time"):
x = next(iter(loader))
x = [x[img_key].to(device) for img_key in cfg.training.image_keys]
# Warm up
for _ in range(10):
_ = model(x)
# sync the device
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "mps":
torch.mps.synchronize()
with profiler.profile(record_shapes=True) as prof, profiler.record_function("model_inference"):
_ = model(x)
inference_times.append(
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
)
inference_times = np.array(inference_times)
avg, median, std = inference_times.mean(), np.median(inference_times), inference_times.std()
print(
f"Inference time mean: {avg:.2f} us, median: {median:.2f} us, std: {std:.2f} us, with {iters} iterations on {device.type} device"
)
if logger._cfg.wandb.enable:
logger.log_dict(
{
"inference_time_benchmark_avg": avg,
"inference_time_benchmark_median": median,
"inference_time_benchmark_std": std,
},
step + 1,
mode="eval",
)
return avg, median, std
@hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier")
def train(cfg: DictConfig) -> None:
# Main training pipeline with support for resuming training
@@ -273,19 +189,17 @@ def train(cfg: DictConfig) -> None:
device = get_safe_torch_device(cfg.device, log=True)
set_global_seed(cfg.seed)
out_dir = hydra.core.hydra_config.HydraConfig.get().run.dir + "frozen_resnet10_2"
out_dir = Path(cfg.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
logger = Logger(cfg, out_dir, cfg.wandb.job_name if cfg.wandb.enable else None)
# Setup dataset and dataloaders
dataset = LeRobotDataset(
cfg.dataset_repo_id, root=cfg.dataset_root, local_files_only=cfg.local_files_only
)
dataset = LeRobotDataset(cfg.dataset_repo_id)
logging.info(f"Dataset size: {len(dataset)}")
n_total = len(dataset)
n_train = int(cfg.train_split_proportion * len(dataset))
train_dataset = torch.utils.data.Subset(dataset, range(0, n_train))
val_dataset = torch.utils.data.Subset(dataset, range(n_train, n_total))
train_size = int(cfg.train_split_proportion * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
sampler = create_balanced_sampler(train_dataset, cfg)
train_loader = DataLoader(
@@ -293,7 +207,7 @@ def train(cfg: DictConfig) -> None:
batch_size=cfg.training.batch_size,
num_workers=cfg.training.num_workers,
sampler=sampler,
pin_memory=device.type == "cuda",
pin_memory=True,
)
val_loader = DataLoader(
@@ -301,7 +215,7 @@ def train(cfg: DictConfig) -> None:
batch_size=cfg.eval.batch_size,
shuffle=False,
num_workers=cfg.training.num_workers,
pin_memory=device.type == "cuda",
pin_memory=True,
)
# Resume training if requested
@@ -399,8 +313,6 @@ def train(cfg: DictConfig) -> None:
step += len(train_loader)
benchmark_inference_time(model, dataset, logger, cfg, device, step)
logging.info("Training completed")

View File

@@ -13,25 +13,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import logging
import random
import functools
from pprint import pformat
from typing import Callable, Optional, Sequence, TypedDict
import random
from typing import Optional, Sequence, TypedDict, Callable
import hydra
import torch
import torch.nn.functional as F
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from torch import nn
from tqdm import tqdm
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
# TODO: Remove the import of maniskill
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.envs.factory import make_env, make_maniskill_env
from lerobot.common.envs.utils import preprocess_maniskill_observation, preprocess_observation
from lerobot.common.envs.utils import preprocess_observation, preprocess_maniskill_observation
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
@@ -176,7 +177,6 @@ class ReplayBuffer:
)
self.position: int = (self.position + 1) % self.capacity
# TODO: ADD image_augmentation and use_drq arguments in this function in order to instantiate the class with them
@classmethod
def from_lerobot_dataset(
cls,