forked from tangger/lerobot
Compare commits
215 Commits
refactor/u
...
user/adil-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
51ae84cd48 | ||
|
|
b5869709ae | ||
|
|
98ad1cbae0 | ||
|
|
539dbd18ce | ||
|
|
fcd9da52eb | ||
|
|
1722f6e504 | ||
|
|
2051dd38fc | ||
|
|
c7a3973653 | ||
|
|
bfa775da46 | ||
|
|
15c7545b41 | ||
|
|
bbc6b7d841 | ||
|
|
69ece1407b | ||
|
|
b104f8b012 | ||
|
|
fb9bb89cb4 | ||
|
|
e22411ff22 | ||
|
|
7ee56676b9 | ||
|
|
bdd9229576 | ||
|
|
633edcb3af | ||
|
|
d6d8317050 | ||
|
|
b36ec31fea | ||
|
|
adbf8bb85e | ||
|
|
32fb13c81e | ||
|
|
6792c3de8f | ||
|
|
ad132c9c39 | ||
|
|
70d55c77e9 | ||
|
|
5998203a33 | ||
|
|
6fa7df35df | ||
|
|
fb7c288c94 | ||
|
|
4257fe5045 | ||
|
|
ea89b29fe5 | ||
|
|
50e9a8ed6a | ||
|
|
1d4f660075 | ||
|
|
bd4db8d747 | ||
|
|
a8da4a347e | ||
|
|
b8c2b0bb93 | ||
|
|
c58b504a9e | ||
|
|
671ac3411f | ||
|
|
299effe0f1 | ||
|
|
a0018240d5 | ||
|
|
cf03ca930f | ||
|
|
ecc960bf8a | ||
|
|
b77cee7cc6 | ||
|
|
5231752487 | ||
|
|
4ce3362724 | ||
|
|
6230840397 | ||
|
|
c5845ee203 | ||
|
|
0030ff3f74 | ||
|
|
dc726cb9a3 | ||
|
|
a7a51cfc9c | ||
|
|
0d70f0b85c | ||
|
|
c1ee25d9f7 | ||
|
|
9886520d33 | ||
|
|
3b24ad3c84 | ||
|
|
54c3c6d684 | ||
|
|
fb92935601 | ||
|
|
dcd850feab | ||
|
|
1ce368503d | ||
|
|
fb075a709d | ||
|
|
3424644ecd | ||
|
|
c37936f2c9 | ||
|
|
c5382a450c | ||
|
|
2f7339b410 | ||
|
|
9e5f254db0 | ||
|
|
8122721f6d | ||
|
|
5c352ae558 | ||
|
|
9386892f8e | ||
|
|
267a837a2c | ||
|
|
28b595c651 | ||
|
|
9fd4c21d4d | ||
|
|
02e1ed0bfb | ||
|
|
e18274bc9a | ||
|
|
68c271ad25 | ||
|
|
a3ada81816 | ||
|
|
203315d378 | ||
|
|
78c640b6d8 | ||
|
|
d5a87f67cf | ||
|
|
8bcf41761d | ||
|
|
1efaf02df9 | ||
|
|
cf58890bb0 | ||
|
|
7c2c67fc3c | ||
|
|
70130b9841 | ||
|
|
6167886472 | ||
|
|
f9fb9d4594 | ||
|
|
d86d29fe21 | ||
|
|
f83d215e7a | ||
|
|
7361a11a4d | ||
|
|
0cce2fe0fa | ||
|
|
88d26ae976 | ||
|
|
3a2308d86f | ||
|
|
fdd04efdb7 | ||
|
|
ff18be18ad | ||
|
|
427720426b | ||
|
|
66693965c0 | ||
|
|
334cf8143e | ||
|
|
5b49601072 | ||
|
|
0185a0b6fd | ||
|
|
70d418935d | ||
|
|
eb44a06a9b | ||
|
|
8eb3c1510c | ||
|
|
4d5ecb082e | ||
|
|
6e687e2910 | ||
|
|
eb710647bf | ||
|
|
176557d770 | ||
|
|
3beab33fac | ||
|
|
c0ba4b4954 | ||
|
|
8fb373aeb2 | ||
|
|
5a0ee06651 | ||
|
|
05a237ce10 | ||
|
|
88cc2b8fc8 | ||
|
|
b69132c79d | ||
|
|
db897a1619 | ||
|
|
0b5b62c8fb | ||
|
|
056f79d358 | ||
|
|
114ec644d0 | ||
|
|
26ee8b6ae5 | ||
|
|
38e8864284 | ||
|
|
80d566eb56 | ||
|
|
bb5a95889f | ||
|
|
0ea27704f6 | ||
|
|
2abbd60a0d | ||
|
|
1c8daf11fd | ||
|
|
cdcf346061 | ||
|
|
42f95e827d | ||
|
|
618ed00d45 | ||
|
|
50d8db481e | ||
|
|
e4a5971ffd | ||
|
|
36f9ccd851 | ||
|
|
787aee0e60 | ||
|
|
0341a38fdd | ||
|
|
ffbed4a141 | ||
|
|
03fe0f054b | ||
|
|
fd74c194b6 | ||
|
|
0959694bab | ||
|
|
7b01e16439 | ||
|
|
66816fd871 | ||
|
|
599326508f | ||
|
|
2f04d0d2b9 | ||
|
|
e002c5ec56 | ||
|
|
3dfb37e976 | ||
|
|
b6a2200983 | ||
|
|
85fe8a3f4e | ||
|
|
bb69cb3c8c | ||
|
|
ae51c19b3c | ||
|
|
9ea79f8a76 | ||
|
|
1d4ec50a58 | ||
|
|
4c73891575 | ||
|
|
d3b84ecd6f | ||
|
|
e1d55c7a44 | ||
|
|
85242cac67 | ||
|
|
0d88a5ee09 | ||
|
|
62e237bdee | ||
|
|
c85f88fb62 | ||
|
|
a90f4872f2 | ||
|
|
a16ea283f5 | ||
|
|
8209a6dfb7 | ||
|
|
b5fbeb7401 | ||
|
|
2ac25b02e2 | ||
|
|
39fe4b1301 | ||
|
|
140e30e386 | ||
|
|
ddcc0415e4 | ||
|
|
5195f40fd3 | ||
|
|
98c6557869 | ||
|
|
ee820859d3 | ||
|
|
5d6879d93a | ||
|
|
fae47d58d3 | ||
|
|
3a07301365 | ||
|
|
f1af97dc9c | ||
|
|
f2266101df | ||
|
|
9784d8a47f | ||
|
|
af769abd8d | ||
|
|
12c13e320e | ||
|
|
273fa2e6e1 | ||
|
|
d143043037 | ||
|
|
ca45c34ad5 | ||
|
|
b1679050de | ||
|
|
d2c41b35db | ||
|
|
bc7b6d3daf | ||
|
|
2516101cba | ||
|
|
aebea08a99 | ||
|
|
03616db82c | ||
|
|
93c4fc198f | ||
|
|
8cd44ae163 | ||
|
|
2ae657f568 | ||
|
|
508f5d1407 | ||
|
|
c8b1132846 | ||
|
|
ef777993cd | ||
|
|
760d60ad4b | ||
|
|
875c0271b7 | ||
|
|
57344bfde5 | ||
|
|
46827fb002 | ||
|
|
2fd78879f6 | ||
|
|
e8449e9630 | ||
|
|
a0e2be8b92 | ||
|
|
181727c0fe | ||
|
|
d1d6ffd23c | ||
|
|
e5801f467f | ||
|
|
c6ca9523de | ||
|
|
642e3a3274 | ||
|
|
146148c48c | ||
|
|
8f15835daa | ||
|
|
022bd65125 | ||
|
|
63d8c96514 | ||
|
|
4624a836e5 | ||
|
|
ad7eea132d | ||
|
|
22a1899ff4 | ||
|
|
17a3a31b5f | ||
|
|
1a8b99e360 | ||
|
|
6db2154f28 | ||
|
|
be3adda95f | ||
|
|
9d48d236c1 | ||
|
|
b57d6a7776 | ||
|
|
d1f76cba8e | ||
|
|
d78cef1fee | ||
|
|
30a808c0ae | ||
|
|
4a7f85a6ec |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -26,6 +26,7 @@ outputs
|
||||
|
||||
# VS Code
|
||||
.vscode
|
||||
.devcontainer
|
||||
|
||||
# HPC
|
||||
nautilus/*.yaml
|
||||
|
||||
@@ -46,7 +46,8 @@ repos:
|
||||
rev: v3.19.1
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
|
||||
# Exclude generated protobuf files
|
||||
exclude: '^(.*_pb2_grpc\.py|.*_pb2\.py$)'
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.5
|
||||
hooks:
|
||||
|
||||
13
README.md
13
README.md
@@ -408,6 +408,19 @@ Additionally, if you are using any of the particular policy architecture, pretra
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
- [HIL-SERL](https://hil-serl.github.io/)
|
||||
```bibtex
|
||||
@Article{luo2024hilserl,
|
||||
title={Precise and Dexterous Robotic Manipulation via Human-in-the-Loop Reinforcement Learning},
|
||||
author={Jianlan Luo and Charles Xu and Jeffrey Wu and Sergey Levine},
|
||||
year={2024},
|
||||
eprint={2410.21845},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.RO}
|
||||
}
|
||||
```
|
||||
## Star History
|
||||
|
||||
[](https://star-history.com/#huggingface/lerobot&Timeline)
|
||||
|
||||
@@ -14,10 +14,12 @@
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
|
||||
@@ -154,3 +156,122 @@ class XarmEnv(EnvConfig):
|
||||
"visualization_height": self.visualization_height,
|
||||
"max_episode_steps": self.episode_length,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoRecordConfig:
|
||||
"""Configuration for video recording in ManiSkill environments."""
|
||||
|
||||
enabled: bool = False
|
||||
record_dir: str = "videos"
|
||||
trajectory_name: str = "trajectory"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EEActionSpaceConfig:
|
||||
"""Configuration parameters for end-effector action space."""
|
||||
|
||||
x_step_size: float
|
||||
y_step_size: float
|
||||
z_step_size: float
|
||||
bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
|
||||
control_mode: str = "gamepad"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvTransformConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
|
||||
ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig)
|
||||
display_cameras: bool = False
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_current_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
crop_params_dict: Optional[Dict[str, Tuple[int, int, int, int]]] = None
|
||||
resize_size: Optional[Tuple[int, int]] = None
|
||||
control_time_s: float = 20.0
|
||||
fixed_reset_joint_positions: Optional[Any] = None
|
||||
reset_time_s: float = 5.0
|
||||
use_gripper: bool = False
|
||||
gripper_quantization_threshold: float | None = 0.8
|
||||
gripper_penalty: float = 0.0
|
||||
gripper_penalty_in_reward: bool = False
|
||||
|
||||
|
||||
@EnvConfig.register_subclass(name="gym_manipulator")
|
||||
@dataclass
|
||||
class HILSerlRobotEnvConfig(EnvConfig):
|
||||
"""Configuration for the HILSerlRobotEnv environment."""
|
||||
|
||||
robot: Optional[RobotConfig] = None
|
||||
wrapper: Optional[EnvTransformConfig] = None
|
||||
fps: int = 10
|
||||
name: str = "real_robot"
|
||||
mode: str = None # Either "record", "replay", None
|
||||
repo_id: Optional[str] = None
|
||||
dataset_root: Optional[str] = None
|
||||
task: str = ""
|
||||
num_episodes: int = 10 # only for record mode
|
||||
episode: int = 0
|
||||
device: str = "cuda"
|
||||
push_to_hub: bool = True
|
||||
pretrained_policy_name_or_path: Optional[str] = None
|
||||
reward_classifier_pretrained_path: Optional[str] = None
|
||||
# For the reward classifier, to record more positive examples after a success
|
||||
number_of_steps_after_success: int = 0
|
||||
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("hil")
|
||||
@dataclass
|
||||
class HILEnvConfig(EnvConfig):
|
||||
"""Configuration for the HIL environment."""
|
||||
|
||||
type: str = "hil"
|
||||
name: str = "PandaPickCube"
|
||||
task: str = "PandaPickCubeKeyboard-v0"
|
||||
use_viewer: bool = True
|
||||
gripper_penalty: float = 0.0
|
||||
use_gamepad: bool = True
|
||||
state_dim: int = 18
|
||||
action_dim: int = 4
|
||||
fps: int = 100
|
||||
episode_length: int = 100
|
||||
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(18,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"observation.image": OBS_IMAGE,
|
||||
"observation.state": OBS_ROBOT,
|
||||
}
|
||||
)
|
||||
################# args from hilserlrobotenv
|
||||
reward_classifier_pretrained_path: Optional[str] = None
|
||||
robot: Optional[RobotConfig] = None
|
||||
wrapper: Optional[EnvTransformConfig] = None
|
||||
mode: str = None # Either "record", "replay", None
|
||||
repo_id: Optional[str] = None
|
||||
dataset_root: Optional[str] = None
|
||||
num_episodes: int = 10 # only for record mode
|
||||
episode: int = 0
|
||||
device: str = "cuda"
|
||||
push_to_hub: bool = True
|
||||
pretrained_policy_name_or_path: Optional[str] = None
|
||||
############################
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"use_viewer": self.use_viewer,
|
||||
"use_gamepad": self.use_gamepad,
|
||||
"gripper_penalty": self.gripper_penalty,
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import importlib
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from lerobot.common.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv
|
||||
from lerobot.common.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv
|
||||
|
||||
|
||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
@@ -27,6 +27,8 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
return PushtEnv(**kwargs)
|
||||
elif env_type == "xarm":
|
||||
return XarmEnv(**kwargs)
|
||||
elif env_type == "hil":
|
||||
return HILEnvConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
|
||||
@@ -65,5 +67,8 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
|
||||
env = env_cls(
|
||||
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
|
||||
)
|
||||
# TODO: add observation processor wrapper and remove preprocess_observation in the codebase
|
||||
# https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/wrappers/vector/vectorize_observation.py#L19,
|
||||
# env = ObservationProcessorWrapper(env=env)
|
||||
|
||||
return env
|
||||
|
||||
@@ -47,6 +47,10 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
||||
img = torch.from_numpy(img)
|
||||
|
||||
# When preprocessing observations in a non-vectorized environment, we need to add a batch dimension.
|
||||
# This is the case for human-in-the-loop RL where there is only one environment.
|
||||
if img.ndim == 3:
|
||||
img = img.unsqueeze(0)
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
@@ -62,13 +66,18 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
return_observations[imgkey] = img
|
||||
|
||||
if "environment_state" in observations:
|
||||
return_observations["observation.environment_state"] = torch.from_numpy(
|
||||
observations["environment_state"]
|
||||
).float()
|
||||
env_state = torch.from_numpy(observations["environment_state"]).float()
|
||||
if env_state.dim() == 1:
|
||||
env_state = env_state.unsqueeze(0)
|
||||
|
||||
return_observations["observation.environment_state"] = env_state
|
||||
|
||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||
# requirement for "agent_pos"
|
||||
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
||||
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
|
||||
if agent_pos.dim() == 1:
|
||||
agent_pos = agent_pos.unsqueeze(0)
|
||||
return_observations["observation.state"] = agent_pos
|
||||
|
||||
return return_observations
|
||||
|
||||
|
||||
|
||||
@@ -14,8 +14,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
from dataclasses import asdict, dataclass
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
import torch
|
||||
@@ -44,7 +45,16 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
return "adam"
|
||||
|
||||
@abc.abstractmethod
|
||||
def build(self) -> torch.optim.Optimizer:
|
||||
def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
||||
"""
|
||||
Build the optimizer. It can be a single optimizer or a dictionary of optimizers.
|
||||
NOTE: Multiple optimizers are useful when you have different models to optimize.
|
||||
For example, you can have one optimizer for the policy and another one for the value function
|
||||
in reinforcement learning settings.
|
||||
|
||||
Returns:
|
||||
The optimizer or a dictionary of optimizers.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -94,7 +104,76 @@ class SGDConfig(OptimizerConfig):
|
||||
return torch.optim.SGD(params, **kwargs)
|
||||
|
||||
|
||||
def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
|
||||
@OptimizerConfig.register_subclass("multi_adam")
|
||||
@dataclass
|
||||
class MultiAdamConfig(OptimizerConfig):
|
||||
"""Configuration for multiple Adam optimizers with different parameter groups.
|
||||
|
||||
This creates a dictionary of Adam optimizers, each with its own hyperparameters.
|
||||
|
||||
Args:
|
||||
lr: Default learning rate (used if not specified for a group)
|
||||
weight_decay: Default weight decay (used if not specified for a group)
|
||||
optimizer_groups: Dictionary mapping parameter group names to their hyperparameters
|
||||
grad_clip_norm: Gradient clipping norm
|
||||
"""
|
||||
|
||||
lr: float = 1e-3
|
||||
weight_decay: float = 0.0
|
||||
grad_clip_norm: float = 10.0
|
||||
optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
|
||||
def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]:
|
||||
"""Build multiple Adam optimizers.
|
||||
|
||||
Args:
|
||||
params_dict: Dictionary mapping parameter group names to lists of parameters
|
||||
The keys should match the keys in optimizer_groups
|
||||
|
||||
Returns:
|
||||
Dictionary mapping parameter group names to their optimizers
|
||||
"""
|
||||
optimizers = {}
|
||||
|
||||
for name, params in params_dict.items():
|
||||
# Get group-specific hyperparameters or use defaults
|
||||
group_config = self.optimizer_groups.get(name, {})
|
||||
|
||||
# Create optimizer with merged parameters (defaults + group-specific)
|
||||
optimizer_kwargs = {
|
||||
"lr": group_config.get("lr", self.lr),
|
||||
"betas": group_config.get("betas", (0.9, 0.999)),
|
||||
"eps": group_config.get("eps", 1e-5),
|
||||
"weight_decay": group_config.get("weight_decay", self.weight_decay),
|
||||
}
|
||||
|
||||
optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs)
|
||||
|
||||
return optimizers
|
||||
|
||||
|
||||
def save_optimizer_state(
|
||||
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
|
||||
) -> None:
|
||||
"""Save optimizer state to disk.
|
||||
|
||||
Args:
|
||||
optimizer: Either a single optimizer or a dictionary of optimizers.
|
||||
save_dir: Directory to save the optimizer state.
|
||||
"""
|
||||
if isinstance(optimizer, dict):
|
||||
# Handle dictionary of optimizers
|
||||
for name, opt in optimizer.items():
|
||||
optimizer_dir = save_dir / name
|
||||
optimizer_dir.mkdir(exist_ok=True, parents=True)
|
||||
_save_single_optimizer_state(opt, optimizer_dir)
|
||||
else:
|
||||
# Handle single optimizer
|
||||
_save_single_optimizer_state(optimizer, save_dir)
|
||||
|
||||
|
||||
def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
|
||||
"""Save a single optimizer's state to disk."""
|
||||
state = optimizer.state_dict()
|
||||
param_groups = state.pop("param_groups")
|
||||
flat_state = flatten_dict(state)
|
||||
@@ -102,11 +181,44 @@ def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> No
|
||||
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
|
||||
|
||||
|
||||
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
|
||||
def load_optimizer_state(
|
||||
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
|
||||
) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
||||
"""Load optimizer state from disk.
|
||||
|
||||
Args:
|
||||
optimizer: Either a single optimizer or a dictionary of optimizers.
|
||||
save_dir: Directory to load the optimizer state from.
|
||||
|
||||
Returns:
|
||||
The updated optimizer(s) with loaded state.
|
||||
"""
|
||||
if isinstance(optimizer, dict):
|
||||
# Handle dictionary of optimizers
|
||||
loaded_optimizers = {}
|
||||
for name, opt in optimizer.items():
|
||||
optimizer_dir = save_dir / name
|
||||
if optimizer_dir.exists():
|
||||
loaded_optimizers[name] = _load_single_optimizer_state(opt, optimizer_dir)
|
||||
else:
|
||||
loaded_optimizers[name] = opt
|
||||
return loaded_optimizers
|
||||
else:
|
||||
# Handle single optimizer
|
||||
return _load_single_optimizer_state(optimizer, save_dir)
|
||||
|
||||
|
||||
def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
|
||||
"""Load a single optimizer's state from disk."""
|
||||
current_state_dict = optimizer.state_dict()
|
||||
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
||||
state = unflatten_dict(flat_state)
|
||||
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
|
||||
|
||||
# Handle case where 'state' key might not exist (for newly created optimizers)
|
||||
if "state" in state:
|
||||
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
|
||||
else:
|
||||
loaded_state_dict = {"state": {}}
|
||||
|
||||
if "param_groups" in current_state_dict:
|
||||
param_groups = deserialize_json_into_object(
|
||||
|
||||
@@ -27,6 +27,7 @@ from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionC
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
@@ -59,6 +60,14 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
|
||||
|
||||
return PI0FASTPolicy
|
||||
elif name == "sac":
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
return SACPolicy
|
||||
elif name == "reward_classifier":
|
||||
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||
|
||||
return Classifier
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
@@ -76,6 +85,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi0fast":
|
||||
return PI0FASTConfig(**kwargs)
|
||||
elif policy_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
||||
@@ -151,6 +151,7 @@ class Normalize(nn.Module):
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# TODO: Remove this shallow copy
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
@@ -252,3 +253,168 @@ class Unnormalize(nn.Module):
|
||||
else:
|
||||
raise ValueError(norm_mode)
|
||||
return batch
|
||||
|
||||
|
||||
# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
|
||||
# and remove the `Normalize` and `Unnormalize` classes.
|
||||
def _initialize_stats_buffers(
|
||||
module: nn.Module,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
) -> None:
|
||||
"""Register statistics buffers (mean/std or min/max) on the given *module*.
|
||||
|
||||
The logic matches the previous constructors of `NormalizeBuffer` and `UnnormalizeBuffer`,
|
||||
but is factored out so it can be reused by both classes and stay in sync.
|
||||
"""
|
||||
for key, ft in features.items():
|
||||
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
shape: tuple[int, ...] = tuple(ft.shape)
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
# reduce spatial dimensions, keep channel dimension only
|
||||
c, *_ = shape
|
||||
shape = (c, 1, 1)
|
||||
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
std = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
|
||||
if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
|
||||
mean_data = stats[key]["mean"]
|
||||
std_data = stats[key]["std"]
|
||||
if isinstance(mean_data, torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
mean = mean_data.clone().to(dtype=torch.float32)
|
||||
std = std_data.clone().to(dtype=torch.float32)
|
||||
else:
|
||||
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
||||
|
||||
module.register_buffer(f"{prefix}_mean", mean)
|
||||
module.register_buffer(f"{prefix}_std", std)
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
max_val = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
|
||||
if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
|
||||
min_data = stats[key]["min"]
|
||||
max_data = stats[key]["max"]
|
||||
if isinstance(min_data, torch.Tensor):
|
||||
min_val = min_data.clone().to(dtype=torch.float32)
|
||||
max_val = max_data.clone().to(dtype=torch.float32)
|
||||
else:
|
||||
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
||||
|
||||
module.register_buffer(f"{prefix}_min", min_val)
|
||||
module.register_buffer(f"{prefix}_max", max_val)
|
||||
continue
|
||||
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
|
||||
class NormalizeBuffer(nn.Module):
|
||||
"""Same as `Normalize` but statistics are stored as registered buffers rather than parameters."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
|
||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch)
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = getattr(self, f"{prefix}_mean")
|
||||
std = getattr(self, f"{prefix}_std")
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = getattr(self, f"{prefix}_min")
|
||||
max_val = getattr(self, f"{prefix}_max")
|
||||
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] - min_val) / (max_val - min_val + 1e-8)
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
continue
|
||||
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class UnnormalizeBuffer(nn.Module):
|
||||
"""Inverse operation of `NormalizeBuffer`. Uses registered buffers for statistics."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
|
||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# batch = dict(batch)
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = getattr(self, f"{prefix}_mean")
|
||||
std = getattr(self, f"{prefix}_std")
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = batch[key] * std + mean
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = getattr(self, f"{prefix}_min")
|
||||
max_val = getattr(self, f"{prefix}_max")
|
||||
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max_val - min_val) + min_val
|
||||
continue
|
||||
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
return batch
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig, OptimizerConfig
|
||||
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass(name="reward_classifier")
|
||||
@dataclass
|
||||
class RewardClassifierConfig(PreTrainedConfig):
|
||||
"""Configuration for the Reward Classifier model."""
|
||||
|
||||
name: str = "reward_classifier"
|
||||
num_classes: int = 2
|
||||
hidden_dim: int = 256
|
||||
latent_dim: int = 256
|
||||
image_embedding_pooling_dim: int = 8
|
||||
dropout_rate: float = 0.1
|
||||
model_name: str = "helper2424/resnet10"
|
||||
device: str = "cpu"
|
||||
model_type: str = "cnn" # "transformer" or "cnn"
|
||||
num_cameras: int = 2
|
||||
learning_rate: float = 1e-4
|
||||
weight_decay: float = 0.01
|
||||
grad_clip_norm: float = 1.0
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> List | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> List | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> List | None:
|
||||
return None
|
||||
|
||||
def get_optimizer_preset(self) -> OptimizerConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.learning_rate,
|
||||
weight_decay=self.weight_decay,
|
||||
grad_clip_norm=self.grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate feature configurations."""
|
||||
has_image = any(key.startswith("observation.image") for key in self.input_features)
|
||||
if not has_image:
|
||||
raise ValueError(
|
||||
"You must provide an image observation (key starting with 'observation.image') in the input features"
|
||||
)
|
||||
301
lerobot/common/policies/reward_model/modeling_classifier.py
Normal file
301
lerobot/common/policies/reward_model/modeling_classifier.py
Normal file
@@ -0,0 +1,301 @@
|
||||
import logging
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.constants import OBS_IMAGE
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
|
||||
|
||||
class ClassifierOutput:
|
||||
"""Wrapper for classifier outputs with additional metadata."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logits: Tensor,
|
||||
probabilities: Optional[Tensor] = None,
|
||||
hidden_states: Optional[Tensor] = None,
|
||||
):
|
||||
self.logits = logits
|
||||
self.probabilities = probabilities
|
||||
self.hidden_states = hidden_states
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"ClassifierOutput(logits={self.logits}, "
|
||||
f"probabilities={self.probabilities}, "
|
||||
f"hidden_states={self.hidden_states})"
|
||||
)
|
||||
|
||||
|
||||
class SpatialLearnedEmbeddings(nn.Module):
|
||||
def __init__(self, height, width, channel, num_features=8):
|
||||
"""
|
||||
PyTorch implementation of learned spatial embeddings
|
||||
|
||||
Args:
|
||||
height: Spatial height of input features
|
||||
width: Spatial width of input features
|
||||
channel: Number of input channels
|
||||
num_features: Number of output embedding dimensions
|
||||
"""
|
||||
super().__init__()
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.channel = channel
|
||||
self.num_features = num_features
|
||||
|
||||
self.kernel = nn.Parameter(torch.empty(channel, height, width, num_features))
|
||||
|
||||
nn.init.kaiming_normal_(self.kernel, mode="fan_in", nonlinearity="linear")
|
||||
|
||||
def forward(self, features):
|
||||
"""
|
||||
Forward pass for spatial embedding
|
||||
|
||||
Args:
|
||||
features: Input tensor of shape [B, H, W, C] or [H, W, C] if no batch
|
||||
Returns:
|
||||
Output tensor of shape [B, C*F] or [C*F] if no batch
|
||||
"""
|
||||
|
||||
features = features.last_hidden_state
|
||||
|
||||
original_shape = features.shape
|
||||
if features.dim() == 3:
|
||||
features = features.unsqueeze(0) # Add batch dim
|
||||
|
||||
features_expanded = features.unsqueeze(-1) # [B, H, W, C, 1]
|
||||
kernel_expanded = self.kernel.unsqueeze(0) # [1, H, W, C, F]
|
||||
|
||||
# Element-wise multiplication and spatial reduction
|
||||
output = (features_expanded * kernel_expanded).sum(dim=(2, 3)) # Sum H,W
|
||||
|
||||
# Reshape to combine channel and feature dimensions
|
||||
output = output.view(output.size(0), -1) # [B, C*F]
|
||||
|
||||
# Remove batch dim
|
||||
if len(original_shape) == 3:
|
||||
output = output.squeeze(0)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class Classifier(PreTrainedPolicy):
|
||||
"""Image classifier built on top of a pre-trained encoder."""
|
||||
|
||||
name = "reward_classifier"
|
||||
config_class = RewardClassifierConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RewardClassifierConfig,
|
||||
dataset_stats: Dict[str, Dict[str, Tensor]] | None = None,
|
||||
):
|
||||
from transformers import AutoModel
|
||||
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Initialize normalization (standardized with the policy framework)
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
# Set up encoder
|
||||
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
# Extract vision model if we're given a multimodal model
|
||||
if hasattr(encoder, "vision_model"):
|
||||
logging.info("Multimodal model detected - using vision encoder only")
|
||||
self.encoder = encoder.vision_model
|
||||
self.vision_config = encoder.config.vision_config
|
||||
else:
|
||||
self.encoder = encoder
|
||||
self.vision_config = getattr(encoder, "config", None)
|
||||
|
||||
# Model type from config
|
||||
self.is_cnn = self.config.model_type == "cnn"
|
||||
|
||||
# For CNNs, initialize backbone
|
||||
if self.is_cnn:
|
||||
self._setup_cnn_backbone()
|
||||
|
||||
self._freeze_encoder()
|
||||
|
||||
# Extract image keys from input_features
|
||||
self.image_keys = [
|
||||
key.replace(".", "_") for key in config.input_features if key.startswith(OBS_IMAGE)
|
||||
]
|
||||
|
||||
if self.is_cnn:
|
||||
self.encoders = nn.ModuleDict()
|
||||
for image_key in self.image_keys:
|
||||
encoder = self._create_single_encoder()
|
||||
self.encoders[image_key] = encoder
|
||||
|
||||
self._build_classifier_head()
|
||||
|
||||
def _setup_cnn_backbone(self):
|
||||
"""Set up CNN encoder"""
|
||||
if hasattr(self.encoder, "fc"):
|
||||
self.feature_dim = self.encoder.fc.in_features
|
||||
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
|
||||
elif hasattr(self.encoder.config, "hidden_sizes"):
|
||||
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
|
||||
else:
|
||||
raise ValueError("Unsupported CNN architecture")
|
||||
|
||||
def _freeze_encoder(self) -> None:
|
||||
"""Freeze the encoder parameters."""
|
||||
for param in self.encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _create_single_encoder(self):
|
||||
encoder = nn.Sequential(
|
||||
self.encoder,
|
||||
SpatialLearnedEmbeddings(
|
||||
height=4,
|
||||
width=4,
|
||||
channel=self.feature_dim,
|
||||
num_features=self.config.image_embedding_pooling_dim,
|
||||
),
|
||||
nn.Dropout(self.config.dropout_rate),
|
||||
nn.Linear(self.feature_dim * self.config.image_embedding_pooling_dim, self.config.latent_dim),
|
||||
nn.LayerNorm(self.config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
return encoder
|
||||
|
||||
def _build_classifier_head(self) -> None:
|
||||
"""Initialize the classifier head architecture."""
|
||||
# Get input dimension based on model type
|
||||
if self.is_cnn:
|
||||
input_dim = self.config.latent_dim
|
||||
else: # Transformer models
|
||||
if hasattr(self.encoder.config, "hidden_size"):
|
||||
input_dim = self.encoder.config.hidden_size
|
||||
else:
|
||||
raise ValueError("Unsupported transformer architecture since hidden_size is not found")
|
||||
|
||||
self.classifier_head = nn.Sequential(
|
||||
nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
|
||||
nn.Dropout(self.config.dropout_rate),
|
||||
nn.LayerNorm(self.config.hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(
|
||||
self.config.hidden_dim,
|
||||
1 if self.config.num_classes == 2 else self.config.num_classes,
|
||||
),
|
||||
)
|
||||
|
||||
def _get_encoder_output(self, x: torch.Tensor, image_key: str) -> torch.Tensor:
|
||||
"""Extract the appropriate output from the encoder."""
|
||||
with torch.no_grad():
|
||||
if self.is_cnn:
|
||||
# The HF ResNet applies pooling internally
|
||||
outputs = self.encoders[image_key](x)
|
||||
return outputs
|
||||
else: # Transformer models
|
||||
outputs = self.encoder(x)
|
||||
return outputs.last_hidden_state[:, 0, :]
|
||||
|
||||
def extract_images_and_labels(self, batch: Dict[str, Tensor]) -> Tuple[list, Tensor]:
|
||||
"""Extract image tensors and label tensors from batch."""
|
||||
# Check for both OBS_IMAGE and OBS_IMAGES prefixes
|
||||
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
|
||||
labels = batch["next.reward"]
|
||||
|
||||
return images, labels
|
||||
|
||||
def predict(self, xs: list) -> ClassifierOutput:
|
||||
"""Forward pass of the classifier for inference."""
|
||||
encoder_outputs = torch.hstack(
|
||||
[self._get_encoder_output(x, img_key) for x, img_key in zip(xs, self.image_keys, strict=True)]
|
||||
)
|
||||
logits = self.classifier_head(encoder_outputs)
|
||||
|
||||
if self.config.num_classes == 2:
|
||||
logits = logits.squeeze(-1)
|
||||
probabilities = torch.sigmoid(logits)
|
||||
else:
|
||||
probabilities = torch.softmax(logits, dim=-1)
|
||||
|
||||
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
|
||||
|
||||
def forward(self, batch: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]:
|
||||
"""Standard forward pass for training compatible with train.py."""
|
||||
# Normalize inputs if needed
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
# Extract images and labels
|
||||
images, labels = self.extract_images_and_labels(batch)
|
||||
|
||||
# Get predictions
|
||||
outputs = self.predict(images)
|
||||
|
||||
# Calculate loss
|
||||
if self.config.num_classes == 2:
|
||||
# Binary classification
|
||||
loss = nn.functional.binary_cross_entropy_with_logits(outputs.logits, labels)
|
||||
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
|
||||
else:
|
||||
# Multi-class classification
|
||||
loss = nn.functional.cross_entropy(outputs.logits, labels.long())
|
||||
predictions = torch.argmax(outputs.logits, dim=1)
|
||||
|
||||
# Calculate accuracy for logging
|
||||
correct = (predictions == labels).sum().item()
|
||||
total = labels.size(0)
|
||||
accuracy = 100 * correct / total
|
||||
|
||||
# Return loss and metrics for logging
|
||||
output_dict = {
|
||||
"accuracy": accuracy,
|
||||
"correct": correct,
|
||||
"total": total,
|
||||
}
|
||||
|
||||
return loss, output_dict
|
||||
|
||||
def predict_reward(self, batch, threshold=0.5):
|
||||
"""Eval method. Returns predicted reward with the decision threshold as argument."""
|
||||
# Check for both OBS_IMAGE and OBS_IMAGES prefixes
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
# Extract images from batch dict
|
||||
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
|
||||
|
||||
if self.config.num_classes == 2:
|
||||
probs = self.predict(images).probabilities
|
||||
logging.debug(f"Predicted reward images: {probs}")
|
||||
return (probs > threshold).float()
|
||||
else:
|
||||
return torch.argmax(self.predict(images).probabilities, dim=1)
|
||||
|
||||
def get_optim_params(self):
|
||||
"""Return optimizer parameters for the policy."""
|
||||
return self.parameters()
|
||||
|
||||
def select_action(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
"""
|
||||
This method is required by PreTrainedPolicy but not used for reward classifiers.
|
||||
The reward classifier is not an actor and does not select actions.
|
||||
"""
|
||||
raise NotImplementedError("Reward classifiers do not select actions")
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
This method is required by PreTrainedPolicy but not used for reward classifiers.
|
||||
The reward classifier is not an actor and does not select actions.
|
||||
"""
|
||||
pass
|
||||
243
lerobot/common/policies/sac/configuration_sac.py
Normal file
243
lerobot/common/policies/sac/configuration_sac.py
Normal file
@@ -0,0 +1,243 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import MultiAdamConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
def is_image_feature(key: str) -> bool:
|
||||
"""Check if a feature key represents an image feature.
|
||||
|
||||
Args:
|
||||
key: The feature key to check
|
||||
|
||||
Returns:
|
||||
True if the key represents an image feature, False otherwise
|
||||
"""
|
||||
return key.startswith("observation.image")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConcurrencyConfig:
|
||||
"""Configuration for the concurrency of the actor and learner.
|
||||
Possible values are:
|
||||
- "threads": Use threads for the actor and learner.
|
||||
- "processes": Use processes for the actor and learner.
|
||||
"""
|
||||
|
||||
actor: str = "threads"
|
||||
learner: str = "threads"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorLearnerConfig:
|
||||
learner_host: str = "127.0.0.1"
|
||||
learner_port: int = 50051
|
||||
policy_parameters_push_frequency: int = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class CriticNetworkConfig:
|
||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||
activate_final: bool = True
|
||||
final_activation: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorNetworkConfig:
|
||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||
activate_final: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyConfig:
|
||||
use_tanh_squash: bool = True
|
||||
log_std_min: float = 1e-5
|
||||
log_std_max: float = 10.0
|
||||
init_final: float = 0.05
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("sac")
|
||||
@dataclass
|
||||
class SACConfig(PreTrainedConfig):
|
||||
"""Soft Actor-Critic (SAC) configuration.
|
||||
|
||||
SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy
|
||||
reinforcement learning framework. It learns a policy and a Q-function simultaneously
|
||||
using experience collected from the environment.
|
||||
|
||||
This configuration class contains all the parameters needed to define a SAC agent,
|
||||
including network architectures, optimization settings, and algorithm-specific
|
||||
hyperparameters.
|
||||
"""
|
||||
|
||||
# Mapping of feature types to normalization modes
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ENV": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
# Statistics for normalizing different types of inputs
|
||||
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
},
|
||||
"observation.state": {
|
||||
"min": [0.0, 0.0],
|
||||
"max": [1.0, 1.0],
|
||||
},
|
||||
"action": {
|
||||
"min": [0.0, 0.0, 0.0],
|
||||
"max": [1.0, 1.0, 1.0],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Architecture specifics
|
||||
# Device to run the model on (e.g., "cuda", "cpu")
|
||||
device: str = "cpu"
|
||||
# Device to store the model on
|
||||
storage_device: str = "cpu"
|
||||
# Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10)
|
||||
vision_encoder_name: str | None = None
|
||||
# Whether to freeze the vision encoder during training
|
||||
freeze_vision_encoder: bool = True
|
||||
# Hidden dimension size for the image encoder
|
||||
image_encoder_hidden_dim: int = 32
|
||||
# Whether to use a shared encoder for actor and critic
|
||||
shared_encoder: bool = True
|
||||
# Number of discrete actions, eg for gripper actions
|
||||
num_discrete_actions: int | None = None
|
||||
# Dimension of the image embedding pooling
|
||||
image_embedding_pooling_dim: int = 8
|
||||
|
||||
# Training parameter
|
||||
# Number of steps for online training
|
||||
online_steps: int = 1000000
|
||||
# Seed for the online environment
|
||||
online_env_seed: int = 10000
|
||||
# Capacity of the online replay buffer
|
||||
online_buffer_capacity: int = 100000
|
||||
# Capacity of the offline replay buffer
|
||||
offline_buffer_capacity: int = 100000
|
||||
# Whether to use asynchronous prefetching for the buffers
|
||||
async_prefetch: bool = False
|
||||
# Number of steps before learning starts
|
||||
online_step_before_learning: int = 100
|
||||
# Frequency of policy updates
|
||||
policy_update_freq: int = 1
|
||||
|
||||
# SAC algorithm parameters
|
||||
# Discount factor for the SAC algorithm
|
||||
discount: float = 0.99
|
||||
# Initial temperature value
|
||||
temperature_init: float = 1.0
|
||||
# Number of critics in the ensemble
|
||||
num_critics: int = 2
|
||||
# Number of subsampled critics for training
|
||||
num_subsample_critics: int | None = None
|
||||
# Learning rate for the critic network
|
||||
critic_lr: float = 3e-4
|
||||
# Learning rate for the actor network
|
||||
actor_lr: float = 3e-4
|
||||
# Learning rate for the temperature parameter
|
||||
temperature_lr: float = 3e-4
|
||||
# Weight for the critic target update
|
||||
critic_target_update_weight: float = 0.005
|
||||
# Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1)
|
||||
utd_ratio: int = 1
|
||||
# Hidden dimension size for the state encoder
|
||||
state_encoder_hidden_dim: int = 256
|
||||
# Dimension of the latent space
|
||||
latent_dim: int = 256
|
||||
# Target entropy for the SAC algorithm
|
||||
target_entropy: float | None = None
|
||||
# Whether to use backup entropy for the SAC algorithm
|
||||
use_backup_entropy: bool = True
|
||||
# Gradient clipping norm for the SAC algorithm
|
||||
grad_clip_norm: float = 40.0
|
||||
|
||||
# Network configuration
|
||||
# Configuration for the critic network architecture
|
||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
# Configuration for the actor network architecture
|
||||
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||
# Configuration for the policy parameters
|
||||
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
||||
# Configuration for the discrete critic network
|
||||
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
# Configuration for actor-learner architecture
|
||||
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
||||
# Configuration for concurrency settings (you can use threads or processes for the actor and learner)
|
||||
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||
|
||||
# Optimizations
|
||||
use_torch_compile: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Any validation specific to SAC configuration
|
||||
|
||||
def get_optimizer_preset(self) -> MultiAdamConfig:
|
||||
return MultiAdamConfig(
|
||||
weight_decay=0.0,
|
||||
optimizer_groups={
|
||||
"actor": {"lr": self.actor_lr},
|
||||
"critic": {"lr": self.critic_lr},
|
||||
"temperature": {"lr": self.temperature_lr},
|
||||
},
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
has_image = any(is_image_feature(key) for key in self.input_features)
|
||||
has_state = "observation.state" in self.input_features
|
||||
|
||||
if not (has_state or has_image):
|
||||
raise ValueError(
|
||||
"You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features"
|
||||
)
|
||||
|
||||
if "action" not in self.output_features:
|
||||
raise ValueError("You must provide 'action' in the output features")
|
||||
|
||||
@property
|
||||
def image_features(self) -> list[str]:
|
||||
return [key for key in self.input_features if is_image_feature(key)]
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return None # SAC typically predicts one action at a time
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
1120
lerobot/common/policies/sac/modeling_sac.py
Normal file
1120
lerobot/common/policies/sac/modeling_sac.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -87,6 +87,8 @@ class RecordControlConfig(ControlConfig):
|
||||
play_sounds: bool = True
|
||||
# Resume recording on an existing dataset.
|
||||
resume: bool = False
|
||||
# Reset follower arms to an initial position.
|
||||
reset_follower_arms: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
|
||||
@@ -24,6 +24,7 @@ from contextlib import nullcontext
|
||||
from copy import copy
|
||||
from functools import cache
|
||||
|
||||
import numpy as np
|
||||
import rerun as rr
|
||||
import torch
|
||||
from deepdiff import DeepDiff
|
||||
@@ -129,9 +130,11 @@ def predict_action(observation, policy, device, use_amp):
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
"""
|
||||
Initializes a keyboard listener to enable early termination of an episode
|
||||
or environment reset by pressing the right arrow key ('->'). This may require
|
||||
sudo permissions to allow the terminal to monitor keyboard events.
|
||||
"""
|
||||
events = {}
|
||||
events["exit_early"] = False
|
||||
events["rerecord_episode"] = False
|
||||
@@ -160,6 +163,7 @@ def init_keyboard_listener():
|
||||
print("Escape key pressed. Stopping data recording...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
@@ -253,7 +257,10 @@ def control_loop(
|
||||
|
||||
if policy is not None:
|
||||
pred_action = predict_action(
|
||||
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
||||
observation,
|
||||
policy,
|
||||
get_safe_torch_device(policy.config.device),
|
||||
policy.config.use_amp,
|
||||
)
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset.
|
||||
@@ -301,7 +308,17 @@ def reset_environment(robot, events, reset_time_s, fps):
|
||||
)
|
||||
|
||||
|
||||
def stop_recording(robot, listener, display_data):
|
||||
def reset_follower_position(robot_arm, target_position):
|
||||
current_position = robot_arm.read("Present_Position")
|
||||
trajectory = torch.from_numpy(
|
||||
np.linspace(current_position, target_position, 50)
|
||||
) # NOTE: 30 is just an arbitrary number
|
||||
for pose in trajectory:
|
||||
robot_arm.write("Goal_Position", pose)
|
||||
busy_wait(0.015)
|
||||
|
||||
|
||||
def stop_recording(robot, listener, display_cameras):
|
||||
robot.disconnect()
|
||||
|
||||
if not is_headless() and listener is not None:
|
||||
@@ -327,12 +344,20 @@ def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||
|
||||
|
||||
def sanity_check_dataset_robot_compatibility(
|
||||
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool
|
||||
dataset: LeRobotDataset,
|
||||
robot: Robot,
|
||||
fps: int,
|
||||
use_videos: bool,
|
||||
extra_features: dict = None,
|
||||
) -> None:
|
||||
features_from_robot = get_features_from_robot(robot, use_videos)
|
||||
if extra_features is not None:
|
||||
features_from_robot.update(extra_features)
|
||||
|
||||
fields = [
|
||||
("robot_type", dataset.meta.robot_type, robot.robot_type),
|
||||
("fps", dataset.fps, fps),
|
||||
("features", dataset.features, get_features_from_robot(robot, use_videos)),
|
||||
("features", dataset.features, features_from_robot),
|
||||
]
|
||||
|
||||
mismatches = []
|
||||
|
||||
@@ -18,9 +18,11 @@ import os
|
||||
import os.path as osp
|
||||
import platform
|
||||
import subprocess
|
||||
from copy import copy
|
||||
import time
|
||||
from copy import copy, deepcopy
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from statistics import mean
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -107,11 +109,17 @@ def is_amp_available(device: str):
|
||||
raise ValueError(f"Unknown device '{device}.")
|
||||
|
||||
|
||||
def init_logging():
|
||||
def init_logging(log_file: Path | None = None, display_pid: bool = False):
|
||||
def custom_format(record):
|
||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
fnameline = f"{record.pathname}:{record.lineno}"
|
||||
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||
|
||||
# NOTE: Display PID is useful for multi-process logging.
|
||||
if display_pid:
|
||||
pid_str = f"[PID: {os.getpid()}]"
|
||||
message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||
else:
|
||||
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||
return message
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -125,6 +133,12 @@ def init_logging():
|
||||
console_handler.setFormatter(formatter)
|
||||
logging.getLogger().addHandler(console_handler)
|
||||
|
||||
if log_file is not None:
|
||||
# Additionally write logs to file
|
||||
file_handler = logging.FileHandler(log_file)
|
||||
file_handler.setFormatter(formatter)
|
||||
logging.getLogger().addHandler(file_handler)
|
||||
|
||||
|
||||
def format_big_number(num, precision=0):
|
||||
suffixes = ["", "K", "M", "B", "T", "Q"]
|
||||
@@ -228,3 +242,114 @@ def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
|
||||
except TypeError:
|
||||
# If a TypeError is raised, the string is not a valid dtype
|
||||
return False
|
||||
|
||||
|
||||
class TimerManager:
|
||||
"""
|
||||
Lightweight utility to measure elapsed time.
|
||||
|
||||
Examples
|
||||
--------
|
||||
```python
|
||||
# Example 1: Using context manager
|
||||
timer = TimerManager("Policy", log=False)
|
||||
for _ in range(3):
|
||||
with timer:
|
||||
time.sleep(0.01)
|
||||
print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01
|
||||
```
|
||||
|
||||
```python
|
||||
# Example 2: Using start/stop methods
|
||||
timer = TimerManager("Policy", log=False)
|
||||
timer.start()
|
||||
time.sleep(0.01)
|
||||
timer.stop()
|
||||
print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
label: str = "Elapsed-time",
|
||||
log: bool = True,
|
||||
logger: logging.Logger | None = None,
|
||||
):
|
||||
self.label = label
|
||||
self.log = log
|
||||
self.logger = logger
|
||||
self._start: float | None = None
|
||||
self._history: list[float] = []
|
||||
|
||||
def __enter__(self):
|
||||
return self.start()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.stop()
|
||||
|
||||
def start(self):
|
||||
self._start = time.perf_counter()
|
||||
return self
|
||||
|
||||
def stop(self) -> float:
|
||||
if self._start is None:
|
||||
raise RuntimeError("Timer was never started.")
|
||||
elapsed = time.perf_counter() - self._start
|
||||
self._history.append(elapsed)
|
||||
self._start = None
|
||||
if self.log:
|
||||
if self.logger is not None:
|
||||
self.logger.info(f"{self.label}: {elapsed:.6f} s")
|
||||
else:
|
||||
logging.info(f"{self.label}: {elapsed:.6f} s")
|
||||
return elapsed
|
||||
|
||||
def reset(self):
|
||||
self._history.clear()
|
||||
|
||||
@property
|
||||
def last(self) -> float:
|
||||
return self._history[-1] if self._history else 0.0
|
||||
|
||||
@property
|
||||
def avg(self) -> float:
|
||||
return mean(self._history) if self._history else 0.0
|
||||
|
||||
@property
|
||||
def total(self) -> float:
|
||||
return sum(self._history)
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
return len(self._history)
|
||||
|
||||
@property
|
||||
def history(self) -> list[float]:
|
||||
return deepcopy(self._history)
|
||||
|
||||
@property
|
||||
def fps_history(self) -> list[float]:
|
||||
return [1.0 / t for t in self._history]
|
||||
|
||||
@property
|
||||
def fps_last(self) -> float:
|
||||
return 0.0 if self.last == 0 else 1.0 / self.last
|
||||
|
||||
@property
|
||||
def fps_avg(self) -> float:
|
||||
return 0.0 if self.avg == 0 else 1.0 / self.avg
|
||||
|
||||
def percentile(self, p: float) -> float:
|
||||
"""
|
||||
Return the p-th percentile of recorded times.
|
||||
"""
|
||||
if not self._history:
|
||||
return 0.0
|
||||
return float(np.percentile(self._history, p))
|
||||
|
||||
def fps_percentile(self, p: float) -> float:
|
||||
"""
|
||||
FPS corresponding to the p-th percentile time.
|
||||
"""
|
||||
val = self.percentile(p)
|
||||
return 0.0 if val == 0 else 1.0 / val
|
||||
|
||||
@@ -30,9 +30,10 @@ def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[st
|
||||
"""Return a group name for logging. Optionally returns group name as list."""
|
||||
lst = [
|
||||
f"policy:{cfg.policy.type}",
|
||||
f"dataset:{cfg.dataset.repo_id}",
|
||||
f"seed:{cfg.seed}",
|
||||
]
|
||||
if cfg.dataset is not None:
|
||||
lst.append(f"dataset:{cfg.dataset.repo_id}")
|
||||
if cfg.env is not None:
|
||||
lst.append(f"env:{cfg.env.type}")
|
||||
return lst if return_list else "-".join(lst)
|
||||
@@ -92,6 +93,12 @@ class WandBLogger:
|
||||
resume="must" if cfg.resume else None,
|
||||
mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online",
|
||||
)
|
||||
run_id = wandb.run.id
|
||||
# NOTE: We will override the cfg.wandb.run_id with the wandb run id.
|
||||
# This is because we want to be able to resume the run from the wandb run id.
|
||||
cfg.wandb.run_id = run_id
|
||||
# Handle custom step key for rl asynchronous training.
|
||||
self._wandb_custom_step_key: set[str] | None = None
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
@@ -108,9 +115,26 @@ class WandBLogger:
|
||||
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
|
||||
self._wandb.log_artifact(artifact)
|
||||
|
||||
def log_dict(self, d: dict, step: int, mode: str = "train"):
|
||||
def log_dict(
|
||||
self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None
|
||||
):
|
||||
if mode not in {"train", "eval"}:
|
||||
raise ValueError(mode)
|
||||
if step is None and custom_step_key is None:
|
||||
raise ValueError("Either step or custom_step_key must be provided.")
|
||||
|
||||
# NOTE: This is not simple. Wandb step must always monotonically increase and it
|
||||
# increases with each wandb.log call, but in the case of asynchronous RL for example,
|
||||
# multiple time steps is possible. For example, the interaction step with the environment,
|
||||
# the training step, the evaluation step, etc. So we need to define a custom step key
|
||||
# to log the correct step for each metric.
|
||||
if custom_step_key is not None:
|
||||
if self._wandb_custom_step_key is None:
|
||||
self._wandb_custom_step_key = set()
|
||||
new_custom_key = f"{mode}/{custom_step_key}"
|
||||
if new_custom_key not in self._wandb_custom_step_key:
|
||||
self._wandb_custom_step_key.add(new_custom_key)
|
||||
self._wandb.define_metric(new_custom_key, hidden=True)
|
||||
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str)):
|
||||
@@ -118,7 +142,18 @@ class WandBLogger:
|
||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
|
||||
# Do not log the custom step key itself.
|
||||
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
|
||||
continue
|
||||
|
||||
if custom_step_key is not None:
|
||||
value_custom_step = d[custom_step_key]
|
||||
data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step}
|
||||
self._wandb.log(data)
|
||||
continue
|
||||
|
||||
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
if mode not in {"train", "eval"}:
|
||||
|
||||
@@ -34,11 +34,10 @@ TRAIN_CONFIG_NAME = "train_config.json"
|
||||
|
||||
@dataclass
|
||||
class TrainPipelineConfig(HubMixin):
|
||||
dataset: DatasetConfig
|
||||
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need an offline dataset
|
||||
env: envs.EnvConfig | None = None
|
||||
policy: PreTrainedConfig | None = None
|
||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
||||
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session # with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||
output_dir: Path | None = None
|
||||
job_name: str | None = None
|
||||
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
|
||||
@@ -107,7 +106,7 @@ class TrainPipelineConfig(HubMixin):
|
||||
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
|
||||
self.output_dir = Path("outputs/train") / train_dir
|
||||
|
||||
if isinstance(self.dataset.repo_id, list):
|
||||
if self.dataset is not None and isinstance(self.dataset.repo_id, list):
|
||||
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
|
||||
|
||||
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
||||
|
||||
@@ -23,6 +23,7 @@ class FeatureType(str, Enum):
|
||||
VISUAL = "VISUAL"
|
||||
ENV = "ENV"
|
||||
ACTION = "ACTION"
|
||||
REWARD = "REWARD"
|
||||
|
||||
|
||||
class NormalizationMode(str, Enum):
|
||||
|
||||
722
lerobot/scripts/server/actor_server.py
Normal file
722
lerobot/scripts/server/actor_server.py
Normal file
@@ -0,0 +1,722 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Actor server runner for distributed HILSerl robot policy training.
|
||||
|
||||
This script implements the actor component of the distributed HILSerl architecture.
|
||||
It executes the policy in the robot environment, collects experience,
|
||||
and sends transitions to the learner server for policy updates.
|
||||
|
||||
Examples of usage:
|
||||
|
||||
- Start an actor server for real robot training with human-in-the-loop intervention:
|
||||
```bash
|
||||
python lerobot/scripts/server/actor_server.py --config_path lerobot/configs/train_config_hilserl_so100.json
|
||||
```
|
||||
|
||||
- Run with a specific robot type for a pick and place task:
|
||||
```bash
|
||||
python lerobot/scripts/server/actor_server.py \
|
||||
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||
--robot.type=so100 \
|
||||
--task=pick_and_place
|
||||
```
|
||||
|
||||
- Set a custom workspace bound for the robot's end-effector:
|
||||
```bash
|
||||
python lerobot/scripts/server/actor_server.py \
|
||||
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||
--env.ee_action_space_params.bounds.max="[0.24, 0.20, 0.10]" \
|
||||
--env.ee_action_space_params.bounds.min="[0.16, -0.08, 0.03]"
|
||||
```
|
||||
|
||||
- Run with specific camera crop parameters:
|
||||
```bash
|
||||
python lerobot/scripts/server/actor_server.py \
|
||||
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||
--env.crop_params_dict="{'observation.images.side': [180, 207, 180, 200], 'observation.images.front': [180, 250, 120, 150]}"
|
||||
```
|
||||
|
||||
**NOTE**: The actor server requires a running learner server to connect to. Ensure the learner
|
||||
server is started before launching the actor.
|
||||
|
||||
**NOTE**: Human intervention is key to HILSerl training. Press the upper right trigger button on the
|
||||
gamepad to take control of the robot during training. Initially intervene frequently, then gradually
|
||||
reduce interventions as the policy improves.
|
||||
|
||||
**WORKFLOW**:
|
||||
1. Determine robot workspace bounds using `find_joint_limits.py`
|
||||
2. Record demonstrations with `gym_manipulator.py` in record mode
|
||||
3. Process the dataset and determine camera crops with `crop_dataset_roi.py`
|
||||
4. Start the learner server with the training configuration
|
||||
5. Start this actor server with the same configuration
|
||||
6. Use human interventions to guide policy learning
|
||||
|
||||
For more details on the complete HILSerl training workflow, see:
|
||||
https://github.com/michel-aractingi/lerobot-hilserl-guide
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from queue import Empty
|
||||
|
||||
import grpc
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
from lerobot.common.utils.utils import (
|
||||
TimerManager,
|
||||
get_safe_torch_device,
|
||||
init_logging,
|
||||
)
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service
|
||||
from lerobot.scripts.server.buffer import Transition
|
||||
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
||||
from lerobot.scripts.server.network_utils import (
|
||||
bytes_to_state_dict,
|
||||
python_object_to_bytes,
|
||||
receive_bytes_in_chunks,
|
||||
send_bytes_in_chunks,
|
||||
transitions_to_bytes,
|
||||
)
|
||||
from lerobot.scripts.server.utils import (
|
||||
get_last_item_from_queue,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
setup_process_handlers,
|
||||
)
|
||||
|
||||
ACTOR_SHUTDOWN_TIMEOUT = 30
|
||||
|
||||
|
||||
#################################################
|
||||
# Main entry point #
|
||||
#################################################
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def actor_cli(cfg: TrainPipelineConfig):
|
||||
cfg.validate()
|
||||
display_pid = False
|
||||
if not use_threads(cfg):
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
display_pid = True
|
||||
|
||||
# Create logs directory to ensure it exists
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_{cfg.job_name}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file, display_pid=display_pid)
|
||||
logging.info(f"Actor logging initialized, writing to {log_file}")
|
||||
|
||||
shutdown_event = setup_process_handlers(use_threads(cfg))
|
||||
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
logging.info("[ACTOR] Establishing connection with Learner")
|
||||
if not establish_learner_connection(learner_client, shutdown_event):
|
||||
logging.error("[ACTOR] Failed to establish connection with Learner")
|
||||
return
|
||||
|
||||
if not use_threads(cfg):
|
||||
# If we use multithreading, we can reuse the channel
|
||||
grpc_channel.close()
|
||||
grpc_channel = None
|
||||
|
||||
logging.info("[ACTOR] Connection with Learner established")
|
||||
|
||||
parameters_queue = Queue()
|
||||
transitions_queue = Queue()
|
||||
interactions_queue = Queue()
|
||||
|
||||
concurrency_entity = None
|
||||
if use_threads(cfg):
|
||||
from threading import Thread
|
||||
|
||||
concurrency_entity = Thread
|
||||
else:
|
||||
from multiprocessing import Process
|
||||
|
||||
concurrency_entity = Process
|
||||
|
||||
receive_policy_process = concurrency_entity(
|
||||
target=receive_policy,
|
||||
args=(cfg, parameters_queue, shutdown_event, grpc_channel),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
transitions_process = concurrency_entity(
|
||||
target=send_transitions,
|
||||
args=(cfg, transitions_queue, shutdown_event, grpc_channel),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
interactions_process = concurrency_entity(
|
||||
target=send_interactions,
|
||||
args=(cfg, interactions_queue, shutdown_event, grpc_channel),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
transitions_process.start()
|
||||
interactions_process.start()
|
||||
receive_policy_process.start()
|
||||
|
||||
act_with_policy(
|
||||
cfg=cfg,
|
||||
shutdown_event=shutdown_event,
|
||||
parameters_queue=parameters_queue,
|
||||
transitions_queue=transitions_queue,
|
||||
interactions_queue=interactions_queue,
|
||||
)
|
||||
logging.info("[ACTOR] Policy process joined")
|
||||
|
||||
logging.info("[ACTOR] Closing queues")
|
||||
transitions_queue.close()
|
||||
interactions_queue.close()
|
||||
parameters_queue.close()
|
||||
|
||||
transitions_process.join()
|
||||
logging.info("[ACTOR] Transitions process joined")
|
||||
interactions_process.join()
|
||||
logging.info("[ACTOR] Interactions process joined")
|
||||
receive_policy_process.join()
|
||||
logging.info("[ACTOR] Receive policy process joined")
|
||||
|
||||
logging.info("[ACTOR] join queues")
|
||||
transitions_queue.cancel_join_thread()
|
||||
interactions_queue.cancel_join_thread()
|
||||
parameters_queue.cancel_join_thread()
|
||||
|
||||
logging.info("[ACTOR] queues closed")
|
||||
|
||||
|
||||
#################################################
|
||||
# Core algorithm functions #
|
||||
#################################################
|
||||
|
||||
|
||||
def act_with_policy(
|
||||
cfg: TrainPipelineConfig,
|
||||
shutdown_event: any, # Event,
|
||||
parameters_queue: Queue,
|
||||
transitions_queue: Queue,
|
||||
interactions_queue: Queue,
|
||||
):
|
||||
"""
|
||||
Executes policy interaction within the environment.
|
||||
|
||||
This function rolls out the policy in the environment, collecting interaction data and pushing it to a queue for streaming to the learner.
|
||||
Once an episode is completed, updated network parameters received from the learner are retrieved from a queue and loaded into the network.
|
||||
|
||||
Args:
|
||||
cfg: Configuration settings for the interaction process.
|
||||
shutdown_event: Event to check if the process should shutdown.
|
||||
parameters_queue: Queue to receive updated network parameters from the learner.
|
||||
transitions_queue: Queue to send transitions to the learner.
|
||||
interactions_queue: Queue to send interactions to the learner.
|
||||
"""
|
||||
# Initialize logging for multiprocessing
|
||||
if not use_threads(cfg):
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_policy_{os.getpid()}.log")
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor policy process logging initialized")
|
||||
|
||||
logging.info("make_env online")
|
||||
|
||||
online_env = make_robot_env(cfg=cfg.env)
|
||||
|
||||
set_seed(cfg.seed)
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("make_policy")
|
||||
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy instance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
policy: SACPolicy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
policy = policy.eval()
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
obs, info = online_env.reset()
|
||||
|
||||
# NOTE: For the moment we will solely handle the case of a single environment
|
||||
sum_reward_episode = 0
|
||||
list_transition_to_send_to_learner = []
|
||||
episode_intervention = False
|
||||
# Add counters for intervention rate calculation
|
||||
episode_intervention_steps = 0
|
||||
episode_total_steps = 0
|
||||
|
||||
policy_timer = TimerManager("Policy inference", log=False)
|
||||
|
||||
for interaction_step in range(cfg.policy.online_steps):
|
||||
start_time = time.perf_counter()
|
||||
if shutdown_event.is_set():
|
||||
logging.info("[ACTOR] Shutting down act_with_policy")
|
||||
return
|
||||
|
||||
if interaction_step >= cfg.policy.online_step_before_learning:
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
action = policy.select_action(batch=obs)
|
||||
policy_fps = policy_timer.fps_last
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
|
||||
else:
|
||||
action = online_env.action_space.sample()
|
||||
|
||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||
|
||||
sum_reward_episode += float(reward)
|
||||
# Increment total steps counter for intervention rate
|
||||
episode_total_steps += 1
|
||||
|
||||
# NOTE: We override the action if the intervention is True, because the action applied is the intervention action
|
||||
if "is_intervention" in info and info["is_intervention"]:
|
||||
# NOTE: The action space for demonstration before hand is with the full action space
|
||||
# but sometimes for example we want to deactivate the gripper
|
||||
action = info["action_intervention"]
|
||||
episode_intervention = True
|
||||
# Increment intervention steps counter
|
||||
episode_intervention_steps += 1
|
||||
|
||||
list_transition_to_send_to_learner.append(
|
||||
Transition(
|
||||
state=obs,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_obs,
|
||||
done=done,
|
||||
truncated=truncated, # TODO: (azouitine) Handle truncation properly
|
||||
complementary_info=info,
|
||||
)
|
||||
)
|
||||
# assign obs to the next obs and continue the rollout
|
||||
obs = next_obs
|
||||
|
||||
if done or truncated:
|
||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
|
||||
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
|
||||
|
||||
if len(list_transition_to_send_to_learner) > 0:
|
||||
push_transitions_to_transport_queue(
|
||||
transitions=list_transition_to_send_to_learner,
|
||||
transitions_queue=transitions_queue,
|
||||
)
|
||||
list_transition_to_send_to_learner = []
|
||||
|
||||
stats = get_frequency_stats(policy_timer)
|
||||
policy_timer.reset()
|
||||
|
||||
# Calculate intervention rate
|
||||
intervention_rate = 0.0
|
||||
if episode_total_steps > 0:
|
||||
intervention_rate = episode_intervention_steps / episode_total_steps
|
||||
|
||||
# Send episodic reward to the learner
|
||||
interactions_queue.put(
|
||||
python_object_to_bytes(
|
||||
{
|
||||
"Episodic reward": sum_reward_episode,
|
||||
"Interaction step": interaction_step,
|
||||
"Episode intervention": int(episode_intervention),
|
||||
"Intervention rate": intervention_rate,
|
||||
**stats,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Reset intervention counters
|
||||
sum_reward_episode = 0.0
|
||||
episode_intervention = False
|
||||
episode_intervention_steps = 0
|
||||
episode_total_steps = 0
|
||||
obs, info = online_env.reset()
|
||||
|
||||
if cfg.env.fps is not None:
|
||||
dt_time = time.perf_counter() - start_time
|
||||
busy_wait(1 / cfg.env.fps - dt_time)
|
||||
|
||||
|
||||
#################################################
|
||||
# Communication Functions - Group all gRPC/messaging functions #
|
||||
#################################################
|
||||
|
||||
|
||||
def establish_learner_connection(
|
||||
stub: hilserl_pb2_grpc.LearnerServiceStub,
|
||||
shutdown_event: Event, # type: ignore
|
||||
attempts: int = 30,
|
||||
):
|
||||
"""Establish a connection with the learner.
|
||||
|
||||
Args:
|
||||
stub (hilserl_pb2_grpc.LearnerServiceStub): The stub to use for the connection.
|
||||
shutdown_event (Event): The event to check if the connection should be established.
|
||||
attempts (int): The number of attempts to establish the connection.
|
||||
Returns:
|
||||
bool: True if the connection is established, False otherwise.
|
||||
"""
|
||||
for _ in range(attempts):
|
||||
if shutdown_event.is_set():
|
||||
logging.info("[ACTOR] Shutting down establish_learner_connection")
|
||||
return False
|
||||
|
||||
# Force a connection attempt and check state
|
||||
try:
|
||||
logging.info("[ACTOR] Send ready message to Learner")
|
||||
if stub.Ready(hilserl_pb2.Empty()) == hilserl_pb2.Empty():
|
||||
return True
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] Waiting for Learner to be ready... {e}")
|
||||
time.sleep(2)
|
||||
return False
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def learner_service_client(
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 50051,
|
||||
) -> tuple[hilserl_pb2_grpc.LearnerServiceStub, grpc.Channel]:
|
||||
import json
|
||||
|
||||
"""
|
||||
Returns a client for the learner service.
|
||||
|
||||
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
|
||||
So we need to create only one client and reuse it.
|
||||
"""
|
||||
|
||||
service_config = {
|
||||
"methodConfig": [
|
||||
{
|
||||
"name": [{}], # Applies to ALL methods in ALL services
|
||||
"retryPolicy": {
|
||||
"maxAttempts": 5, # Max retries (total attempts = 5)
|
||||
"initialBackoff": "0.1s", # First retry after 0.1s
|
||||
"maxBackoff": "2s", # Max wait time between retries
|
||||
"backoffMultiplier": 2, # Exponential backoff factor
|
||||
"retryableStatusCodes": [
|
||||
"UNAVAILABLE",
|
||||
"DEADLINE_EXCEEDED",
|
||||
], # Retries on network failures
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
service_config_json = json.dumps(service_config)
|
||||
|
||||
channel = grpc.insecure_channel(
|
||||
f"{host}:{port}",
|
||||
options=[
|
||||
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||
("grpc.enable_retries", 1),
|
||||
("grpc.service_config", service_config_json),
|
||||
],
|
||||
)
|
||||
stub = hilserl_pb2_grpc.LearnerServiceStub(channel)
|
||||
logging.info("[ACTOR] Learner service client created")
|
||||
return stub, channel
|
||||
|
||||
|
||||
def receive_policy(
|
||||
cfg: TrainPipelineConfig,
|
||||
parameters_queue: Queue,
|
||||
shutdown_event: Event, # type: ignore
|
||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
):
|
||||
"""Receive parameters from the learner.
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): The configuration for the actor.
|
||||
parameters_queue (Queue): The queue to receive the parameters.
|
||||
shutdown_event (Event): The event to check if the process should shutdown.
|
||||
"""
|
||||
logging.info("[ACTOR] Start receiving parameters from the Learner")
|
||||
if not use_threads(cfg):
|
||||
# Create a process-specific log file
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_receive_policy_{os.getpid()}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor receive policy process logging initialized")
|
||||
|
||||
# Setup process handlers to handle shutdown signal
|
||||
# But use shutdown event from the main process
|
||||
setup_process_handlers(use_threads=False)
|
||||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
try:
|
||||
iterator = learner_client.StreamParameters(hilserl_pb2.Empty())
|
||||
receive_bytes_in_chunks(
|
||||
iterator,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
log_prefix="[ACTOR] parameters",
|
||||
)
|
||||
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
if not use_threads(cfg):
|
||||
grpc_channel.close()
|
||||
logging.info("[ACTOR] Received policy loop stopped")
|
||||
|
||||
|
||||
def send_transitions(
|
||||
cfg: TrainPipelineConfig,
|
||||
transitions_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
) -> hilserl_pb2.Empty:
|
||||
"""
|
||||
Sends transitions to the learner.
|
||||
|
||||
This function continuously retrieves messages from the queue and processes:
|
||||
|
||||
- Transition Data:
|
||||
- A batch of transitions (observation, action, reward, next observation) is collected.
|
||||
- Transitions are moved to the CPU and serialized using PyTorch.
|
||||
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
|
||||
"""
|
||||
|
||||
if not use_threads(cfg):
|
||||
# Create a process-specific log file
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_transitions_{os.getpid()}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor transitions process logging initialized")
|
||||
|
||||
# Setup process handlers to handle shutdown signal
|
||||
# But use shutdown event from the main process
|
||||
setup_process_handlers(False)
|
||||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
try:
|
||||
learner_client.SendTransitions(transitions_stream(shutdown_event, transitions_queue))
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
logging.info("[ACTOR] Finished streaming transitions")
|
||||
|
||||
if not use_threads(cfg):
|
||||
grpc_channel.close()
|
||||
logging.info("[ACTOR] Transitions process stopped")
|
||||
|
||||
|
||||
def send_interactions(
|
||||
cfg: TrainPipelineConfig,
|
||||
interactions_queue: Queue,
|
||||
shutdown_event: Event, # type: ignore
|
||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
) -> hilserl_pb2.Empty:
|
||||
"""
|
||||
Sends interactions to the learner.
|
||||
|
||||
This function continuously retrieves messages from the queue and processes:
|
||||
|
||||
- Interaction Messages:
|
||||
- Contains useful statistics about episodic rewards and policy timings.
|
||||
- The message is serialized using `pickle` and sent to the learner.
|
||||
"""
|
||||
|
||||
if not use_threads(cfg):
|
||||
# Create a process-specific log file
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_interactions_{os.getpid()}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor interactions process logging initialized")
|
||||
|
||||
# Setup process handlers to handle shutdown signal
|
||||
# But use shutdown event from the main process
|
||||
setup_process_handlers(False)
|
||||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
try:
|
||||
learner_client.SendInteractions(interactions_stream(shutdown_event, interactions_queue))
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
logging.info("[ACTOR] Finished streaming interactions")
|
||||
|
||||
if not use_threads(cfg):
|
||||
grpc_channel.close()
|
||||
logging.info("[ACTOR] Interactions process stopped")
|
||||
|
||||
|
||||
def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilserl_pb2.Empty: # type: ignore
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
message = transitions_queue.get(block=True, timeout=5)
|
||||
except Empty:
|
||||
logging.debug("[ACTOR] Transition queue is empty")
|
||||
continue
|
||||
|
||||
yield from send_bytes_in_chunks(
|
||||
message, hilserl_pb2.Transition, log_prefix="[ACTOR] Send transitions"
|
||||
)
|
||||
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
|
||||
def interactions_stream(
|
||||
shutdown_event: Event, # type: ignore
|
||||
interactions_queue: Queue,
|
||||
) -> hilserl_pb2.Empty:
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
message = interactions_queue.get(block=True, timeout=5)
|
||||
except Empty:
|
||||
logging.debug("[ACTOR] Interaction queue is empty")
|
||||
continue
|
||||
|
||||
yield from send_bytes_in_chunks(
|
||||
message,
|
||||
hilserl_pb2.InteractionMessage,
|
||||
log_prefix="[ACTOR] Send interactions",
|
||||
)
|
||||
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
|
||||
#################################################
|
||||
# Policy functions #
|
||||
#################################################
|
||||
|
||||
|
||||
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
|
||||
if not parameters_queue.empty():
|
||||
logging.info("[ACTOR] Load new parameters from Learner.")
|
||||
bytes_state_dict = get_last_item_from_queue(parameters_queue)
|
||||
state_dict = bytes_to_state_dict(bytes_state_dict)
|
||||
state_dict = move_state_dict_to_device(state_dict, device=device)
|
||||
policy.load_state_dict(state_dict)
|
||||
|
||||
|
||||
#################################################
|
||||
# Utilities functions #
|
||||
#################################################
|
||||
|
||||
|
||||
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
|
||||
"""Send transitions to learner in smaller chunks to avoid network issues.
|
||||
|
||||
Args:
|
||||
transitions: List of transitions to send
|
||||
message_queue: Queue to send messages to learner
|
||||
chunk_size: Size of each chunk to send
|
||||
"""
|
||||
transition_to_send_to_learner = []
|
||||
for transition in transitions:
|
||||
tr = move_transition_to_device(transition=transition, device="cpu")
|
||||
for key, value in tr["state"].items():
|
||||
if torch.isnan(value).any():
|
||||
logging.warning(f"Found NaN values in transition {key}")
|
||||
|
||||
transition_to_send_to_learner.append(tr)
|
||||
|
||||
transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner))
|
||||
|
||||
|
||||
def get_frequency_stats(timer: TimerManager) -> dict[str, float]:
|
||||
"""Get the frequency statistics of the policy.
|
||||
|
||||
Args:
|
||||
timer (TimerManager): The timer with collected metrics.
|
||||
|
||||
Returns:
|
||||
dict[str, float]: The frequency statistics of the policy.
|
||||
"""
|
||||
stats = {}
|
||||
if timer.count > 1:
|
||||
avg_fps = timer.fps_avg
|
||||
p90_fps = timer.fps_percentile(90)
|
||||
logging.debug(f"[ACTOR] Average policy frame rate: {avg_fps}")
|
||||
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {p90_fps}")
|
||||
stats = {
|
||||
"Policy frequency [Hz]": avg_fps,
|
||||
"Policy frequency 90th-p [Hz]": p90_fps,
|
||||
}
|
||||
return stats
|
||||
|
||||
|
||||
def log_policy_frequency_issue(policy_fps: float, cfg: TrainPipelineConfig, interaction_step: int):
|
||||
if policy_fps < cfg.env.fps:
|
||||
logging.warning(
|
||||
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.env.fps} at step {interaction_step}"
|
||||
)
|
||||
|
||||
|
||||
def use_threads(cfg: TrainPipelineConfig) -> bool:
|
||||
return cfg.policy.concurrency.actor == "threads"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
actor_cli()
|
||||
820
lerobot/scripts/server/buffer.py
Normal file
820
lerobot/scripts/server/buffer.py
Normal file
@@ -0,0 +1,820 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
from contextlib import suppress
|
||||
from typing import Callable, Optional, Sequence, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.scripts.server.utils import Transition
|
||||
|
||||
|
||||
class BatchTransition(TypedDict):
|
||||
state: dict[str, torch.Tensor]
|
||||
action: torch.Tensor
|
||||
reward: torch.Tensor
|
||||
next_state: dict[str, torch.Tensor]
|
||||
done: torch.Tensor
|
||||
truncated: torch.Tensor
|
||||
complementary_info: dict[str, torch.Tensor | float | int] | None = None
|
||||
|
||||
|
||||
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
|
||||
"""
|
||||
Perform a per-image random crop over a batch of images in a vectorized way.
|
||||
(Same as shown previously.)
|
||||
"""
|
||||
B, C, H, W = images.shape # noqa: N806
|
||||
crop_h, crop_w = output_size
|
||||
|
||||
if crop_h > H or crop_w > W:
|
||||
raise ValueError(
|
||||
f"Requested crop size ({crop_h}, {crop_w}) is bigger than the image size ({H}, {W})."
|
||||
)
|
||||
|
||||
tops = torch.randint(0, H - crop_h + 1, (B,), device=images.device)
|
||||
lefts = torch.randint(0, W - crop_w + 1, (B,), device=images.device)
|
||||
|
||||
rows = torch.arange(crop_h, device=images.device).unsqueeze(0) + tops.unsqueeze(1)
|
||||
cols = torch.arange(crop_w, device=images.device).unsqueeze(0) + lefts.unsqueeze(1)
|
||||
|
||||
rows = rows.unsqueeze(2).expand(-1, -1, crop_w) # (B, crop_h, crop_w)
|
||||
cols = cols.unsqueeze(1).expand(-1, crop_h, -1) # (B, crop_h, crop_w)
|
||||
|
||||
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
|
||||
|
||||
# Gather pixels
|
||||
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
|
||||
# cropped_hwcn => (B, crop_h, crop_w, C)
|
||||
|
||||
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
|
||||
return cropped
|
||||
|
||||
|
||||
def random_shift(images: torch.Tensor, pad: int = 4):
|
||||
"""Vectorized random shift, imgs: (B,C,H,W), pad: #pixels"""
|
||||
_, _, h, w = images.shape
|
||||
images = F.pad(input=images, pad=(pad, pad, pad, pad), mode="replicate")
|
||||
return random_crop_vectorized(images=images, output_size=(h, w))
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
def __init__(
|
||||
self,
|
||||
capacity: int,
|
||||
device: str = "cuda:0",
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
image_augmentation_function: Optional[Callable] = None,
|
||||
use_drq: bool = True,
|
||||
storage_device: str = "cpu",
|
||||
optimize_memory: bool = False,
|
||||
):
|
||||
"""
|
||||
Replay buffer for storing transitions.
|
||||
It will allocate tensors on the specified device, when the first transition is added.
|
||||
NOTE: If you encounter memory issues, you can try to use the `optimize_memory` flag to save memory or
|
||||
and use the `storage_device` flag to store the buffer on a different device.
|
||||
Args:
|
||||
capacity (int): Maximum number of transitions to store in the buffer.
|
||||
device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu").
|
||||
state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
|
||||
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
|
||||
and returns a batch of augmented images. If None, a default augmentation function is used.
|
||||
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
|
||||
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored.
|
||||
Using "cpu" can help save GPU memory.
|
||||
optimize_memory (bool): If True, optimizes memory by not storing duplicate next_states when
|
||||
they can be derived from states. This is useful for large datasets where next_state[i] = state[i+1].
|
||||
"""
|
||||
if capacity <= 0:
|
||||
raise ValueError("Capacity must be greater than 0.")
|
||||
|
||||
self.capacity = capacity
|
||||
self.device = device
|
||||
self.storage_device = storage_device
|
||||
self.position = 0
|
||||
self.size = 0
|
||||
self.initialized = False
|
||||
self.optimize_memory = optimize_memory
|
||||
|
||||
# Track episode boundaries for memory optimization
|
||||
self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device)
|
||||
|
||||
# If no state_keys provided, default to an empty list
|
||||
self.state_keys = state_keys if state_keys is not None else []
|
||||
|
||||
self.image_augmentation_function = image_augmentation_function
|
||||
|
||||
if image_augmentation_function is None:
|
||||
base_function = functools.partial(random_shift, pad=4)
|
||||
self.image_augmentation_function = torch.compile(base_function)
|
||||
self.use_drq = use_drq
|
||||
|
||||
def _initialize_storage(
|
||||
self,
|
||||
state: dict[str, torch.Tensor],
|
||||
action: torch.Tensor,
|
||||
complementary_info: Optional[dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Initialize the storage tensors based on the first transition."""
|
||||
# Determine shapes from the first transition
|
||||
state_shapes = {key: val.squeeze(0).shape for key, val in state.items()}
|
||||
action_shape = action.squeeze(0).shape
|
||||
|
||||
# Pre-allocate tensors for storage
|
||||
self.states = {
|
||||
key: torch.empty((self.capacity, *shape), device=self.storage_device)
|
||||
for key, shape in state_shapes.items()
|
||||
}
|
||||
self.actions = torch.empty((self.capacity, *action_shape), device=self.storage_device)
|
||||
self.rewards = torch.empty((self.capacity,), device=self.storage_device)
|
||||
|
||||
if not self.optimize_memory:
|
||||
# Standard approach: store states and next_states separately
|
||||
self.next_states = {
|
||||
key: torch.empty((self.capacity, *shape), device=self.storage_device)
|
||||
for key, shape in state_shapes.items()
|
||||
}
|
||||
else:
|
||||
# Memory-optimized approach: don't allocate next_states buffer
|
||||
# Just create a reference to states for consistent API
|
||||
self.next_states = self.states # Just a reference for API consistency
|
||||
|
||||
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||
|
||||
# Initialize storage for complementary_info
|
||||
self.has_complementary_info = complementary_info is not None
|
||||
self.complementary_info_keys = []
|
||||
self.complementary_info = {}
|
||||
|
||||
if self.has_complementary_info:
|
||||
self.complementary_info_keys = list(complementary_info.keys())
|
||||
# Pre-allocate tensors for each key in complementary_info
|
||||
for key, value in complementary_info.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
value_shape = value.squeeze(0).shape
|
||||
self.complementary_info[key] = torch.empty(
|
||||
(self.capacity, *value_shape), device=self.storage_device
|
||||
)
|
||||
elif isinstance(value, (int, float)):
|
||||
# Handle scalar values similar to reward
|
||||
self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported type {type(value)} for complementary_info[{key}]")
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
def add(
|
||||
self,
|
||||
state: dict[str, torch.Tensor],
|
||||
action: torch.Tensor,
|
||||
reward: float,
|
||||
next_state: dict[str, torch.Tensor],
|
||||
done: bool,
|
||||
truncated: bool,
|
||||
complementary_info: Optional[dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
||||
# Initialize storage if this is the first transition
|
||||
if not self.initialized:
|
||||
self._initialize_storage(state=state, action=action, complementary_info=complementary_info)
|
||||
|
||||
# Store the transition in pre-allocated tensors
|
||||
for key in self.states:
|
||||
self.states[key][self.position].copy_(state[key].squeeze(dim=0))
|
||||
|
||||
if not self.optimize_memory:
|
||||
# Only store next_states if not optimizing memory
|
||||
self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0))
|
||||
|
||||
self.actions[self.position].copy_(action.squeeze(dim=0))
|
||||
self.rewards[self.position] = reward
|
||||
self.dones[self.position] = done
|
||||
self.truncateds[self.position] = truncated
|
||||
|
||||
# Handle complementary_info if provided and storage is initialized
|
||||
if complementary_info is not None and self.has_complementary_info:
|
||||
# Store the complementary_info
|
||||
for key in self.complementary_info_keys:
|
||||
if key in complementary_info:
|
||||
value = complementary_info[key]
|
||||
if isinstance(value, torch.Tensor):
|
||||
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
|
||||
elif isinstance(value, (int, float)):
|
||||
self.complementary_info[key][self.position] = value
|
||||
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
self.size = min(self.size + 1, self.capacity)
|
||||
|
||||
def sample(self, batch_size: int) -> BatchTransition:
|
||||
"""Sample a random batch of transitions and collate them into batched tensors."""
|
||||
if not self.initialized:
|
||||
raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.")
|
||||
|
||||
batch_size = min(batch_size, self.size)
|
||||
high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size
|
||||
|
||||
# Random indices for sampling - create on the same device as storage
|
||||
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
|
||||
|
||||
# Identify image keys that need augmentation
|
||||
image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else []
|
||||
|
||||
# Create batched state and next_state
|
||||
batch_state = {}
|
||||
batch_next_state = {}
|
||||
|
||||
# First pass: load all state tensors to target device
|
||||
for key in self.states:
|
||||
batch_state[key] = self.states[key][idx].to(self.device)
|
||||
|
||||
if not self.optimize_memory:
|
||||
# Standard approach - load next_states directly
|
||||
batch_next_state[key] = self.next_states[key][idx].to(self.device)
|
||||
else:
|
||||
# Memory-optimized approach - get next_state from the next index
|
||||
next_idx = (idx + 1) % self.capacity
|
||||
batch_next_state[key] = self.states[key][next_idx].to(self.device)
|
||||
|
||||
# Apply image augmentation in a batched way if needed
|
||||
if self.use_drq and image_keys:
|
||||
# Concatenate all images from state and next_state
|
||||
all_images = []
|
||||
for key in image_keys:
|
||||
all_images.append(batch_state[key])
|
||||
all_images.append(batch_next_state[key])
|
||||
|
||||
# Optimization: Batch all images and apply augmentation once
|
||||
all_images_tensor = torch.cat(all_images, dim=0)
|
||||
augmented_images = self.image_augmentation_function(all_images_tensor)
|
||||
|
||||
# Split the augmented images back to their sources
|
||||
for i, key in enumerate(image_keys):
|
||||
# Calculate offsets for the current image key:
|
||||
# For each key, we have 2*batch_size images (batch_size for states, batch_size for next_states)
|
||||
# States start at index i*2*batch_size and take up batch_size slots
|
||||
batch_state[key] = augmented_images[i * 2 * batch_size : (i * 2 + 1) * batch_size]
|
||||
# Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots
|
||||
batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size]
|
||||
|
||||
# Sample other tensors
|
||||
batch_actions = self.actions[idx].to(self.device)
|
||||
batch_rewards = self.rewards[idx].to(self.device)
|
||||
batch_dones = self.dones[idx].to(self.device).float()
|
||||
batch_truncateds = self.truncateds[idx].to(self.device).float()
|
||||
|
||||
# Sample complementary_info if available
|
||||
batch_complementary_info = None
|
||||
if self.has_complementary_info:
|
||||
batch_complementary_info = {}
|
||||
for key in self.complementary_info_keys:
|
||||
batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
|
||||
|
||||
return BatchTransition(
|
||||
state=batch_state,
|
||||
action=batch_actions,
|
||||
reward=batch_rewards,
|
||||
next_state=batch_next_state,
|
||||
done=batch_dones,
|
||||
truncated=batch_truncateds,
|
||||
complementary_info=batch_complementary_info,
|
||||
)
|
||||
|
||||
def get_iterator(
|
||||
self,
|
||||
batch_size: int,
|
||||
async_prefetch: bool = True,
|
||||
queue_size: int = 2,
|
||||
):
|
||||
"""
|
||||
Creates an infinite iterator that yields batches of transitions.
|
||||
Will automatically restart when internal iterator is exhausted.
|
||||
|
||||
Args:
|
||||
batch_size (int): Size of batches to sample
|
||||
async_prefetch (bool): Whether to use asynchronous prefetching with threads (default: True)
|
||||
queue_size (int): Number of batches to prefetch (default: 2)
|
||||
|
||||
Yields:
|
||||
BatchTransition: Batched transitions
|
||||
"""
|
||||
while True: # Create an infinite loop
|
||||
if async_prefetch:
|
||||
# Get the standard iterator
|
||||
iterator = self._get_async_iterator(queue_size=queue_size, batch_size=batch_size)
|
||||
else:
|
||||
iterator = self._get_naive_iterator(batch_size=batch_size, queue_size=queue_size)
|
||||
|
||||
# Yield all items from the iterator
|
||||
with suppress(StopIteration):
|
||||
yield from iterator
|
||||
|
||||
def _get_async_iterator(self, batch_size: int, queue_size: int = 2):
|
||||
"""
|
||||
Creates an iterator that prefetches batches in a background thread.
|
||||
|
||||
Args:
|
||||
queue_size (int): Number of batches to prefetch (default: 2)
|
||||
batch_size (int): Size of batches to sample (default: 128)
|
||||
|
||||
Yields:
|
||||
BatchTransition: Prefetched batch transitions
|
||||
"""
|
||||
import queue
|
||||
import threading
|
||||
|
||||
# Use thread-safe queue
|
||||
data_queue = queue.Queue(maxsize=queue_size)
|
||||
running = [True] # Use list to allow modification in nested function
|
||||
|
||||
def prefetch_worker():
|
||||
while running[0]:
|
||||
try:
|
||||
# Sample data and add to queue
|
||||
data = self.sample(batch_size)
|
||||
data_queue.put(data, block=True, timeout=0.5)
|
||||
except queue.Full:
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"Prefetch error: {e}")
|
||||
break
|
||||
|
||||
# Start prefetching thread
|
||||
thread = threading.Thread(target=prefetch_worker, daemon=True)
|
||||
thread.start()
|
||||
|
||||
try:
|
||||
while running[0]:
|
||||
try:
|
||||
yield data_queue.get(block=True, timeout=0.5)
|
||||
except queue.Empty:
|
||||
if not thread.is_alive():
|
||||
break
|
||||
finally:
|
||||
# Clean up
|
||||
running[0] = False
|
||||
thread.join(timeout=1.0)
|
||||
|
||||
def _get_naive_iterator(self, batch_size: int, queue_size: int = 2):
|
||||
"""
|
||||
Creates a simple non-threaded iterator that yields batches.
|
||||
|
||||
Args:
|
||||
batch_size (int): Size of batches to sample
|
||||
queue_size (int): Number of initial batches to prefetch
|
||||
|
||||
Yields:
|
||||
BatchTransition: Batch transitions
|
||||
"""
|
||||
import collections
|
||||
|
||||
queue = collections.deque()
|
||||
|
||||
def enqueue(n):
|
||||
for _ in range(n):
|
||||
data = self.sample(batch_size)
|
||||
queue.append(data)
|
||||
|
||||
enqueue(queue_size)
|
||||
while queue:
|
||||
yield queue.popleft()
|
||||
enqueue(1)
|
||||
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
cls,
|
||||
lerobot_dataset: LeRobotDataset,
|
||||
device: str = "cuda:0",
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
capacity: Optional[int] = None,
|
||||
image_augmentation_function: Optional[Callable] = None,
|
||||
use_drq: bool = True,
|
||||
storage_device: str = "cpu",
|
||||
optimize_memory: bool = False,
|
||||
) -> "ReplayBuffer":
|
||||
"""
|
||||
Convert a LeRobotDataset into a ReplayBuffer.
|
||||
|
||||
Args:
|
||||
lerobot_dataset (LeRobotDataset): The dataset to convert.
|
||||
device (str): The device for sampling tensors. Defaults to "cuda:0".
|
||||
state_keys (Optional[Sequence[str]]): The list of keys that appear in `state` and `next_state`.
|
||||
capacity (Optional[int]): Buffer capacity. If None, uses dataset length.
|
||||
action_mask (Optional[Sequence[int]]): Indices of action dimensions to keep.
|
||||
image_augmentation_function (Optional[Callable]): Function for image augmentation.
|
||||
If None, uses default random shift with pad=4.
|
||||
use_drq (bool): Whether to use DrQ image augmentation when sampling.
|
||||
storage_device (str): Device for storing tensor data. Using "cpu" saves GPU memory.
|
||||
optimize_memory (bool): If True, reduces memory usage by not duplicating state data.
|
||||
|
||||
Returns:
|
||||
ReplayBuffer: The replay buffer with dataset transitions.
|
||||
"""
|
||||
if capacity is None:
|
||||
capacity = len(lerobot_dataset)
|
||||
|
||||
if capacity < len(lerobot_dataset):
|
||||
raise ValueError(
|
||||
"The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset."
|
||||
)
|
||||
|
||||
# Create replay buffer with image augmentation and DrQ settings
|
||||
replay_buffer = cls(
|
||||
capacity=capacity,
|
||||
device=device,
|
||||
state_keys=state_keys,
|
||||
image_augmentation_function=image_augmentation_function,
|
||||
use_drq=use_drq,
|
||||
storage_device=storage_device,
|
||||
optimize_memory=optimize_memory,
|
||||
)
|
||||
|
||||
# Convert dataset to transitions
|
||||
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
|
||||
|
||||
# Initialize the buffer with the first transition to set up storage tensors
|
||||
if list_transition:
|
||||
first_transition = list_transition[0]
|
||||
first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
|
||||
first_action = first_transition["action"].to(device)
|
||||
|
||||
# Get complementary info if available
|
||||
first_complementary_info = None
|
||||
if (
|
||||
"complementary_info" in first_transition
|
||||
and first_transition["complementary_info"] is not None
|
||||
):
|
||||
first_complementary_info = {
|
||||
k: v.to(device) for k, v in first_transition["complementary_info"].items()
|
||||
}
|
||||
|
||||
replay_buffer._initialize_storage(
|
||||
state=first_state, action=first_action, complementary_info=first_complementary_info
|
||||
)
|
||||
|
||||
# Fill the buffer with all transitions
|
||||
for data in list_transition:
|
||||
for k, v in data.items():
|
||||
if isinstance(v, dict):
|
||||
for key, tensor in v.items():
|
||||
v[key] = tensor.to(storage_device)
|
||||
elif isinstance(v, torch.Tensor):
|
||||
data[k] = v.to(storage_device)
|
||||
|
||||
action = data["action"]
|
||||
|
||||
replay_buffer.add(
|
||||
state=data["state"],
|
||||
action=action,
|
||||
reward=data["reward"],
|
||||
next_state=data["next_state"],
|
||||
done=data["done"],
|
||||
truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset
|
||||
complementary_info=data.get("complementary_info", None),
|
||||
)
|
||||
|
||||
return replay_buffer
|
||||
|
||||
def to_lerobot_dataset(
|
||||
self,
|
||||
repo_id: str,
|
||||
fps=1,
|
||||
root=None,
|
||||
task_name="from_replay_buffer",
|
||||
) -> LeRobotDataset:
|
||||
"""
|
||||
Converts all transitions in this ReplayBuffer into a single LeRobotDataset object.
|
||||
"""
|
||||
if self.size == 0:
|
||||
raise ValueError("The replay buffer is empty. Cannot convert to a dataset.")
|
||||
|
||||
# Create features dictionary for the dataset
|
||||
features = {
|
||||
"index": {"dtype": "int64", "shape": [1]}, # global index across episodes
|
||||
"episode_index": {"dtype": "int64", "shape": [1]}, # which episode
|
||||
"frame_index": {"dtype": "int64", "shape": [1]}, # index inside an episode
|
||||
"timestamp": {"dtype": "float32", "shape": [1]}, # for now we store dummy
|
||||
"task_index": {"dtype": "int64", "shape": [1]},
|
||||
}
|
||||
|
||||
# Add "action"
|
||||
sample_action = self.actions[0]
|
||||
act_info = guess_feature_info(t=sample_action, name="action")
|
||||
features["action"] = act_info
|
||||
|
||||
# Add "reward" and "done"
|
||||
features["next.reward"] = {"dtype": "float32", "shape": (1,)}
|
||||
features["next.done"] = {"dtype": "bool", "shape": (1,)}
|
||||
|
||||
# Add state keys
|
||||
for key in self.states:
|
||||
sample_val = self.states[key][0]
|
||||
f_info = guess_feature_info(t=sample_val, name=key)
|
||||
features[key] = f_info
|
||||
|
||||
# Add complementary_info keys if available
|
||||
if self.has_complementary_info:
|
||||
for key in self.complementary_info_keys:
|
||||
sample_val = self.complementary_info[key][0]
|
||||
if isinstance(sample_val, torch.Tensor) and sample_val.ndim == 0:
|
||||
sample_val = sample_val.unsqueeze(0)
|
||||
f_info = guess_feature_info(t=sample_val, name=f"complementary_info.{key}")
|
||||
features[f"complementary_info.{key}"] = f_info
|
||||
|
||||
# Create an empty LeRobotDataset
|
||||
lerobot_dataset = LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
fps=fps,
|
||||
root=root,
|
||||
robot=None, # TODO: (azouitine) Handle robot
|
||||
robot_type=None,
|
||||
features=features,
|
||||
use_videos=True,
|
||||
)
|
||||
|
||||
# Start writing images if needed
|
||||
lerobot_dataset.start_image_writer(num_processes=0, num_threads=3)
|
||||
|
||||
# Convert transitions into episodes and frames
|
||||
episode_index = 0
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index=episode_index)
|
||||
|
||||
frame_idx_in_episode = 0
|
||||
for idx in range(self.size):
|
||||
actual_idx = (self.position - self.size + idx) % self.capacity
|
||||
|
||||
frame_dict = {}
|
||||
|
||||
# Fill the data for state keys
|
||||
for key in self.states:
|
||||
frame_dict[key] = self.states[key][actual_idx].cpu()
|
||||
|
||||
# Fill action, reward, done
|
||||
frame_dict["action"] = self.actions[actual_idx].cpu()
|
||||
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
||||
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
|
||||
|
||||
# Add complementary_info if available
|
||||
if self.has_complementary_info:
|
||||
for key in self.complementary_info_keys:
|
||||
val = self.complementary_info[key][actual_idx]
|
||||
# Convert tensors to CPU
|
||||
if isinstance(val, torch.Tensor):
|
||||
if val.ndim == 0:
|
||||
val = val.unsqueeze(0)
|
||||
frame_dict[f"complementary_info.{key}"] = val.cpu()
|
||||
# Non-tensor values can be used directly
|
||||
else:
|
||||
frame_dict[f"complementary_info.{key}"] = val
|
||||
|
||||
# Add task field which is required by LeRobotDataset
|
||||
frame_dict["task"] = task_name
|
||||
|
||||
# Add to the dataset's buffer
|
||||
lerobot_dataset.add_frame(frame_dict)
|
||||
|
||||
# Move to next frame
|
||||
frame_idx_in_episode += 1
|
||||
|
||||
# If we reached an episode boundary, call save_episode, reset counters
|
||||
if self.dones[actual_idx] or self.truncateds[actual_idx]:
|
||||
lerobot_dataset.save_episode()
|
||||
episode_index += 1
|
||||
frame_idx_in_episode = 0
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
|
||||
episode_index=episode_index
|
||||
)
|
||||
|
||||
# Save any remaining frames in the buffer
|
||||
if lerobot_dataset.episode_buffer["size"] > 0:
|
||||
lerobot_dataset.save_episode()
|
||||
|
||||
lerobot_dataset.stop_image_writer()
|
||||
|
||||
return lerobot_dataset
|
||||
|
||||
@staticmethod
|
||||
def _lerobotdataset_to_transitions(
|
||||
dataset: LeRobotDataset,
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
) -> list[Transition]:
|
||||
"""
|
||||
Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions.
|
||||
|
||||
Args:
|
||||
dataset (LeRobotDataset):
|
||||
The dataset to convert. Each item in the dataset is expected to have
|
||||
at least the following keys:
|
||||
{
|
||||
"action": ...
|
||||
"next.reward": ...
|
||||
"next.done": ...
|
||||
"episode_index": ...
|
||||
}
|
||||
plus whatever your 'state_keys' specify.
|
||||
|
||||
state_keys (Optional[Sequence[str]]):
|
||||
The dataset keys to include in 'state' and 'next_state'. Their names
|
||||
will be kept as-is in the output transitions. E.g.
|
||||
["observation.state", "observation.environment_state"].
|
||||
If None, you must handle or define default keys.
|
||||
|
||||
Returns:
|
||||
transitions (List[Transition]):
|
||||
A list of Transition dictionaries with the same length as `dataset`.
|
||||
"""
|
||||
if state_keys is None:
|
||||
raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.")
|
||||
|
||||
transitions = []
|
||||
num_frames = len(dataset)
|
||||
|
||||
# Check if the dataset has "next.done" key
|
||||
sample = dataset[0]
|
||||
has_done_key = "next.done" in sample
|
||||
|
||||
# Check for complementary_info keys
|
||||
complementary_info_keys = [key for key in sample if key.startswith("complementary_info.")]
|
||||
has_complementary_info = len(complementary_info_keys) > 0
|
||||
|
||||
# If not, we need to infer it from episode boundaries
|
||||
if not has_done_key:
|
||||
print("'next.done' key not found in dataset. Inferring from episode boundaries...")
|
||||
|
||||
for i in tqdm(range(num_frames)):
|
||||
current_sample = dataset[i]
|
||||
|
||||
# ----- 1) Current state -----
|
||||
current_state: dict[str, torch.Tensor] = {}
|
||||
for key in state_keys:
|
||||
val = current_sample[key]
|
||||
current_state[key] = val.unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- 2) Action -----
|
||||
action = current_sample["action"].unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- 3) Reward and done -----
|
||||
reward = float(current_sample["next.reward"].item()) # ensure float
|
||||
|
||||
# Determine done flag - use next.done if available, otherwise infer from episode boundaries
|
||||
if has_done_key:
|
||||
done = bool(current_sample["next.done"].item()) # ensure bool
|
||||
else:
|
||||
# If this is the last frame or if next frame is in a different episode, mark as done
|
||||
done = False
|
||||
if i == num_frames - 1:
|
||||
done = True
|
||||
elif i < num_frames - 1:
|
||||
next_sample = dataset[i + 1]
|
||||
if next_sample["episode_index"] != current_sample["episode_index"]:
|
||||
done = True
|
||||
|
||||
# TODO: (azouitine) Handle truncation (using the same value as done for now)
|
||||
truncated = done
|
||||
|
||||
# ----- 4) Next state -----
|
||||
# If not done and the next sample is in the same episode, we pull the next sample's state.
|
||||
# Otherwise (done=True or next sample crosses to a new episode), next_state = current_state.
|
||||
next_state = current_state # default
|
||||
if not done and (i < num_frames - 1):
|
||||
next_sample = dataset[i + 1]
|
||||
if next_sample["episode_index"] == current_sample["episode_index"]:
|
||||
# Build next_state from the same keys
|
||||
next_state_data: dict[str, torch.Tensor] = {}
|
||||
for key in state_keys:
|
||||
val = next_sample[key]
|
||||
next_state_data[key] = val.unsqueeze(0) # Add batch dimension
|
||||
next_state = next_state_data
|
||||
|
||||
# ----- 5) Complementary info (if available) -----
|
||||
complementary_info = None
|
||||
if has_complementary_info:
|
||||
complementary_info = {}
|
||||
for key in complementary_info_keys:
|
||||
# Strip the "complementary_info." prefix to get the actual key
|
||||
clean_key = key[len("complementary_info.") :]
|
||||
val = current_sample[key]
|
||||
# Handle tensor and non-tensor values differently
|
||||
if isinstance(val, torch.Tensor):
|
||||
complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension
|
||||
else:
|
||||
# TODO: (azouitine) Check if it's necessary to convert to tensor
|
||||
# For non-tensor values, use directly
|
||||
complementary_info[clean_key] = val
|
||||
|
||||
# ----- Construct the Transition -----
|
||||
transition = Transition(
|
||||
state=current_state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done,
|
||||
truncated=truncated,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
transitions.append(transition)
|
||||
|
||||
return transitions
|
||||
|
||||
|
||||
# Utility function to guess shapes/dtypes from a tensor
|
||||
def guess_feature_info(t, name: str):
|
||||
"""
|
||||
Return a dictionary with the 'dtype' and 'shape' for a given tensor or scalar value.
|
||||
If it looks like a 3D (C,H,W) shape, we might consider it an 'image'.
|
||||
Otherwise default to appropriate dtype for numeric.
|
||||
"""
|
||||
|
||||
shape = tuple(t.shape)
|
||||
# Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image'
|
||||
if len(shape) == 3 and shape[0] in [1, 3]:
|
||||
return {
|
||||
"dtype": "image",
|
||||
"shape": shape,
|
||||
}
|
||||
else:
|
||||
# Otherwise treat as numeric
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": shape,
|
||||
}
|
||||
|
||||
|
||||
def concatenate_batch_transitions(
|
||||
left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
|
||||
) -> BatchTransition:
|
||||
"""NOTE: Be careful it change the left_batch_transitions in place"""
|
||||
# Concatenate state fields
|
||||
left_batch_transitions["state"] = {
|
||||
key: torch.cat(
|
||||
[left_batch_transitions["state"][key], right_batch_transition["state"][key]],
|
||||
dim=0,
|
||||
)
|
||||
for key in left_batch_transitions["state"]
|
||||
}
|
||||
|
||||
# Concatenate basic fields
|
||||
left_batch_transitions["action"] = torch.cat(
|
||||
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
|
||||
)
|
||||
left_batch_transitions["reward"] = torch.cat(
|
||||
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
|
||||
)
|
||||
|
||||
# Concatenate next_state fields
|
||||
left_batch_transitions["next_state"] = {
|
||||
key: torch.cat(
|
||||
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]],
|
||||
dim=0,
|
||||
)
|
||||
for key in left_batch_transitions["next_state"]
|
||||
}
|
||||
|
||||
# Concatenate done and truncated fields
|
||||
left_batch_transitions["done"] = torch.cat(
|
||||
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
|
||||
)
|
||||
left_batch_transitions["truncated"] = torch.cat(
|
||||
[left_batch_transitions["truncated"], right_batch_transition["truncated"]],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Handle complementary_info
|
||||
left_info = left_batch_transitions.get("complementary_info")
|
||||
right_info = right_batch_transition.get("complementary_info")
|
||||
|
||||
# Only process if right_info exists
|
||||
if right_info is not None:
|
||||
# Initialize left complementary_info if needed
|
||||
if left_info is None:
|
||||
left_batch_transitions["complementary_info"] = right_info
|
||||
else:
|
||||
# Concatenate each field
|
||||
for key in right_info:
|
||||
if key in left_info:
|
||||
left_info[key] = torch.cat([left_info[key], right_info[key]], dim=0)
|
||||
else:
|
||||
left_info[key] = right_info[key]
|
||||
|
||||
return left_batch_transitions
|
||||
303
lerobot/scripts/server/crop_dataset_roi.py
Normal file
303
lerobot/scripts/server/crop_dataset_roi.py
Normal file
@@ -0,0 +1,303 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import cv2
|
||||
|
||||
# import torch.nn.functional as F # noqa: N812
|
||||
import torchvision.transforms.functional as F # type: ignore # noqa: N812
|
||||
from tqdm import tqdm # type: ignore
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def select_rect_roi(img):
|
||||
"""
|
||||
Allows the user to draw a rectangular ROI on the image.
|
||||
|
||||
The user must click and drag to draw the rectangle.
|
||||
- While dragging, the rectangle is dynamically drawn.
|
||||
- On mouse button release, the rectangle is fixed.
|
||||
- Press 'c' to confirm the selection.
|
||||
- Press 'r' to reset the selection.
|
||||
- Press ESC to cancel.
|
||||
|
||||
Returns:
|
||||
A tuple (top, left, height, width) representing the rectangular ROI,
|
||||
or None if no valid ROI is selected.
|
||||
"""
|
||||
# Create a working copy of the image
|
||||
clone = img.copy()
|
||||
working_img = clone.copy()
|
||||
|
||||
roi = None # Will store the final ROI as (top, left, height, width)
|
||||
drawing = False
|
||||
index_x, index_y = -1, -1 # Initial click coordinates
|
||||
|
||||
def mouse_callback(event, x, y, flags, param):
|
||||
nonlocal index_x, index_y, drawing, roi, working_img
|
||||
|
||||
if event == cv2.EVENT_LBUTTONDOWN:
|
||||
# Start drawing: record starting coordinates
|
||||
drawing = True
|
||||
index_x, index_y = x, y
|
||||
|
||||
elif event == cv2.EVENT_MOUSEMOVE:
|
||||
if drawing:
|
||||
# Compute the top-left and bottom-right corners regardless of drag direction
|
||||
top = min(index_y, y)
|
||||
left = min(index_x, x)
|
||||
bottom = max(index_y, y)
|
||||
right = max(index_x, x)
|
||||
# Show a temporary image with the current rectangle drawn
|
||||
temp = working_img.copy()
|
||||
cv2.rectangle(temp, (left, top), (right, bottom), (0, 255, 0), 2)
|
||||
cv2.imshow("Select ROI", temp)
|
||||
|
||||
elif event == cv2.EVENT_LBUTTONUP:
|
||||
# Finish drawing
|
||||
drawing = False
|
||||
top = min(index_y, y)
|
||||
left = min(index_x, x)
|
||||
bottom = max(index_y, y)
|
||||
right = max(index_x, x)
|
||||
height = bottom - top
|
||||
width = right - left
|
||||
roi = (top, left, height, width) # (top, left, height, width)
|
||||
# Draw the final rectangle on the working image and display it
|
||||
working_img = clone.copy()
|
||||
cv2.rectangle(working_img, (left, top), (right, bottom), (0, 255, 0), 2)
|
||||
cv2.imshow("Select ROI", working_img)
|
||||
|
||||
# Create the window and set the callback
|
||||
cv2.namedWindow("Select ROI")
|
||||
cv2.setMouseCallback("Select ROI", mouse_callback)
|
||||
cv2.imshow("Select ROI", working_img)
|
||||
|
||||
print("Instructions for ROI selection:")
|
||||
print(" - Click and drag to draw a rectangular ROI.")
|
||||
print(" - Press 'c' to confirm the selection.")
|
||||
print(" - Press 'r' to reset and draw again.")
|
||||
print(" - Press ESC to cancel the selection.")
|
||||
|
||||
# Wait until the user confirms with 'c', resets with 'r', or cancels with ESC
|
||||
while True:
|
||||
key = cv2.waitKey(1) & 0xFF
|
||||
# Confirm ROI if one has been drawn
|
||||
if key == ord("c") and roi is not None:
|
||||
break
|
||||
# Reset: clear the ROI and restore the original image
|
||||
elif key == ord("r"):
|
||||
working_img = clone.copy()
|
||||
roi = None
|
||||
cv2.imshow("Select ROI", working_img)
|
||||
# Cancel selection for this image
|
||||
elif key == 27: # ESC key
|
||||
roi = None
|
||||
break
|
||||
|
||||
cv2.destroyWindow("Select ROI")
|
||||
return roi
|
||||
|
||||
|
||||
def select_square_roi_for_images(images: dict) -> dict:
|
||||
"""
|
||||
For each image in the provided dictionary, open a window to allow the user
|
||||
to select a rectangular ROI. Returns a dictionary mapping each key to a tuple
|
||||
(top, left, height, width) representing the ROI.
|
||||
|
||||
Parameters:
|
||||
images (dict): Dictionary where keys are identifiers and values are OpenCV images.
|
||||
|
||||
Returns:
|
||||
dict: Mapping of image keys to the selected rectangular ROI.
|
||||
"""
|
||||
selected_rois = {}
|
||||
|
||||
for key, img in images.items():
|
||||
if img is None:
|
||||
print(f"Image for key '{key}' is None, skipping.")
|
||||
continue
|
||||
|
||||
print(f"\nSelect rectangular ROI for image with key: '{key}'")
|
||||
roi = select_rect_roi(img)
|
||||
|
||||
if roi is None:
|
||||
print(f"No valid ROI selected for '{key}'.")
|
||||
else:
|
||||
selected_rois[key] = roi
|
||||
print(f"ROI for '{key}': {roi}")
|
||||
|
||||
return selected_rois
|
||||
|
||||
|
||||
def get_image_from_lerobot_dataset(dataset: LeRobotDataset):
|
||||
"""
|
||||
Find the first row in the dataset and extract the image in order to be used for the crop.
|
||||
"""
|
||||
row = dataset[0]
|
||||
image_dict = {}
|
||||
for k in row:
|
||||
if "image" in k:
|
||||
image_dict[k] = deepcopy(row[k])
|
||||
return image_dict
|
||||
|
||||
|
||||
def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
original_dataset: LeRobotDataset,
|
||||
crop_params_dict: Dict[str, Tuple[int, int, int, int]],
|
||||
new_repo_id: str,
|
||||
new_dataset_root: str,
|
||||
resize_size: Tuple[int, int] = (128, 128),
|
||||
push_to_hub: bool = False,
|
||||
) -> LeRobotDataset:
|
||||
"""
|
||||
Converts an existing LeRobotDataset by iterating over its episodes and frames,
|
||||
applying cropping and resizing to image observations, and saving a new dataset
|
||||
with the transformed data.
|
||||
|
||||
Args:
|
||||
original_dataset (LeRobotDataset): The source dataset.
|
||||
crop_params_dict (Dict[str, Tuple[int, int, int, int]]):
|
||||
A dictionary mapping observation keys to crop parameters (top, left, height, width).
|
||||
new_repo_id (str): Repository id for the new dataset.
|
||||
new_dataset_root (str): The root directory where the new dataset will be written.
|
||||
resize_size (Tuple[int, int], optional): The target size (height, width) after cropping.
|
||||
Defaults to (128, 128).
|
||||
|
||||
Returns:
|
||||
LeRobotDataset: A new LeRobotDataset where the specified image observations have been cropped
|
||||
and resized.
|
||||
"""
|
||||
# 1. Create a new (empty) LeRobotDataset for writing.
|
||||
new_dataset = LeRobotDataset.create(
|
||||
repo_id=new_repo_id,
|
||||
fps=original_dataset.fps,
|
||||
root=new_dataset_root,
|
||||
robot_type=original_dataset.meta.robot_type,
|
||||
features=original_dataset.meta.info["features"],
|
||||
use_videos=len(original_dataset.meta.video_keys) > 0,
|
||||
)
|
||||
|
||||
# Update the metadata for every image key that will be cropped:
|
||||
# (Here we simply set the shape to be the final resize_size.)
|
||||
for key in crop_params_dict:
|
||||
if key in new_dataset.meta.info["features"]:
|
||||
new_dataset.meta.info["features"][key]["shape"] = [3] + list(resize_size)
|
||||
|
||||
# TODO: Directly modify the mp4 video + meta info features, instead of recreating a dataset
|
||||
prev_episode_index = 0
|
||||
for frame_idx in tqdm(range(len(original_dataset))):
|
||||
frame = original_dataset[frame_idx]
|
||||
|
||||
# Create a copy of the frame to add to the new dataset
|
||||
new_frame = {}
|
||||
for key, value in frame.items():
|
||||
if key in ("task_index", "timestamp", "episode_index", "frame_index", "index"):
|
||||
continue
|
||||
if key in ("next.done", "next.reward"):
|
||||
# if not isinstance(value, str) and len(value.shape) == 0:
|
||||
value = value.unsqueeze(0)
|
||||
|
||||
if key in crop_params_dict:
|
||||
top, left, height, width = crop_params_dict[key]
|
||||
# Apply crop then resize.
|
||||
cropped = F.crop(value, top, left, height, width)
|
||||
value = F.resize(cropped, resize_size)
|
||||
value = value.clamp(0, 1)
|
||||
|
||||
new_frame[key] = value
|
||||
|
||||
new_dataset.add_frame(new_frame)
|
||||
|
||||
if frame["episode_index"].item() != prev_episode_index:
|
||||
# Save the episode
|
||||
new_dataset.save_episode()
|
||||
prev_episode_index = frame["episode_index"].item()
|
||||
|
||||
if push_to_hub:
|
||||
new_dataset.push_to_hub()
|
||||
|
||||
return new_dataset
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default="lerobot",
|
||||
help="The repository id of the LeRobot dataset to process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The root directory of the LeRobot dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crop-params-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path to the JSON file containing the ROIs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Whether to push the new dataset to the hub.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root)
|
||||
|
||||
images = get_image_from_lerobot_dataset(dataset)
|
||||
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
|
||||
images = {k: (v * 255).astype("uint8") for k, v in images.items()}
|
||||
|
||||
if args.crop_params_path is None:
|
||||
rois = select_square_roi_for_images(images)
|
||||
else:
|
||||
with open(args.crop_params_path) as f:
|
||||
rois = json.load(f)
|
||||
|
||||
# Print the selected rectangular ROIs
|
||||
print("\nSelected Rectangular Regions of Interest (top, left, height, width):")
|
||||
for key, roi in rois.items():
|
||||
print(f"{key}: {roi}")
|
||||
|
||||
new_repo_id = args.repo_id + "_cropped_resized"
|
||||
new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
|
||||
|
||||
cropped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
original_dataset=dataset,
|
||||
crop_params_dict=rois,
|
||||
new_repo_id=new_repo_id,
|
||||
new_dataset_root=new_dataset_root,
|
||||
resize_size=(128, 128),
|
||||
push_to_hub=args.push_to_hub,
|
||||
)
|
||||
|
||||
meta_dir = new_dataset_root / "meta"
|
||||
meta_dir.mkdir(exist_ok=True)
|
||||
|
||||
with open(meta_dir / "crop_params.json", "w") as f:
|
||||
json.dump(rois, f, indent=4)
|
||||
802
lerobot/scripts/server/end_effector_control_utils.py
Normal file
802
lerobot/scripts/server/end_effector_control_utils.py
Normal file
@@ -0,0 +1,802 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||
|
||||
|
||||
class InputController:
|
||||
"""Base class for input controllers that generate motion deltas."""
|
||||
|
||||
def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01):
|
||||
"""
|
||||
Initialize the controller.
|
||||
|
||||
Args:
|
||||
x_step_size: Base movement step size in meters
|
||||
y_step_size: Base movement step size in meters
|
||||
z_step_size: Base movement step size in meters
|
||||
"""
|
||||
self.x_step_size = x_step_size
|
||||
self.y_step_size = y_step_size
|
||||
self.z_step_size = z_step_size
|
||||
self.running = True
|
||||
self.episode_end_status = None # None, "success", or "failure"
|
||||
self.intervention_flag = False
|
||||
self.open_gripper_command = False
|
||||
self.close_gripper_command = False
|
||||
|
||||
def start(self):
|
||||
"""Start the controller and initialize resources."""
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
"""Stop the controller and release resources."""
|
||||
pass
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas (dx, dy, dz) in meters."""
|
||||
return 0.0, 0.0, 0.0
|
||||
|
||||
def should_quit(self):
|
||||
"""Return True if the user has requested to quit."""
|
||||
return not self.running
|
||||
|
||||
def update(self):
|
||||
"""Update controller state - call this once per frame."""
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
"""Support for use in 'with' statements."""
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Ensure resources are released when exiting 'with' block."""
|
||||
self.stop()
|
||||
|
||||
def get_episode_end_status(self):
|
||||
"""
|
||||
Get the current episode end status.
|
||||
|
||||
Returns:
|
||||
None if episode should continue, "success" or "failure" otherwise
|
||||
"""
|
||||
status = self.episode_end_status
|
||||
self.episode_end_status = None # Reset after reading
|
||||
return status
|
||||
|
||||
def should_intervene(self):
|
||||
"""Return True if intervention flag was set."""
|
||||
return self.intervention_flag
|
||||
|
||||
def gripper_command(self):
|
||||
"""Return the current gripper command."""
|
||||
if self.open_gripper_command == self.close_gripper_command:
|
||||
return "no-op"
|
||||
elif self.open_gripper_command:
|
||||
return "open"
|
||||
elif self.close_gripper_command:
|
||||
return "close"
|
||||
|
||||
|
||||
class KeyboardController(InputController):
|
||||
"""Generate motion deltas from keyboard input."""
|
||||
|
||||
def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01):
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.key_states = {
|
||||
"forward_x": False,
|
||||
"backward_x": False,
|
||||
"forward_y": False,
|
||||
"backward_y": False,
|
||||
"forward_z": False,
|
||||
"backward_z": False,
|
||||
"quit": False,
|
||||
"success": False,
|
||||
"failure": False,
|
||||
}
|
||||
self.listener = None
|
||||
|
||||
def start(self):
|
||||
"""Start the keyboard listener."""
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if key == keyboard.Key.up:
|
||||
self.key_states["forward_x"] = True
|
||||
elif key == keyboard.Key.down:
|
||||
self.key_states["backward_x"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
self.key_states["forward_y"] = True
|
||||
elif key == keyboard.Key.right:
|
||||
self.key_states["backward_y"] = True
|
||||
elif key == keyboard.Key.shift:
|
||||
self.key_states["backward_z"] = True
|
||||
elif key == keyboard.Key.shift_r:
|
||||
self.key_states["forward_z"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
self.key_states["quit"] = True
|
||||
self.running = False
|
||||
return False
|
||||
elif key == keyboard.Key.enter:
|
||||
self.key_states["success"] = True
|
||||
self.episode_end_status = "success"
|
||||
elif key == keyboard.Key.backspace:
|
||||
self.key_states["failure"] = True
|
||||
self.episode_end_status = "failure"
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def on_release(key):
|
||||
try:
|
||||
if key == keyboard.Key.up:
|
||||
self.key_states["forward_x"] = False
|
||||
elif key == keyboard.Key.down:
|
||||
self.key_states["backward_x"] = False
|
||||
elif key == keyboard.Key.left:
|
||||
self.key_states["forward_y"] = False
|
||||
elif key == keyboard.Key.right:
|
||||
self.key_states["backward_y"] = False
|
||||
elif key == keyboard.Key.shift:
|
||||
self.key_states["backward_z"] = False
|
||||
elif key == keyboard.Key.shift_r:
|
||||
self.key_states["forward_z"] = False
|
||||
elif key == keyboard.Key.enter:
|
||||
self.key_states["success"] = False
|
||||
elif key == keyboard.Key.backspace:
|
||||
self.key_states["failure"] = False
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
self.listener = keyboard.Listener(on_press=on_press, on_release=on_release)
|
||||
self.listener.start()
|
||||
|
||||
print("Keyboard controls:")
|
||||
print(" Arrow keys: Move in X-Y plane")
|
||||
print(" Shift and Shift_R: Move in Z axis")
|
||||
print(" Enter: End episode with SUCCESS")
|
||||
print(" Backspace: End episode with FAILURE")
|
||||
print(" ESC: Exit")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the keyboard listener."""
|
||||
if self.listener and self.listener.is_alive():
|
||||
self.listener.stop()
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from keyboard state."""
|
||||
delta_x = delta_y = delta_z = 0.0
|
||||
|
||||
if self.key_states["forward_x"]:
|
||||
delta_x += self.x_step_size
|
||||
if self.key_states["backward_x"]:
|
||||
delta_x -= self.x_step_size
|
||||
if self.key_states["forward_y"]:
|
||||
delta_y += self.y_step_size
|
||||
if self.key_states["backward_y"]:
|
||||
delta_y -= self.y_step_size
|
||||
if self.key_states["forward_z"]:
|
||||
delta_z += self.z_step_size
|
||||
if self.key_states["backward_z"]:
|
||||
delta_z -= self.z_step_size
|
||||
|
||||
return delta_x, delta_y, delta_z
|
||||
|
||||
def should_quit(self):
|
||||
"""Return True if ESC was pressed."""
|
||||
return self.key_states["quit"]
|
||||
|
||||
def should_save(self):
|
||||
"""Return True if Enter was pressed (save episode)."""
|
||||
return self.key_states["success"] or self.key_states["failure"]
|
||||
|
||||
|
||||
class GamepadController(InputController):
|
||||
"""Generate motion deltas from gamepad input."""
|
||||
|
||||
def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01, deadzone=0.1):
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.deadzone = deadzone
|
||||
self.joystick = None
|
||||
self.intervention_flag = False
|
||||
|
||||
def start(self):
|
||||
"""Initialize pygame and the gamepad."""
|
||||
import pygame
|
||||
|
||||
pygame.init()
|
||||
pygame.joystick.init()
|
||||
|
||||
if pygame.joystick.get_count() == 0:
|
||||
logging.error("No gamepad detected. Please connect a gamepad and try again.")
|
||||
self.running = False
|
||||
return
|
||||
|
||||
self.joystick = pygame.joystick.Joystick(0)
|
||||
self.joystick.init()
|
||||
logging.info(f"Initialized gamepad: {self.joystick.get_name()}")
|
||||
|
||||
print("Gamepad controls:")
|
||||
print(" Left analog stick: Move in X-Y plane")
|
||||
print(" Right analog stick (vertical): Move in Z axis")
|
||||
print(" B/Circle button: Exit")
|
||||
print(" Y/Triangle button: End episode with SUCCESS")
|
||||
print(" A/Cross button: End episode with FAILURE")
|
||||
print(" X/Square button: Rerecord episode")
|
||||
|
||||
def stop(self):
|
||||
"""Clean up pygame resources."""
|
||||
import pygame
|
||||
|
||||
if pygame.joystick.get_init():
|
||||
if self.joystick:
|
||||
self.joystick.quit()
|
||||
pygame.joystick.quit()
|
||||
pygame.quit()
|
||||
|
||||
def update(self):
|
||||
"""Process pygame events to get fresh gamepad readings."""
|
||||
import pygame
|
||||
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.JOYBUTTONDOWN:
|
||||
if event.button == 3:
|
||||
self.episode_end_status = "success"
|
||||
# A button (1) for failure
|
||||
elif event.button == 1:
|
||||
self.episode_end_status = "failure"
|
||||
# X button (0) for rerecord
|
||||
elif event.button == 0:
|
||||
self.episode_end_status = "rerecord_episode"
|
||||
|
||||
# RB button (6) for closing gripper
|
||||
elif event.button == 6:
|
||||
self.close_gripper_command = True
|
||||
|
||||
# LT button (7) for opening gripper
|
||||
elif event.button == 7:
|
||||
self.open_gripper_command = True
|
||||
|
||||
# Reset episode status on button release
|
||||
elif event.type == pygame.JOYBUTTONUP:
|
||||
if event.button in [0, 2, 3]:
|
||||
self.episode_end_status = None
|
||||
|
||||
elif event.button == 6:
|
||||
self.close_gripper_command = False
|
||||
|
||||
elif event.button == 7:
|
||||
self.open_gripper_command = False
|
||||
|
||||
# Check for RB button (typically button 5) for intervention flag
|
||||
if self.joystick.get_button(5):
|
||||
self.intervention_flag = True
|
||||
else:
|
||||
self.intervention_flag = False
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from gamepad state."""
|
||||
import pygame
|
||||
|
||||
try:
|
||||
# Read joystick axes
|
||||
# Left stick X and Y (typically axes 0 and 1)
|
||||
x_input = self.joystick.get_axis(0) # Left/Right
|
||||
y_input = self.joystick.get_axis(1) # Up/Down (often inverted)
|
||||
|
||||
# Right stick Y (typically axis 3 or 4)
|
||||
z_input = self.joystick.get_axis(3) # Up/Down for Z
|
||||
|
||||
# Apply deadzone to avoid drift
|
||||
x_input = 0 if abs(x_input) < self.deadzone else x_input
|
||||
y_input = 0 if abs(y_input) < self.deadzone else y_input
|
||||
z_input = 0 if abs(z_input) < self.deadzone else z_input
|
||||
|
||||
# Calculate deltas (note: may need to invert axes depending on controller)
|
||||
delta_x = -y_input * self.y_step_size # Forward/backward
|
||||
delta_y = -x_input * self.x_step_size # Left/right
|
||||
delta_z = -z_input * self.z_step_size # Up/down
|
||||
|
||||
return delta_x, delta_y, delta_z
|
||||
|
||||
except pygame.error:
|
||||
logging.error("Error reading gamepad. Is it still connected?")
|
||||
return 0.0, 0.0, 0.0
|
||||
|
||||
|
||||
class GamepadControllerHID(InputController):
|
||||
"""Generate motion deltas from gamepad input using HIDAPI."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
x_step_size=0.01,
|
||||
y_step_size=0.01,
|
||||
z_step_size=0.01,
|
||||
deadzone=0.1,
|
||||
vendor_id=0x046D,
|
||||
product_id=0xC219,
|
||||
):
|
||||
"""
|
||||
Initialize the HID gamepad controller.
|
||||
|
||||
Args:
|
||||
step_size: Base movement step size in meters
|
||||
z_scale: Scaling factor for Z-axis movement
|
||||
deadzone: Joystick deadzone to prevent drift
|
||||
vendor_id: USB vendor ID of the gamepad (default: Logitech)
|
||||
product_id: USB product ID of the gamepad (default: RumblePad 2)
|
||||
"""
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.deadzone = deadzone
|
||||
self.vendor_id = vendor_id
|
||||
self.product_id = product_id
|
||||
self.device = None
|
||||
self.device_info = None
|
||||
|
||||
# Movement values (normalized from -1.0 to 1.0)
|
||||
self.left_x = 0.0
|
||||
self.left_y = 0.0
|
||||
self.right_x = 0.0
|
||||
self.right_y = 0.0
|
||||
|
||||
# Button states
|
||||
self.buttons = {}
|
||||
self.quit_requested = False
|
||||
self.save_requested = False
|
||||
|
||||
def find_device(self):
|
||||
"""Look for the gamepad device by vendor and product ID."""
|
||||
import hid
|
||||
|
||||
devices = hid.enumerate()
|
||||
for device in devices:
|
||||
if device["vendor_id"] == self.vendor_id and device["product_id"] == self.product_id:
|
||||
logging.info(f"Found gamepad: {device.get('product_string', 'Unknown')}")
|
||||
return device
|
||||
|
||||
logging.error(
|
||||
f"No gamepad with vendor ID 0x{self.vendor_id:04X} and product ID 0x{self.product_id:04X} found"
|
||||
)
|
||||
return None
|
||||
|
||||
def start(self):
|
||||
"""Connect to the gamepad using HIDAPI."""
|
||||
import hid
|
||||
|
||||
self.device_info = self.find_device()
|
||||
if not self.device_info:
|
||||
self.running = False
|
||||
return
|
||||
|
||||
try:
|
||||
logging.info(f"Connecting to gamepad at path: {self.device_info['path']}")
|
||||
self.device = hid.device()
|
||||
self.device.open_path(self.device_info["path"])
|
||||
self.device.set_nonblocking(1)
|
||||
|
||||
manufacturer = self.device.get_manufacturer_string()
|
||||
product = self.device.get_product_string()
|
||||
logging.info(f"Connected to {manufacturer} {product}")
|
||||
|
||||
logging.info("Gamepad controls (HID mode):")
|
||||
logging.info(" Left analog stick: Move in X-Y plane")
|
||||
logging.info(" Right analog stick: Move in Z axis (vertical)")
|
||||
logging.info(" Button 1/B/Circle: Exit")
|
||||
logging.info(" Button 2/A/Cross: End episode with SUCCESS")
|
||||
logging.info(" Button 3/X/Square: End episode with FAILURE")
|
||||
|
||||
except OSError as e:
|
||||
logging.error(f"Error opening gamepad: {e}")
|
||||
logging.error("You might need to run this with sudo/admin privileges on some systems")
|
||||
self.running = False
|
||||
|
||||
def stop(self):
|
||||
"""Close the HID device connection."""
|
||||
if self.device:
|
||||
self.device.close()
|
||||
self.device = None
|
||||
|
||||
def update(self):
|
||||
"""
|
||||
Read and process the latest gamepad data.
|
||||
Due to an issue with the HIDAPI, we need to read the read the device several times in order to get a stable reading
|
||||
"""
|
||||
for _ in range(10):
|
||||
self._update()
|
||||
|
||||
def _update(self):
|
||||
"""Read and process the latest gamepad data."""
|
||||
if not self.device or not self.running:
|
||||
return
|
||||
|
||||
try:
|
||||
# Read data from the gamepad
|
||||
data = self.device.read(64)
|
||||
# Interpret gamepad data - this will vary by controller model
|
||||
# These offsets are for the Logitech RumblePad 2
|
||||
if data and len(data) >= 8:
|
||||
# Normalize joystick values from 0-255 to -1.0-1.0
|
||||
self.left_x = (data[1] - 128) / 128.0
|
||||
self.left_y = (data[2] - 128) / 128.0
|
||||
self.right_x = (data[3] - 128) / 128.0
|
||||
self.right_y = (data[4] - 128) / 128.0
|
||||
|
||||
# Apply deadzone
|
||||
self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x
|
||||
self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y
|
||||
self.right_x = 0 if abs(self.right_x) < self.deadzone else self.right_x
|
||||
self.right_y = 0 if abs(self.right_y) < self.deadzone else self.right_y
|
||||
|
||||
# Parse button states (byte 5 in the Logitech RumblePad 2)
|
||||
buttons = data[5]
|
||||
|
||||
# Check if RB is pressed then the intervention flag should be set
|
||||
self.intervention_flag = data[6] in [2, 6, 10, 14]
|
||||
|
||||
# Check if RT is pressed
|
||||
self.open_gripper_command = data[6] in [8, 10, 12]
|
||||
|
||||
# Check if LT is pressed
|
||||
self.close_gripper_command = data[6] in [4, 6, 12]
|
||||
|
||||
# Check if Y/Triangle button (bit 7) is pressed for saving
|
||||
# Check if X/Square button (bit 5) is pressed for failure
|
||||
# Check if A/Cross button (bit 4) is pressed for rerecording
|
||||
if buttons & 1 << 7:
|
||||
self.episode_end_status = "success"
|
||||
elif buttons & 1 << 5:
|
||||
self.episode_end_status = "failure"
|
||||
elif buttons & 1 << 4:
|
||||
self.episode_end_status = "rerecord_episode"
|
||||
else:
|
||||
self.episode_end_status = None
|
||||
|
||||
except OSError as e:
|
||||
logging.error(f"Error reading from gamepad: {e}")
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from gamepad state."""
|
||||
# Calculate deltas - invert as needed based on controller orientation
|
||||
delta_x = -self.left_y * self.x_step_size # Forward/backward
|
||||
delta_y = -self.left_x * self.y_step_size # Left/right
|
||||
delta_z = -self.right_y * self.z_step_size # Up/down
|
||||
|
||||
return delta_x, delta_y, delta_z
|
||||
|
||||
def should_quit(self):
|
||||
"""Return True if quit button was pressed."""
|
||||
return self.quit_requested
|
||||
|
||||
def should_save(self):
|
||||
"""Return True if save button was pressed."""
|
||||
return self.save_requested
|
||||
|
||||
|
||||
def test_forward_kinematics(robot, fps=10):
|
||||
logging.info("Testing Forward Kinematics")
|
||||
timestep = time.perf_counter()
|
||||
kinematics = RobotKinematics(robot.robot_type)
|
||||
while time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
robot.teleop_step()
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
logging.info(f"EE Position: {ee_pos[:3, 3]}")
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
|
||||
def test_inverse_kinematics(robot, fps=10):
|
||||
logging.info("Testing Inverse Kinematics")
|
||||
timestep = time.perf_counter()
|
||||
while time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
ee_pos = RobotKinematics.fk_gripper_tip(joint_positions)
|
||||
desired_ee_pos = ee_pos
|
||||
target_joint_state = RobotKinematics.ik(joint_positions, desired_ee_pos, position_only=True)
|
||||
robot.send_action(torch.from_numpy(target_joint_state))
|
||||
logging.info(f"Target Joint State: {target_joint_state}")
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
|
||||
def teleoperate_inverse_kinematics_with_leader(robot, fps=10):
|
||||
logging.info("Testing Inverse Kinematics")
|
||||
kinematics = RobotKinematics(robot.robot_type)
|
||||
timestep = time.perf_counter()
|
||||
while time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
|
||||
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
|
||||
leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
|
||||
|
||||
desired_ee_pos = leader_ee
|
||||
target_joint_state = kinematics.ik(
|
||||
joint_positions, desired_ee_pos, position_only=True, fk_func=kinematics.fk_gripper_tip
|
||||
)
|
||||
robot.send_action(torch.from_numpy(target_joint_state))
|
||||
logging.info(f"Leader EE: {leader_ee[:3, 3]}, Follower EE: {ee_pos[:3, 3]}")
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
|
||||
def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
|
||||
logging.info("Testing Delta End-Effector Control")
|
||||
timestep = time.perf_counter()
|
||||
|
||||
# Initial position capture
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
|
||||
kinematics = RobotKinematics(robot.robot_type)
|
||||
|
||||
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
|
||||
initial_leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
|
||||
|
||||
desired_ee_pos = np.diag(np.ones(4))
|
||||
joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
fixed_ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
|
||||
while time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
|
||||
# Get leader state for teleoperation
|
||||
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
|
||||
leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
|
||||
|
||||
# Get current state
|
||||
# obs = robot.capture_observation()
|
||||
# joint_positions = obs["observation.state"].cpu().numpy()
|
||||
joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
|
||||
# Calculate delta between leader and follower end-effectors
|
||||
# Scaling factor can be adjusted for sensitivity
|
||||
scaling_factor = 1.0
|
||||
ee_delta = -np.clip((leader_ee - initial_leader_ee) * scaling_factor, -0.05, 0.05)
|
||||
|
||||
# Apply delta to current position
|
||||
desired_ee_pos[0, 3] = fixed_ee_pos[0, 3] # current_ee_pos[0, 3] + ee_delta[0, 3] * 0
|
||||
desired_ee_pos[1, 3] = fixed_ee_pos[1, 3] # current_ee_pos[1, 3] + ee_delta[1, 3] * 0
|
||||
desired_ee_pos[2, 3] = current_ee_pos[2, 3] - ee_delta[2, 3]
|
||||
|
||||
# Compute joint targets via inverse kinematics
|
||||
target_joint_state = kinematics.ik(
|
||||
joint_positions, desired_ee_pos, position_only=True, fk_func=kinematics.fk_gripper_tip
|
||||
)
|
||||
|
||||
initial_leader_ee = leader_ee.copy()
|
||||
|
||||
# Send command to robot
|
||||
robot.send_action(torch.from_numpy(target_joint_state))
|
||||
|
||||
# Logging
|
||||
logging.info(f"Current EE: {current_ee_pos[:3, 3]}, Desired EE: {desired_ee_pos[:3, 3]}")
|
||||
logging.info(f"Delta EE: {ee_delta[:3, 3]}")
|
||||
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
|
||||
def teleoperate_delta_inverse_kinematics(robot, controller, fps=10, bounds=None, fk_func=None):
|
||||
"""
|
||||
Control a robot using delta end-effector movements from any input controller.
|
||||
|
||||
Args:
|
||||
robot: Robot instance to control
|
||||
controller: InputController instance (keyboard, gamepad, etc.)
|
||||
fps: Control frequency in Hz
|
||||
bounds: Optional position limits
|
||||
fk_func: Forward kinematics function to use
|
||||
"""
|
||||
if fk_func is None:
|
||||
fk_func = RobotKinematics.fk_gripper_tip
|
||||
|
||||
logging.info(f"Testing Delta End-Effector Control with {controller.__class__.__name__}")
|
||||
|
||||
# Initial position capture
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
kinematics = RobotKinematics(robot.robot_type)
|
||||
current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
|
||||
# Initialize desired position with current position
|
||||
desired_ee_pos = np.eye(4) # Identity matrix
|
||||
|
||||
timestep = time.perf_counter()
|
||||
with controller:
|
||||
while not controller.should_quit() and time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
|
||||
# Process input events
|
||||
controller.update()
|
||||
|
||||
# Get current robot state
|
||||
joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
|
||||
# Get movement deltas from the controller
|
||||
delta_x, delta_y, delta_z = controller.get_deltas()
|
||||
|
||||
# Update desired position
|
||||
desired_ee_pos[0, 3] = current_ee_pos[0, 3] + delta_x
|
||||
desired_ee_pos[1, 3] = current_ee_pos[1, 3] + delta_y
|
||||
desired_ee_pos[2, 3] = current_ee_pos[2, 3] + delta_z
|
||||
|
||||
# Apply bounds if provided
|
||||
if bounds is not None:
|
||||
desired_ee_pos[:3, 3] = np.clip(desired_ee_pos[:3, 3], bounds["min"], bounds["max"])
|
||||
|
||||
# Only send commands if there's actual movement
|
||||
if any(abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]):
|
||||
# Compute joint targets via inverse kinematics
|
||||
target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True)
|
||||
|
||||
# Send command to robot
|
||||
robot.send_action(torch.from_numpy(target_joint_state))
|
||||
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
|
||||
def teleoperate_gym_env(env, controller, fps: int = 30):
|
||||
"""
|
||||
Control a robot through a gym environment using keyboard inputs.
|
||||
|
||||
Args:
|
||||
env: A gym environment created with make_robot_env
|
||||
fps: Target control frequency
|
||||
"""
|
||||
|
||||
logging.info("Testing Keyboard Control of Gym Environment")
|
||||
print("Keyboard controls:")
|
||||
print(" Arrow keys: Move in X-Y plane")
|
||||
print(" Shift and Shift_R: Move in Z axis")
|
||||
print(" ESC: Exit")
|
||||
|
||||
# Reset the environment to get initial observation
|
||||
obs, info = env.reset()
|
||||
|
||||
try:
|
||||
with controller:
|
||||
while not controller.should_quit():
|
||||
loop_start_time = time.perf_counter()
|
||||
|
||||
# Process input events
|
||||
controller.update()
|
||||
|
||||
# Get movement deltas from the controller
|
||||
delta_x, delta_y, delta_z = controller.get_deltas()
|
||||
|
||||
# Create the action vector
|
||||
action = np.array([delta_x, delta_y, delta_z])
|
||||
|
||||
# Skip if no movement
|
||||
if any(abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]):
|
||||
# Step the environment - pass action as a tensor with intervention flag
|
||||
action_tensor = torch.from_numpy(action.astype(np.float32))
|
||||
obs, reward, terminated, truncated, info = env.step((action_tensor, False))
|
||||
|
||||
# Log information
|
||||
logging.info(f"Action: [{delta_x:.4f}, {delta_y:.4f}, {delta_z:.4f}]")
|
||||
logging.info(f"Reward: {reward}")
|
||||
|
||||
# Reset if episode ended
|
||||
if terminated or truncated:
|
||||
logging.info("Episode ended, resetting environment")
|
||||
obs, info = env.reset()
|
||||
|
||||
# Maintain target frame rate
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
finally:
|
||||
# Close the environment
|
||||
env.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from lerobot.common.envs.configs import EEActionSpaceConfig, EnvTransformConfig, HILSerlRobotEnvConfig
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
||||
|
||||
init_logging()
|
||||
|
||||
parser = argparse.ArgumentParser(description="Test end-effector control")
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
default="keyboard",
|
||||
choices=[
|
||||
"keyboard",
|
||||
"gamepad",
|
||||
"keyboard_gym",
|
||||
"gamepad_gym",
|
||||
"leader_delta",
|
||||
"leader",
|
||||
],
|
||||
help="Control mode to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-type",
|
||||
type=str,
|
||||
default="so100",
|
||||
help="Robot type (so100, koch, aloha, etc.)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
robot_config = RobotConfig.get_choice_class(args.robot_type)(mock=False)
|
||||
robot = make_robot_from_config(robot_config)
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
# Example bounds
|
||||
bounds = {
|
||||
"max": np.array([0.32170487, 0.201285, 0.10273342]),
|
||||
"min": np.array([0.16631757, -0.08237468, 0.03364977]),
|
||||
}
|
||||
|
||||
try:
|
||||
# Determine controller type based on mode prefix
|
||||
controller = None
|
||||
if args.mode.startswith("keyboard"):
|
||||
controller = KeyboardController(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
|
||||
elif args.mode.startswith("gamepad"):
|
||||
if sys.platform == "darwin":
|
||||
controller = GamepadControllerHID(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
|
||||
else:
|
||||
controller = GamepadController(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
|
||||
|
||||
# Handle mode categories
|
||||
if args.mode in ["keyboard", "gamepad"]:
|
||||
# Direct robot control modes
|
||||
teleoperate_delta_inverse_kinematics(robot, controller, bounds=bounds, fps=10)
|
||||
|
||||
elif args.mode in ["keyboard_gym", "gamepad_gym"]:
|
||||
# Gym environment control modes
|
||||
cfg = HILSerlRobotEnvConfig(robot=robot_config, wrapper=EnvTransformConfig())
|
||||
cfg.wrapper.ee_action_space_params = EEActionSpaceConfig(
|
||||
x_step_size=0.03, y_step_size=0.03, z_step_size=0.03, bounds=bounds
|
||||
)
|
||||
cfg.wrapper.ee_action_space_params.use_gamepad = False
|
||||
cfg.device = "cpu"
|
||||
env = make_robot_env(cfg, robot)
|
||||
teleoperate_gym_env(env, controller, fps=cfg.fps)
|
||||
|
||||
elif args.mode == "leader_delta":
|
||||
# Leader-follower modes don't use controllers
|
||||
teleoperate_delta_inverse_kinematics_with_leader(robot)
|
||||
|
||||
elif args.mode == "leader":
|
||||
teleoperate_inverse_kinematics_with_leader(robot)
|
||||
|
||||
finally:
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
135
lerobot/scripts/server/find_joint_limits.py
Normal file
135
lerobot/scripts/server/find_joint_limits.py
Normal file
@@ -0,0 +1,135 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.robot_devices.control_utils import is_headless
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||
from lerobot.configs import parser
|
||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||
|
||||
|
||||
def find_joint_bounds(
|
||||
robot,
|
||||
control_time_s=30,
|
||||
display_cameras=False,
|
||||
):
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
start_episode_t = time.perf_counter()
|
||||
pos_list = []
|
||||
while True:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
|
||||
# Wait for 5 seconds to stabilize the robot initial position
|
||||
if time.perf_counter() - start_episode_t < 5:
|
||||
continue
|
||||
|
||||
pos_list.append(robot.follower_arms["main"].read("Present_Position"))
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
|
||||
if time.perf_counter() - start_episode_t > control_time_s:
|
||||
max = np.max(np.stack(pos_list), 0)
|
||||
min = np.min(np.stack(pos_list), 0)
|
||||
print(f"Max angle position per joint {max}")
|
||||
print(f"Min angle position per joint {min}")
|
||||
break
|
||||
|
||||
|
||||
def find_ee_bounds(
|
||||
robot,
|
||||
control_time_s=30,
|
||||
display_cameras=False,
|
||||
):
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
start_episode_t = time.perf_counter()
|
||||
ee_list = []
|
||||
while True:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
|
||||
# Wait for 5 seconds to stabilize the robot initial position
|
||||
if time.perf_counter() - start_episode_t < 5:
|
||||
continue
|
||||
|
||||
kinematics = RobotKinematics(robot.robot_type)
|
||||
joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
print(f"Joint positions: {joint_positions}")
|
||||
ee_list.append(kinematics.fk_gripper_tip(joint_positions)[:3, 3])
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
|
||||
if time.perf_counter() - start_episode_t > control_time_s:
|
||||
max = np.max(np.stack(ee_list), 0)
|
||||
min = np.min(np.stack(ee_list), 0)
|
||||
print(f"Max ee position {max}")
|
||||
print(f"Min ee position {min}")
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Create argparse for script-specific arguments
|
||||
parser = argparse.ArgumentParser(add_help=False) # Set add_help=False to avoid conflict
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
default="joint",
|
||||
choices=["joint", "ee"],
|
||||
help="Mode to run the script in. Can be 'joint' or 'ee'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--control-time-s",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Time step to use for control.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-type",
|
||||
type=str,
|
||||
default="so100",
|
||||
help="Robot type (so100, koch, aloha, etc.)",
|
||||
)
|
||||
|
||||
# Only parse known args, leaving robot config args for Hydra if used
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create robot with the appropriate config
|
||||
robot_config = RobotConfig.get_choice_class(args.robot_type)(mock=False)
|
||||
robot = make_robot_from_config(robot_config)
|
||||
|
||||
if args.mode == "joint":
|
||||
find_joint_bounds(robot, args.control_time_s)
|
||||
elif args.mode == "ee":
|
||||
find_ee_bounds(robot, args.control_time_s)
|
||||
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
2202
lerobot/scripts/server/gym_manipulator.py
Normal file
2202
lerobot/scripts/server/gym_manipulator.py
Normal file
File diff suppressed because it is too large
Load Diff
54
lerobot/scripts/server/hilserl.proto
Normal file
54
lerobot/scripts/server/hilserl.proto
Normal file
@@ -0,0 +1,54 @@
|
||||
// Copyright 2024 The HuggingFace Inc. team.
|
||||
// All rights reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package hil_serl;
|
||||
|
||||
// LearnerService: the Actor calls this to push transitions.
|
||||
// The Learner implements this service.
|
||||
service LearnerService {
|
||||
// Actor -> Learner to store transitions
|
||||
rpc SendInteractionMessage(InteractionMessage) returns (Empty);
|
||||
rpc StreamParameters(Empty) returns (stream Parameters);
|
||||
rpc SendTransitions(stream Transition) returns (Empty);
|
||||
rpc SendInteractions(stream InteractionMessage) returns (Empty);
|
||||
rpc Ready(Empty) returns (Empty);
|
||||
}
|
||||
|
||||
enum TransferState {
|
||||
TRANSFER_UNKNOWN = 0;
|
||||
TRANSFER_BEGIN = 1;
|
||||
TRANSFER_MIDDLE = 2;
|
||||
TRANSFER_END = 3;
|
||||
}
|
||||
|
||||
// Messages
|
||||
message Transition {
|
||||
TransferState transfer_state = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message Parameters {
|
||||
TransferState transfer_state = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message InteractionMessage {
|
||||
TransferState transfer_state = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
46
lerobot/scripts/server/hilserl_pb2.py
Normal file
46
lerobot/scripts/server/hilserl_pb2.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: hilserl.proto
|
||||
# Protobuf Python Version: 5.29.0
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import runtime_version as _runtime_version
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||
_runtime_version.Domain.PUBLIC,
|
||||
5,
|
||||
29,
|
||||
0,
|
||||
'',
|
||||
'hilserl.proto'
|
||||
)
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"K\n\nTransition\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"K\n\nParameters\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"S\n\x12InteractionMessage\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xc2\x02\n\x0eLearnerService\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty\x12;\n\x10StreamParameters\x12\x0f.hil_serl.Empty\x1a\x14.hil_serl.Parameters0\x01\x12:\n\x0fSendTransitions\x12\x14.hil_serl.Transition\x1a\x0f.hil_serl.Empty(\x01\x12\x43\n\x10SendInteractions\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty(\x01\x12)\n\x05Ready\x12\x0f.hil_serl.Empty\x1a\x0f.hil_serl.Emptyb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'hilserl_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_TRANSFERSTATE']._serialized_start=275
|
||||
_globals['_TRANSFERSTATE']._serialized_end=371
|
||||
_globals['_TRANSITION']._serialized_start=27
|
||||
_globals['_TRANSITION']._serialized_end=102
|
||||
_globals['_PARAMETERS']._serialized_start=104
|
||||
_globals['_PARAMETERS']._serialized_end=179
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_start=181
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_end=264
|
||||
_globals['_EMPTY']._serialized_start=266
|
||||
_globals['_EMPTY']._serialized_end=273
|
||||
_globals['_LEARNERSERVICE']._serialized_start=374
|
||||
_globals['_LEARNERSERVICE']._serialized_end=696
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
276
lerobot/scripts/server/hilserl_pb2_grpc.py
Normal file
276
lerobot/scripts/server/hilserl_pb2_grpc.py
Normal file
@@ -0,0 +1,276 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
import hilserl_pb2 as hilserl__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.70.0'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
_version_not_supported = False
|
||||
|
||||
try:
|
||||
from grpc._utilities import first_version_is_lower
|
||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||
except ImportError:
|
||||
_version_not_supported = True
|
||||
|
||||
if _version_not_supported:
|
||||
raise RuntimeError(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ f' but the generated code in hilserl_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
)
|
||||
|
||||
|
||||
class LearnerServiceStub(object):
|
||||
"""LearnerService: the Actor calls this to push transitions.
|
||||
The Learner implements this service.
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.SendInteractionMessage = channel.unary_unary(
|
||||
'/hil_serl.LearnerService/SendInteractionMessage',
|
||||
request_serializer=hilserl__pb2.InteractionMessage.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.StreamParameters = channel.unary_stream(
|
||||
'/hil_serl.LearnerService/StreamParameters',
|
||||
request_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.Parameters.FromString,
|
||||
_registered_method=True)
|
||||
self.SendTransitions = channel.stream_unary(
|
||||
'/hil_serl.LearnerService/SendTransitions',
|
||||
request_serializer=hilserl__pb2.Transition.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.SendInteractions = channel.stream_unary(
|
||||
'/hil_serl.LearnerService/SendInteractions',
|
||||
request_serializer=hilserl__pb2.InteractionMessage.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.Ready = channel.unary_unary(
|
||||
'/hil_serl.LearnerService/Ready',
|
||||
request_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class LearnerServiceServicer(object):
|
||||
"""LearnerService: the Actor calls this to push transitions.
|
||||
The Learner implements this service.
|
||||
"""
|
||||
|
||||
def SendInteractionMessage(self, request, context):
|
||||
"""Actor -> Learner to store transitions
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def StreamParameters(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendTransitions(self, request_iterator, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendInteractions(self, request_iterator, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Ready(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_LearnerServiceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'SendInteractionMessage': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendInteractionMessage,
|
||||
request_deserializer=hilserl__pb2.InteractionMessage.FromString,
|
||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'StreamParameters': grpc.unary_stream_rpc_method_handler(
|
||||
servicer.StreamParameters,
|
||||
request_deserializer=hilserl__pb2.Empty.FromString,
|
||||
response_serializer=hilserl__pb2.Parameters.SerializeToString,
|
||||
),
|
||||
'SendTransitions': grpc.stream_unary_rpc_method_handler(
|
||||
servicer.SendTransitions,
|
||||
request_deserializer=hilserl__pb2.Transition.FromString,
|
||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'SendInteractions': grpc.stream_unary_rpc_method_handler(
|
||||
servicer.SendInteractions,
|
||||
request_deserializer=hilserl__pb2.InteractionMessage.FromString,
|
||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'Ready': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Ready,
|
||||
request_deserializer=hilserl__pb2.Empty.FromString,
|
||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'hil_serl.LearnerService', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('hil_serl.LearnerService', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class LearnerService(object):
|
||||
"""LearnerService: the Actor calls this to push transitions.
|
||||
The Learner implements this service.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def SendInteractionMessage(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/hil_serl.LearnerService/SendInteractionMessage',
|
||||
hilserl__pb2.InteractionMessage.SerializeToString,
|
||||
hilserl__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def StreamParameters(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_stream(
|
||||
request,
|
||||
target,
|
||||
'/hil_serl.LearnerService/StreamParameters',
|
||||
hilserl__pb2.Empty.SerializeToString,
|
||||
hilserl__pb2.Parameters.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendTransitions(request_iterator,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.stream_unary(
|
||||
request_iterator,
|
||||
target,
|
||||
'/hil_serl.LearnerService/SendTransitions',
|
||||
hilserl__pb2.Transition.SerializeToString,
|
||||
hilserl__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendInteractions(request_iterator,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.stream_unary(
|
||||
request_iterator,
|
||||
target,
|
||||
'/hil_serl.LearnerService/SendInteractions',
|
||||
hilserl__pb2.InteractionMessage.SerializeToString,
|
||||
hilserl__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def Ready(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/hil_serl.LearnerService/Ready',
|
||||
hilserl__pb2.Empty.SerializeToString,
|
||||
hilserl__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
546
lerobot/scripts/server/kinematics.py
Normal file
546
lerobot/scripts/server/kinematics.py
Normal file
@@ -0,0 +1,546 @@
|
||||
# ruff: noqa: N806, N815, N803
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
|
||||
def skew_symmetric(w):
|
||||
"""Creates the skew-symmetric matrix from a 3D vector."""
|
||||
return np.array([[0, -w[2], w[1]], [w[2], 0, -w[0]], [-w[1], w[0], 0]])
|
||||
|
||||
|
||||
def rodrigues_rotation(w, theta):
|
||||
"""Computes the rotation matrix using Rodrigues' formula."""
|
||||
w_hat = skew_symmetric(w)
|
||||
return np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
|
||||
|
||||
|
||||
def screw_axis_to_transform(S, theta):
|
||||
"""Converts a screw axis to a 4x4 transformation matrix."""
|
||||
S_w = S[:3]
|
||||
S_v = S[3:]
|
||||
if np.allclose(S_w, 0) and np.linalg.norm(S_v) == 1: # Pure translation
|
||||
T = np.eye(4)
|
||||
T[:3, 3] = S_v * theta
|
||||
elif np.linalg.norm(S_w) == 1: # Rotation and translation
|
||||
w_hat = skew_symmetric(S_w)
|
||||
R = np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
|
||||
t = (np.eye(3) * theta + (1 - np.cos(theta)) * w_hat + (theta - np.sin(theta)) * w_hat @ w_hat) @ S_v
|
||||
T = np.eye(4)
|
||||
T[:3, :3] = R
|
||||
T[:3, 3] = t
|
||||
else:
|
||||
raise ValueError("Invalid screw axis parameters")
|
||||
return T
|
||||
|
||||
|
||||
def pose_difference_se3(pose1, pose2):
|
||||
"""
|
||||
Calculates the SE(3) difference between two 4x4 homogeneous transformation matrices.
|
||||
SE(3) (Special Euclidean Group) represents rigid body transformations in 3D space, combining rotation (SO(3)) and translation.
|
||||
Each 4x4 matrix has the following structure, a 3x3 rotation matrix in the top-left and a 3x1 translation vector in the top-right:
|
||||
|
||||
[R11 R12 R13 tx]
|
||||
[R21 R22 R23 ty]
|
||||
[R31 R32 R33 tz]
|
||||
[ 0 0 0 1]
|
||||
|
||||
where Rij is the 3x3 rotation matrix and [tx,ty,tz] is the translation vector.
|
||||
|
||||
pose1 - pose2
|
||||
|
||||
Args:
|
||||
pose1: A 4x4 numpy array representing the first pose.
|
||||
pose2: A 4x4 numpy array representing the second pose.
|
||||
|
||||
Returns:
|
||||
A tuple (translation_diff, rotation_diff) where:
|
||||
- translation_diff is a 3x1 numpy array representing the translational difference.
|
||||
- rotation_diff is a 3x1 numpy array representing the rotational difference in axis-angle representation.
|
||||
"""
|
||||
|
||||
# Extract rotation matrices from poses
|
||||
R1 = pose1[:3, :3]
|
||||
R2 = pose2[:3, :3]
|
||||
|
||||
# Calculate translational difference
|
||||
translation_diff = pose1[:3, 3] - pose2[:3, 3]
|
||||
|
||||
# Calculate rotational difference using scipy's Rotation library
|
||||
R_diff = Rotation.from_matrix(R1 @ R2.T)
|
||||
rotation_diff = R_diff.as_rotvec() # Convert to axis-angle representation
|
||||
|
||||
return np.concatenate([translation_diff, rotation_diff])
|
||||
|
||||
|
||||
def se3_error(target_pose, current_pose):
|
||||
pos_error = target_pose[:3, 3] - current_pose[:3, 3]
|
||||
R_target = target_pose[:3, :3]
|
||||
R_current = current_pose[:3, :3]
|
||||
R_error = R_target @ R_current.T
|
||||
rot_error = Rotation.from_matrix(R_error).as_rotvec()
|
||||
return np.concatenate([pos_error, rot_error])
|
||||
|
||||
|
||||
class RobotKinematics:
|
||||
"""Robot kinematics class supporting multiple robot models."""
|
||||
|
||||
# Robot measurements dictionary
|
||||
ROBOT_MEASUREMENTS = {
|
||||
"koch": {
|
||||
"gripper": [0.239, -0.001, 0.024],
|
||||
"wrist": [0.209, 0, 0.024],
|
||||
"forearm": [0.108, 0, 0.02],
|
||||
"humerus": [0, 0, 0.036],
|
||||
"shoulder": [0, 0, 0],
|
||||
"base": [0, 0, 0.02],
|
||||
},
|
||||
"so100": {
|
||||
"gripper": [0.320, 0, 0.050],
|
||||
"wrist": [0.278, 0, 0.050],
|
||||
"forearm": [0.143, 0, 0.044],
|
||||
"humerus": [0.031, 0, 0.072],
|
||||
"shoulder": [0, 0, 0],
|
||||
"base": [0, 0, 0.02],
|
||||
},
|
||||
"moss": {
|
||||
"gripper": [0.246, 0.013, 0.111],
|
||||
"wrist": [0.245, 0.002, 0.064],
|
||||
"forearm": [0.122, 0, 0.064],
|
||||
"humerus": [0.001, 0.001, 0.063],
|
||||
"shoulder": [0, 0, 0],
|
||||
"base": [0, 0, 0.02],
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, robot_type="so100"):
|
||||
"""Initialize kinematics for the specified robot type.
|
||||
|
||||
Args:
|
||||
robot_type: String specifying the robot model ("koch", "so100", or "moss")
|
||||
"""
|
||||
if robot_type not in self.ROBOT_MEASUREMENTS:
|
||||
raise ValueError(
|
||||
f"Unknown robot type: {robot_type}. Available types: {list(self.ROBOT_MEASUREMENTS.keys())}"
|
||||
)
|
||||
|
||||
self.robot_type = robot_type
|
||||
self.measurements = self.ROBOT_MEASUREMENTS[robot_type]
|
||||
|
||||
# Initialize all transformation matrices and screw axes
|
||||
self._setup_transforms()
|
||||
|
||||
def _create_translation_matrix(self, x=0, y=0, z=0):
|
||||
"""Create a 4x4 translation matrix."""
|
||||
return np.array([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z], [0, 0, 0, 1]])
|
||||
|
||||
def _setup_transforms(self):
|
||||
"""Setup all transformation matrices and screw axes for the robot."""
|
||||
# Set up rotation matrices (constant across robot types)
|
||||
|
||||
# Gripper orientation
|
||||
self.gripper_X0 = np.array(
|
||||
[
|
||||
[1, 0, 0, 0],
|
||||
[0, 0, 1, 0],
|
||||
[0, -1, 0, 0],
|
||||
[0, 0, 0, 1],
|
||||
]
|
||||
)
|
||||
|
||||
# Wrist orientation
|
||||
self.wrist_X0 = np.array(
|
||||
[
|
||||
[0, -1, 0, 0],
|
||||
[1, 0, 0, 0],
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 0, 1],
|
||||
]
|
||||
)
|
||||
|
||||
# Base orientation
|
||||
self.base_X0 = np.array(
|
||||
[
|
||||
[0, 0, 1, 0],
|
||||
[1, 0, 0, 0],
|
||||
[0, 1, 0, 0],
|
||||
[0, 0, 0, 1],
|
||||
]
|
||||
)
|
||||
|
||||
# Gripper
|
||||
# Screw axis of gripper frame wrt base frame
|
||||
self.S_BG = np.array(
|
||||
[
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
self.measurements["gripper"][2],
|
||||
-self.measurements["gripper"][1],
|
||||
]
|
||||
)
|
||||
|
||||
# Gripper origin to centroid transform
|
||||
self.X_GoGc = self._create_translation_matrix(x=0.07)
|
||||
|
||||
# Gripper origin to tip transform
|
||||
self.X_GoGt = self._create_translation_matrix(x=0.12)
|
||||
|
||||
# 0-position gripper frame pose wrt base
|
||||
self.X_BoGo = self._create_translation_matrix(
|
||||
x=self.measurements["gripper"][0],
|
||||
y=self.measurements["gripper"][1],
|
||||
z=self.measurements["gripper"][2],
|
||||
)
|
||||
|
||||
# Wrist
|
||||
# Screw axis of wrist frame wrt base frame
|
||||
self.S_BR = np.array([0, 1, 0, -self.measurements["wrist"][2], 0, self.measurements["wrist"][0]])
|
||||
|
||||
# 0-position origin to centroid transform
|
||||
self.X_RoRc = self._create_translation_matrix(x=0.0035, y=-0.002)
|
||||
|
||||
# 0-position wrist frame pose wrt base
|
||||
self.X_BR = self._create_translation_matrix(
|
||||
x=self.measurements["wrist"][0],
|
||||
y=self.measurements["wrist"][1],
|
||||
z=self.measurements["wrist"][2],
|
||||
)
|
||||
|
||||
# Forearm
|
||||
# Screw axis of forearm frame wrt base frame
|
||||
self.S_BF = np.array(
|
||||
[
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
-self.measurements["forearm"][2],
|
||||
0,
|
||||
self.measurements["forearm"][0],
|
||||
]
|
||||
)
|
||||
|
||||
# Forearm origin + centroid transform
|
||||
self.X_FoFc = self._create_translation_matrix(x=0.036) # spellchecker:disable-line
|
||||
|
||||
# 0-position forearm frame pose wrt base
|
||||
self.X_BF = self._create_translation_matrix(
|
||||
x=self.measurements["forearm"][0],
|
||||
y=self.measurements["forearm"][1],
|
||||
z=self.measurements["forearm"][2],
|
||||
)
|
||||
|
||||
# Humerus
|
||||
# Screw axis of humerus frame wrt base frame
|
||||
self.S_BH = np.array(
|
||||
[
|
||||
0,
|
||||
-1,
|
||||
0,
|
||||
self.measurements["humerus"][2],
|
||||
0,
|
||||
-self.measurements["humerus"][0],
|
||||
]
|
||||
)
|
||||
|
||||
# Humerus origin to centroid transform
|
||||
self.X_HoHc = self._create_translation_matrix(x=0.0475)
|
||||
|
||||
# 0-position humerus frame pose wrt base
|
||||
self.X_BH = self._create_translation_matrix(
|
||||
x=self.measurements["humerus"][0],
|
||||
y=self.measurements["humerus"][1],
|
||||
z=self.measurements["humerus"][2],
|
||||
)
|
||||
|
||||
# Shoulder
|
||||
# Screw axis of shoulder frame wrt Base frame
|
||||
self.S_BS = np.array([0, 0, -1, 0, 0, 0])
|
||||
|
||||
# Shoulder origin to centroid transform
|
||||
self.X_SoSc = self._create_translation_matrix(x=-0.017, z=0.0235)
|
||||
|
||||
# 0-position shoulder frame pose wrt base
|
||||
self.X_BS = self._create_translation_matrix(
|
||||
x=self.measurements["shoulder"][0],
|
||||
y=self.measurements["shoulder"][1],
|
||||
z=self.measurements["shoulder"][2],
|
||||
)
|
||||
|
||||
# Base
|
||||
# Base origin to centroid transform
|
||||
self.X_BoBc = self._create_translation_matrix(y=0.015)
|
||||
|
||||
# World to base transform
|
||||
self.X_WoBo = self._create_translation_matrix(
|
||||
x=self.measurements["base"][0],
|
||||
y=self.measurements["base"][1],
|
||||
z=self.measurements["base"][2],
|
||||
)
|
||||
|
||||
# Pre-compute gripper post-multiplication matrix
|
||||
self._fk_gripper_post = self.X_GoGc @ self.X_BoGo @ self.gripper_X0
|
||||
|
||||
def fk_base(self):
|
||||
"""Forward kinematics for the base frame."""
|
||||
return self.X_WoBo @ self.X_BoBc @ self.base_X0
|
||||
|
||||
def fk_shoulder(self, robot_pos_deg):
|
||||
"""Forward kinematics for the shoulder frame."""
|
||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||
return self.X_WoBo @ screw_axis_to_transform(self.S_BS, robot_pos_rad[0]) @ self.X_SoSc @ self.X_BS
|
||||
|
||||
def fk_humerus(self, robot_pos_deg):
|
||||
"""Forward kinematics for the humerus frame."""
|
||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||
return (
|
||||
self.X_WoBo
|
||||
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
||||
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
|
||||
@ self.X_HoHc
|
||||
@ self.X_BH
|
||||
)
|
||||
|
||||
def fk_forearm(self, robot_pos_deg):
|
||||
"""Forward kinematics for the forearm frame."""
|
||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||
return (
|
||||
self.X_WoBo
|
||||
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
||||
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
|
||||
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
|
||||
@ self.X_FoFc # spellchecker:disable-line
|
||||
@ self.X_BF
|
||||
)
|
||||
|
||||
def fk_wrist(self, robot_pos_deg):
|
||||
"""Forward kinematics for the wrist frame."""
|
||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||
return (
|
||||
self.X_WoBo
|
||||
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
||||
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
|
||||
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
|
||||
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
|
||||
@ self.X_RoRc
|
||||
@ self.X_BR
|
||||
@ self.wrist_X0
|
||||
)
|
||||
|
||||
def fk_gripper(self, robot_pos_deg):
|
||||
"""Forward kinematics for the gripper frame."""
|
||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||
return (
|
||||
self.X_WoBo
|
||||
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
||||
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
|
||||
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
|
||||
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
|
||||
@ screw_axis_to_transform(self.S_BG, robot_pos_rad[4])
|
||||
@ self._fk_gripper_post
|
||||
)
|
||||
|
||||
def fk_gripper_tip(self, robot_pos_deg):
|
||||
"""Forward kinematics for the gripper tip frame."""
|
||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||
return (
|
||||
self.X_WoBo
|
||||
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
||||
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
|
||||
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
|
||||
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
|
||||
@ screw_axis_to_transform(self.S_BG, robot_pos_rad[4])
|
||||
@ self.X_GoGt
|
||||
@ self.X_BoGo
|
||||
@ self.gripper_X0
|
||||
)
|
||||
|
||||
def compute_jacobian(self, robot_pos_deg, fk_func=None):
|
||||
"""Finite differences to compute the Jacobian.
|
||||
J(i, j) represents how the ith component of the end-effector's velocity changes wrt a small change
|
||||
in the jth joint's velocity.
|
||||
|
||||
Args:
|
||||
robot_pos_deg: Current joint positions in degrees
|
||||
fk_func: Forward kinematics function to use (defaults to fk_gripper)
|
||||
"""
|
||||
if fk_func is None:
|
||||
fk_func = self.fk_gripper
|
||||
|
||||
eps = 1e-8
|
||||
jac = np.zeros(shape=(6, 5))
|
||||
delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64)
|
||||
for el_ix in range(len(robot_pos_deg[:-1])):
|
||||
delta *= 0
|
||||
delta[el_ix] = eps / 2
|
||||
Sdot = (
|
||||
pose_difference_se3(
|
||||
fk_func(robot_pos_deg[:-1] + delta),
|
||||
fk_func(robot_pos_deg[:-1] - delta),
|
||||
)
|
||||
/ eps
|
||||
)
|
||||
jac[:, el_ix] = Sdot
|
||||
return jac
|
||||
|
||||
def compute_positional_jacobian(self, robot_pos_deg, fk_func=None):
|
||||
"""Finite differences to compute the positional Jacobian.
|
||||
J(i, j) represents how the ith component of the end-effector's position changes wrt a small change
|
||||
in the jth joint's velocity.
|
||||
|
||||
Args:
|
||||
robot_pos_deg: Current joint positions in degrees
|
||||
fk_func: Forward kinematics function to use (defaults to fk_gripper)
|
||||
"""
|
||||
if fk_func is None:
|
||||
fk_func = self.fk_gripper
|
||||
|
||||
eps = 1e-8
|
||||
jac = np.zeros(shape=(3, 5))
|
||||
delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64)
|
||||
for el_ix in range(len(robot_pos_deg[:-1])):
|
||||
delta *= 0
|
||||
delta[el_ix] = eps / 2
|
||||
Sdot = (
|
||||
fk_func(robot_pos_deg[:-1] + delta)[:3, 3] - fk_func(robot_pos_deg[:-1] - delta)[:3, 3]
|
||||
) / eps
|
||||
jac[:, el_ix] = Sdot
|
||||
return jac
|
||||
|
||||
def ik(self, current_joint_state, desired_ee_pose, position_only=True, fk_func=None):
|
||||
"""Inverse kinematics using gradient descent.
|
||||
|
||||
Args:
|
||||
current_joint_state: Initial joint positions in degrees
|
||||
desired_ee_pose: Target end-effector pose as a 4x4 transformation matrix
|
||||
position_only: If True, only match end-effector position, not orientation
|
||||
fk_func: Forward kinematics function to use (defaults to fk_gripper)
|
||||
|
||||
Returns:
|
||||
Joint positions in degrees that achieve the desired end-effector pose
|
||||
"""
|
||||
if fk_func is None:
|
||||
fk_func = self.fk_gripper
|
||||
|
||||
# Do gradient descent.
|
||||
max_iterations = 5
|
||||
learning_rate = 1
|
||||
for _ in range(max_iterations):
|
||||
current_ee_pose = fk_func(current_joint_state)
|
||||
if not position_only:
|
||||
error = se3_error(desired_ee_pose, current_ee_pose)
|
||||
jac = self.compute_jacobian(current_joint_state, fk_func)
|
||||
else:
|
||||
error = desired_ee_pose[:3, 3] - current_ee_pose[:3, 3]
|
||||
jac = self.compute_positional_jacobian(current_joint_state, fk_func)
|
||||
delta_angles = np.linalg.pinv(jac) @ error
|
||||
current_joint_state[:-1] += learning_rate * delta_angles
|
||||
|
||||
if np.linalg.norm(error) < 5e-3:
|
||||
return current_joint_state
|
||||
return current_joint_state
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
def run_test(robot_type):
|
||||
"""Run test suite for a specific robot type."""
|
||||
print(f"\n--- Testing {robot_type.upper()} Robot ---")
|
||||
|
||||
# Initialize kinematics for this robot
|
||||
robot = RobotKinematics(robot_type)
|
||||
|
||||
# Test 1: Forward kinematics consistency
|
||||
print("Test 1: Forward kinematics consistency")
|
||||
test_angles = np.array([30, 45, -30, 20, 10, 0]) # Example joint angles in degrees
|
||||
|
||||
# Calculate FK for different joints
|
||||
shoulder_pose = robot.fk_shoulder(test_angles)
|
||||
humerus_pose = robot.fk_humerus(test_angles)
|
||||
forearm_pose = robot.fk_forearm(test_angles)
|
||||
wrist_pose = robot.fk_wrist(test_angles)
|
||||
gripper_pose = robot.fk_gripper(test_angles)
|
||||
gripper_tip_pose = robot.fk_gripper_tip(test_angles)
|
||||
|
||||
# Check that poses form a consistent kinematic chain (positions should be progressively further from origin)
|
||||
distances = [
|
||||
np.linalg.norm(shoulder_pose[:3, 3]),
|
||||
np.linalg.norm(humerus_pose[:3, 3]),
|
||||
np.linalg.norm(forearm_pose[:3, 3]),
|
||||
np.linalg.norm(wrist_pose[:3, 3]),
|
||||
np.linalg.norm(gripper_pose[:3, 3]),
|
||||
np.linalg.norm(gripper_tip_pose[:3, 3]),
|
||||
]
|
||||
|
||||
# Check if distances generally increase along the chain
|
||||
is_consistent = all(distances[i] <= distances[i + 1] for i in range(len(distances) - 1))
|
||||
print(f" Pose distances from origin: {[round(d, 3) for d in distances]}")
|
||||
print(f" Kinematic chain consistency: {'PASSED' if is_consistent else 'FAILED'}")
|
||||
|
||||
# Test 2: Jacobian computation
|
||||
print("Test 2: Jacobian computation")
|
||||
jacobian = robot.compute_jacobian(test_angles)
|
||||
positional_jacobian = robot.compute_positional_jacobian(test_angles)
|
||||
|
||||
# Check shapes
|
||||
jacobian_shape_ok = jacobian.shape == (6, 5)
|
||||
pos_jacobian_shape_ok = positional_jacobian.shape == (3, 5)
|
||||
|
||||
print(f" Jacobian shape: {'PASSED' if jacobian_shape_ok else 'FAILED'}")
|
||||
print(f" Positional Jacobian shape: {'PASSED' if pos_jacobian_shape_ok else 'FAILED'}")
|
||||
|
||||
# Test 3: Inverse kinematics
|
||||
print("Test 3: Inverse kinematics (position only)")
|
||||
|
||||
# Generate target pose from known joint angles
|
||||
original_angles = np.array([10, 20, 30, -10, 5, 0])
|
||||
target_pose = robot.fk_gripper(original_angles)
|
||||
|
||||
# Start IK from a different position
|
||||
initial_guess = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
|
||||
|
||||
# Measure IK performance
|
||||
start_time = time.time()
|
||||
computed_angles = robot.ik(initial_guess.copy(), target_pose)
|
||||
ik_time = time.time() - start_time
|
||||
|
||||
# Compute resulting pose from IK solution
|
||||
result_pose = robot.fk_gripper(computed_angles)
|
||||
|
||||
# Calculate position error
|
||||
pos_error = np.linalg.norm(target_pose[:3, 3] - result_pose[:3, 3])
|
||||
passed = pos_error < 0.01 # Accept errors less than 1cm
|
||||
|
||||
print(f" IK computation time: {ik_time:.4f} seconds")
|
||||
print(f" Position error: {pos_error:.4f}")
|
||||
print(f" IK position accuracy: {'PASSED' if passed else 'FAILED'}")
|
||||
|
||||
return is_consistent and jacobian_shape_ok and pos_jacobian_shape_ok and passed
|
||||
|
||||
# Run tests for all robot types
|
||||
results = {}
|
||||
for robot_type in ["koch", "so100", "moss"]:
|
||||
results[robot_type] = run_test(robot_type)
|
||||
|
||||
# Print overall summary
|
||||
print("\n=== Test Summary ===")
|
||||
all_passed = all(results.values())
|
||||
for robot_type, passed in results.items():
|
||||
print(f"{robot_type.upper()}: {'PASSED' if passed else 'FAILED'}")
|
||||
print(f"\nOverall: {'ALL TESTS PASSED' if all_passed else 'SOME TESTS FAILED'}")
|
||||
1232
lerobot/scripts/server/learner_server.py
Normal file
1232
lerobot/scripts/server/learner_server.py
Normal file
File diff suppressed because it is too large
Load Diff
80
lerobot/scripts/server/learner_service.py
Normal file
80
lerobot/scripts/server/learner_service.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import logging
|
||||
from multiprocessing import Event, Queue
|
||||
|
||||
import hilserl_pb2 # type: ignore
|
||||
import hilserl_pb2_grpc # type: ignore
|
||||
|
||||
from lerobot.scripts.server.network_utils import receive_bytes_in_chunks, send_bytes_in_chunks
|
||||
|
||||
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
|
||||
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
|
||||
SHUTDOWN_TIMEOUT = 10
|
||||
|
||||
|
||||
class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
|
||||
def __init__(
|
||||
self,
|
||||
shutdown_event: Event, # type: ignore
|
||||
parameters_queue: Queue,
|
||||
seconds_between_pushes: float,
|
||||
transition_queue: Queue,
|
||||
interaction_message_queue: Queue,
|
||||
):
|
||||
self.shutdown_event = shutdown_event
|
||||
self.parameters_queue = parameters_queue
|
||||
self.seconds_between_pushes = seconds_between_pushes
|
||||
self.transition_queue = transition_queue
|
||||
self.interaction_message_queue = interaction_message_queue
|
||||
|
||||
def StreamParameters(self, request, context): # noqa: N802
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to stream parameters from the Actor")
|
||||
|
||||
while not self.shutdown_event.is_set():
|
||||
logging.info("[LEARNER] Push parameters to the Actor")
|
||||
buffer = self.parameters_queue.get()
|
||||
|
||||
yield from send_bytes_in_chunks(
|
||||
buffer,
|
||||
hilserl_pb2.Parameters,
|
||||
log_prefix="[LEARNER] Sending parameters",
|
||||
silent=True,
|
||||
)
|
||||
|
||||
logging.info("[LEARNER] Parameters sent")
|
||||
|
||||
self.shutdown_event.wait(self.seconds_between_pushes)
|
||||
|
||||
logging.info("[LEARNER] Stream parameters finished")
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
def SendTransitions(self, request_iterator, _context): # noqa: N802
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to receive transitions from the Actor")
|
||||
|
||||
receive_bytes_in_chunks(
|
||||
request_iterator,
|
||||
self.transition_queue,
|
||||
self.shutdown_event,
|
||||
log_prefix="[LEARNER] transitions",
|
||||
)
|
||||
|
||||
logging.debug("[LEARNER] Finished receiving transitions")
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
def SendInteractions(self, request_iterator, _context): # noqa: N802
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to receive interactions from the Actor")
|
||||
|
||||
receive_bytes_in_chunks(
|
||||
request_iterator,
|
||||
self.interaction_message_queue,
|
||||
self.shutdown_event,
|
||||
log_prefix="[LEARNER] interactions",
|
||||
)
|
||||
|
||||
logging.debug("[LEARNER] Finished receiving interactions")
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
def Ready(self, request, context): # noqa: N802
|
||||
return hilserl_pb2.Empty()
|
||||
142
lerobot/scripts/server/network_utils.py
Normal file
142
lerobot/scripts/server/network_utils.py
Normal file
@@ -0,0 +1,142 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import io
|
||||
import logging
|
||||
import pickle # nosec B403: Safe usage for internal serialization only
|
||||
from multiprocessing import Event, Queue
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.scripts.server import hilserl_pb2
|
||||
from lerobot.scripts.server.utils import Transition
|
||||
|
||||
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
|
||||
|
||||
|
||||
def bytes_buffer_size(buffer: io.BytesIO) -> int:
|
||||
buffer.seek(0, io.SEEK_END)
|
||||
result = buffer.tell()
|
||||
buffer.seek(0)
|
||||
return result
|
||||
|
||||
|
||||
def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True):
|
||||
buffer = io.BytesIO(buffer)
|
||||
size_in_bytes = bytes_buffer_size(buffer)
|
||||
|
||||
sent_bytes = 0
|
||||
|
||||
logging_method = logging.info if not silent else logging.debug
|
||||
|
||||
logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with")
|
||||
|
||||
while sent_bytes < size_in_bytes:
|
||||
transfer_state = hilserl_pb2.TransferState.TRANSFER_MIDDLE
|
||||
|
||||
if sent_bytes + CHUNK_SIZE >= size_in_bytes:
|
||||
transfer_state = hilserl_pb2.TransferState.TRANSFER_END
|
||||
elif sent_bytes == 0:
|
||||
transfer_state = hilserl_pb2.TransferState.TRANSFER_BEGIN
|
||||
|
||||
size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
|
||||
chunk = buffer.read(size_to_read)
|
||||
|
||||
yield message_class(transfer_state=transfer_state, data=chunk)
|
||||
sent_bytes += size_to_read
|
||||
logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}")
|
||||
|
||||
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
|
||||
|
||||
|
||||
def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""): # type: ignore
|
||||
bytes_buffer = io.BytesIO()
|
||||
step = 0
|
||||
|
||||
logging.info(f"{log_prefix} Starting receiver")
|
||||
for item in iterator:
|
||||
logging.debug(f"{log_prefix} Received item")
|
||||
if shutdown_event.is_set():
|
||||
logging.info(f"{log_prefix} Shutting down receiver")
|
||||
return
|
||||
|
||||
if item.transfer_state == hilserl_pb2.TransferState.TRANSFER_BEGIN:
|
||||
bytes_buffer.seek(0)
|
||||
bytes_buffer.truncate(0)
|
||||
bytes_buffer.write(item.data)
|
||||
logging.debug(f"{log_prefix} Received data at step 0")
|
||||
step = 0
|
||||
continue
|
||||
elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_MIDDLE:
|
||||
bytes_buffer.write(item.data)
|
||||
step += 1
|
||||
logging.debug(f"{log_prefix} Received data at step {step}")
|
||||
elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_END:
|
||||
bytes_buffer.write(item.data)
|
||||
logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
|
||||
|
||||
queue.put(bytes_buffer.getvalue())
|
||||
|
||||
bytes_buffer.seek(0)
|
||||
bytes_buffer.truncate(0)
|
||||
step = 0
|
||||
|
||||
logging.debug(f"{log_prefix} Queue updated")
|
||||
|
||||
|
||||
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
|
||||
"""Convert model state dict to flat array for transmission"""
|
||||
buffer = io.BytesIO()
|
||||
|
||||
torch.save(state_dict, buffer)
|
||||
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
return torch.load(buffer, weights_only=False) # nosec B614: Using weights_only=False relies on pickle which has security implications.
|
||||
# This is currently safe as we only deserialize trusted internal data.
|
||||
# TODO: Verify if weights_only=True would work for our use case (safer default in torch 2.6+)
|
||||
|
||||
|
||||
def python_object_to_bytes(python_object: Any) -> bytes:
|
||||
return pickle.dumps(python_object)
|
||||
|
||||
|
||||
def bytes_to_python_object(buffer: bytes) -> Any:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load
|
||||
# Add validation checks here
|
||||
return obj
|
||||
|
||||
|
||||
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
transitions = torch.load(buffer, weights_only=False) # nosec B614: Safe usage of torch.load
|
||||
# Add validation checks here
|
||||
return transitions
|
||||
|
||||
|
||||
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
|
||||
buffer = io.BytesIO()
|
||||
torch.save(transitions, buffer)
|
||||
return buffer.getvalue()
|
||||
141
lerobot/scripts/server/utils.py
Normal file
141
lerobot/scripts/server/utils.py
Normal file
@@ -0,0 +1,141 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
from queue import Empty
|
||||
from typing import TypedDict
|
||||
|
||||
import torch
|
||||
from torch.multiprocessing import Queue
|
||||
|
||||
shutdown_event_counter = 0
|
||||
|
||||
|
||||
def setup_process_handlers(use_threads: bool) -> any:
|
||||
if use_threads:
|
||||
from threading import Event
|
||||
else:
|
||||
from multiprocessing import Event
|
||||
|
||||
shutdown_event = Event()
|
||||
|
||||
# Define signal handler
|
||||
def signal_handler(signum, frame):
|
||||
logging.info("Shutdown signal received. Cleaning up...")
|
||||
shutdown_event.set()
|
||||
global shutdown_event_counter
|
||||
shutdown_event_counter += 1
|
||||
|
||||
if shutdown_event_counter > 1:
|
||||
logging.info("Force shutdown")
|
||||
sys.exit(1)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
||||
signal.signal(signal.SIGTERM, signal_handler) # Termination request (kill)
|
||||
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed/Hangup
|
||||
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
logging.info("Shutdown signal received. Cleaning up...")
|
||||
shutdown_event.set()
|
||||
|
||||
return shutdown_event
|
||||
|
||||
|
||||
def get_last_item_from_queue(queue: Queue):
|
||||
item = queue.get()
|
||||
counter = 1
|
||||
|
||||
# Drain queue and keep only the most recent parameters
|
||||
try:
|
||||
while True:
|
||||
item = queue.get_nowait()
|
||||
counter += 1
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
logging.debug(f"Drained {counter} items from queue")
|
||||
|
||||
return item
|
||||
|
||||
|
||||
class Transition(TypedDict):
|
||||
state: dict[str, torch.Tensor]
|
||||
action: torch.Tensor
|
||||
reward: float
|
||||
next_state: dict[str, torch.Tensor]
|
||||
done: bool
|
||||
truncated: bool
|
||||
complementary_info: dict[str, torch.Tensor | float | int] | None = None
|
||||
|
||||
|
||||
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
|
||||
device = torch.device(device)
|
||||
non_blocking = device.type == "cuda"
|
||||
|
||||
# Move state tensors to device
|
||||
transition["state"] = {
|
||||
key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items()
|
||||
}
|
||||
|
||||
# Move action to device
|
||||
transition["action"] = transition["action"].to(device, non_blocking=non_blocking)
|
||||
|
||||
# Move reward and done if they are tensors
|
||||
if isinstance(transition["reward"], torch.Tensor):
|
||||
transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking)
|
||||
|
||||
if isinstance(transition["done"], torch.Tensor):
|
||||
transition["done"] = transition["done"].to(device, non_blocking=non_blocking)
|
||||
|
||||
if isinstance(transition["truncated"], torch.Tensor):
|
||||
transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking)
|
||||
|
||||
# Move next_state tensors to device
|
||||
transition["next_state"] = {
|
||||
key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items()
|
||||
}
|
||||
|
||||
# Move complementary_info tensors if present
|
||||
if transition.get("complementary_info") is not None:
|
||||
for key, val in transition["complementary_info"].items():
|
||||
if isinstance(val, torch.Tensor):
|
||||
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
|
||||
elif isinstance(val, (int, float, bool)):
|
||||
transition["complementary_info"][key] = torch.tensor(val, device=device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
|
||||
return transition
|
||||
|
||||
|
||||
def move_state_dict_to_device(state_dict, device="cpu"):
|
||||
"""
|
||||
Recursively move all tensors in a (potentially) nested
|
||||
dict/list/tuple structure to the CPU.
|
||||
"""
|
||||
if isinstance(state_dict, torch.Tensor):
|
||||
return state_dict.to(device)
|
||||
elif isinstance(state_dict, dict):
|
||||
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
|
||||
elif isinstance(state_dict, list):
|
||||
return [move_state_dict_to_device(v, device=device) for v in state_dict]
|
||||
elif isinstance(state_dict, tuple):
|
||||
return tuple(move_state_dict_to_device(v, device=device) for v in state_dict)
|
||||
else:
|
||||
return state_dict
|
||||
@@ -84,6 +84,7 @@ dora = [
|
||||
]
|
||||
dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
|
||||
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
|
||||
hilserl = ["transformers>=4.48", "gym-hil>=0.1.3", "protobuf>=5.29.3", "grpcio>=1.70.0"]
|
||||
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
|
||||
pi0 = ["transformers>=4.48.0"]
|
||||
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
|
||||
@@ -98,13 +99,40 @@ umi = ["imagecodecs>=2024.1.1"]
|
||||
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
|
||||
xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"]
|
||||
|
||||
|
||||
[tool.poetry]
|
||||
requires-poetry = ">=2.1"
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 110
|
||||
target-version = "py310"
|
||||
exclude = ["tests/artifacts/**/*.safetensors"]
|
||||
exclude = [
|
||||
"tests/data",
|
||||
".bzr",
|
||||
".direnv",
|
||||
".eggs",
|
||||
".git",
|
||||
".git-rewrite",
|
||||
".hg",
|
||||
".mypy_cache",
|
||||
".nox",
|
||||
".pants.d",
|
||||
".pytype",
|
||||
".ruff_cache",
|
||||
".svn",
|
||||
".tox",
|
||||
".venv",
|
||||
"__pypackages__",
|
||||
"_build",
|
||||
"buck-out",
|
||||
"build",
|
||||
"dist",
|
||||
"node_modules",
|
||||
"venv",
|
||||
"*_pb2.py",
|
||||
"*_pb2_grpc.py",
|
||||
]
|
||||
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
|
||||
|
||||
@@ -14,9 +14,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import traceback
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from serial import SerialException
|
||||
|
||||
from lerobot import available_cameras, available_motors, available_robots
|
||||
@@ -86,3 +88,19 @@ def patch_builtins_input(monkeypatch):
|
||||
print(text)
|
||||
|
||||
monkeypatch.setattr("builtins.input", print_text)
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--seed",
|
||||
action="store",
|
||||
default="42",
|
||||
help="Set random seed for reproducibility",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_random_seed(request):
|
||||
seed = int(request.config.getoption("--seed"))
|
||||
random.seed(seed) # Python random
|
||||
torch.manual_seed(seed) # PyTorch
|
||||
|
||||
@@ -21,6 +21,7 @@ from lerobot.common.constants import (
|
||||
from lerobot.common.optim.optimizers import (
|
||||
AdamConfig,
|
||||
AdamWConfig,
|
||||
MultiAdamConfig,
|
||||
SGDConfig,
|
||||
load_optimizer_state,
|
||||
save_optimizer_state,
|
||||
@@ -33,13 +34,21 @@ from lerobot.common.optim.optimizers import (
|
||||
(AdamConfig, torch.optim.Adam),
|
||||
(AdamWConfig, torch.optim.AdamW),
|
||||
(SGDConfig, torch.optim.SGD),
|
||||
(MultiAdamConfig, dict),
|
||||
],
|
||||
)
|
||||
def test_optimizer_build(config_cls, expected_class, model_params):
|
||||
config = config_cls()
|
||||
optimizer = config.build(model_params)
|
||||
assert isinstance(optimizer, expected_class)
|
||||
assert optimizer.defaults["lr"] == config.lr
|
||||
if config_cls == MultiAdamConfig:
|
||||
params_dict = {"default": model_params}
|
||||
optimizer = config.build(params_dict)
|
||||
assert isinstance(optimizer, expected_class)
|
||||
assert isinstance(optimizer["default"], torch.optim.Adam)
|
||||
assert optimizer["default"].defaults["lr"] == config.lr
|
||||
else:
|
||||
optimizer = config.build(model_params)
|
||||
assert isinstance(optimizer, expected_class)
|
||||
assert optimizer.defaults["lr"] == config.lr
|
||||
|
||||
|
||||
def test_save_optimizer_state(optimizer, tmp_path):
|
||||
@@ -54,3 +63,180 @@ def test_save_and_load_optimizer_state(model_params, optimizer, tmp_path):
|
||||
loaded_optimizer = load_optimizer_state(loaded_optimizer, tmp_path)
|
||||
|
||||
torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_params_dict():
|
||||
return {
|
||||
"actor": [torch.nn.Parameter(torch.randn(10, 10))],
|
||||
"critic": [torch.nn.Parameter(torch.randn(5, 5))],
|
||||
"temperature": [torch.nn.Parameter(torch.randn(3, 3))],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config_params, expected_values",
|
||||
[
|
||||
# Test 1: Basic configuration with different learning rates
|
||||
(
|
||||
{
|
||||
"lr": 1e-3,
|
||||
"weight_decay": 1e-4,
|
||||
"optimizer_groups": {
|
||||
"actor": {"lr": 1e-4},
|
||||
"critic": {"lr": 5e-4},
|
||||
"temperature": {"lr": 2e-3},
|
||||
},
|
||||
},
|
||||
{
|
||||
"actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)},
|
||||
"critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)},
|
||||
"temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999)},
|
||||
},
|
||||
),
|
||||
# Test 2: Different weight decays and beta values
|
||||
(
|
||||
{
|
||||
"lr": 1e-3,
|
||||
"weight_decay": 1e-4,
|
||||
"optimizer_groups": {
|
||||
"actor": {"lr": 1e-4, "weight_decay": 1e-5},
|
||||
"critic": {"lr": 5e-4, "weight_decay": 1e-6},
|
||||
"temperature": {"lr": 2e-3, "betas": (0.95, 0.999)},
|
||||
},
|
||||
},
|
||||
{
|
||||
"actor": {"lr": 1e-4, "weight_decay": 1e-5, "betas": (0.9, 0.999)},
|
||||
"critic": {"lr": 5e-4, "weight_decay": 1e-6, "betas": (0.9, 0.999)},
|
||||
"temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.95, 0.999)},
|
||||
},
|
||||
),
|
||||
# Test 3: Epsilon parameter customization
|
||||
(
|
||||
{
|
||||
"lr": 1e-3,
|
||||
"weight_decay": 1e-4,
|
||||
"optimizer_groups": {
|
||||
"actor": {"lr": 1e-4, "eps": 1e-6},
|
||||
"critic": {"lr": 5e-4, "eps": 1e-7},
|
||||
"temperature": {"lr": 2e-3, "eps": 1e-8},
|
||||
},
|
||||
},
|
||||
{
|
||||
"actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-6},
|
||||
"critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-7},
|
||||
"temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-8},
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_multi_adam_configuration(base_params_dict, config_params, expected_values):
|
||||
# Create config with the given parameters
|
||||
config = MultiAdamConfig(**config_params)
|
||||
optimizers = config.build(base_params_dict)
|
||||
|
||||
# Verify optimizer count and keys
|
||||
assert len(optimizers) == len(expected_values)
|
||||
assert set(optimizers.keys()) == set(expected_values.keys())
|
||||
|
||||
# Check that all optimizers are Adam instances
|
||||
for opt in optimizers.values():
|
||||
assert isinstance(opt, torch.optim.Adam)
|
||||
|
||||
# Verify hyperparameters for each optimizer
|
||||
for name, expected in expected_values.items():
|
||||
optimizer = optimizers[name]
|
||||
for param, value in expected.items():
|
||||
assert optimizer.defaults[param] == value
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def multi_optimizers(base_params_dict):
|
||||
config = MultiAdamConfig(
|
||||
lr=1e-3,
|
||||
optimizer_groups={
|
||||
"actor": {"lr": 1e-4},
|
||||
"critic": {"lr": 5e-4},
|
||||
"temperature": {"lr": 2e-3},
|
||||
},
|
||||
)
|
||||
return config.build(base_params_dict)
|
||||
|
||||
|
||||
def test_save_multi_optimizer_state(multi_optimizers, tmp_path):
|
||||
# Save optimizer states
|
||||
save_optimizer_state(multi_optimizers, tmp_path)
|
||||
|
||||
# Verify that directories were created for each optimizer
|
||||
for name in multi_optimizers:
|
||||
assert (tmp_path / name).is_dir()
|
||||
assert (tmp_path / name / OPTIMIZER_STATE).is_file()
|
||||
assert (tmp_path / name / OPTIMIZER_PARAM_GROUPS).is_file()
|
||||
|
||||
|
||||
def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers, tmp_path):
|
||||
# Option 1: Add a minimal backward pass to populate optimizer states
|
||||
for name, params in base_params_dict.items():
|
||||
if name in multi_optimizers:
|
||||
# Create a dummy loss and do backward
|
||||
dummy_loss = params[0].sum()
|
||||
dummy_loss.backward()
|
||||
# Perform an optimization step
|
||||
multi_optimizers[name].step()
|
||||
# Zero gradients for next steps
|
||||
multi_optimizers[name].zero_grad()
|
||||
|
||||
# Save optimizer states
|
||||
save_optimizer_state(multi_optimizers, tmp_path)
|
||||
|
||||
# Create new optimizers with the same config
|
||||
config = MultiAdamConfig(
|
||||
lr=1e-3,
|
||||
optimizer_groups={
|
||||
"actor": {"lr": 1e-4},
|
||||
"critic": {"lr": 5e-4},
|
||||
"temperature": {"lr": 2e-3},
|
||||
},
|
||||
)
|
||||
new_optimizers = config.build(base_params_dict)
|
||||
|
||||
# Load optimizer states
|
||||
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
|
||||
|
||||
# Verify state dictionaries match
|
||||
for name in multi_optimizers:
|
||||
torch.testing.assert_close(multi_optimizers[name].state_dict(), loaded_optimizers[name].state_dict())
|
||||
|
||||
|
||||
def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path):
|
||||
"""Test saving and loading optimizer states even when the state is empty (no backward pass)."""
|
||||
# Create config and build optimizers
|
||||
config = MultiAdamConfig(
|
||||
lr=1e-3,
|
||||
optimizer_groups={
|
||||
"actor": {"lr": 1e-4},
|
||||
"critic": {"lr": 5e-4},
|
||||
"temperature": {"lr": 2e-3},
|
||||
},
|
||||
)
|
||||
optimizers = config.build(base_params_dict)
|
||||
|
||||
# Save optimizer states without any backward pass (empty state)
|
||||
save_optimizer_state(optimizers, tmp_path)
|
||||
|
||||
# Create new optimizers with the same config
|
||||
new_optimizers = config.build(base_params_dict)
|
||||
|
||||
# Load optimizer states
|
||||
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
|
||||
|
||||
# Verify hyperparameters match even with empty state
|
||||
for name, optimizer in optimizers.items():
|
||||
assert optimizer.defaults["lr"] == loaded_optimizers[name].defaults["lr"]
|
||||
assert optimizer.defaults["weight_decay"] == loaded_optimizers[name].defaults["weight_decay"]
|
||||
assert optimizer.defaults["betas"] == loaded_optimizers[name].defaults["betas"]
|
||||
|
||||
# Verify state dictionaries match (they will be empty)
|
||||
torch.testing.assert_close(
|
||||
optimizer.state_dict()["param_groups"], loaded_optimizers[name].state_dict()["param_groups"]
|
||||
)
|
||||
|
||||
123
tests/policies/hilserl/test_modeling_classifier.py
Normal file
123
tests/policies/hilserl/test_modeling_classifier.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import torch
|
||||
|
||||
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.common.policies.reward_model.modeling_classifier import ClassifierOutput
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
def test_classifier_output():
|
||||
output = ClassifierOutput(
|
||||
logits=torch.tensor([1, 2, 3]),
|
||||
probabilities=torch.tensor([0.1, 0.2, 0.3]),
|
||||
hidden_states=None,
|
||||
)
|
||||
|
||||
assert (
|
||||
f"{output}"
|
||||
== "ClassifierOutput(logits=tensor([1, 2, 3]), probabilities=tensor([0.1000, 0.2000, 0.3000]), hidden_states=None)"
|
||||
)
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_binary_classifier_with_default_params():
|
||||
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||
|
||||
config = RewardClassifierConfig()
|
||||
config.input_features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"REWARD": NormalizationMode.IDENTITY,
|
||||
}
|
||||
config.num_cameras = 1
|
||||
classifier = Classifier(config)
|
||||
|
||||
batch_size = 10
|
||||
|
||||
input = {
|
||||
"observation.image": torch.rand((batch_size, 3, 128, 128)),
|
||||
"next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(),
|
||||
}
|
||||
|
||||
images, labels = classifier.extract_images_and_labels(input)
|
||||
assert len(images) == 1
|
||||
assert images[0].shape == torch.Size([batch_size, 3, 128, 128])
|
||||
assert labels.shape == torch.Size([batch_size])
|
||||
|
||||
output = classifier.predict(images)
|
||||
|
||||
assert output is not None
|
||||
assert output.logits.size() == torch.Size([batch_size])
|
||||
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
||||
assert output.probabilities.shape == torch.Size([batch_size])
|
||||
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
||||
assert output.hidden_states.shape == torch.Size([batch_size, 256])
|
||||
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_multiclass_classifier():
|
||||
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||
|
||||
num_classes = 5
|
||||
config = RewardClassifierConfig()
|
||||
config.input_features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)),
|
||||
}
|
||||
config.num_cameras = 1
|
||||
config.num_classes = num_classes
|
||||
classifier = Classifier(config)
|
||||
|
||||
batch_size = 10
|
||||
|
||||
input = {
|
||||
"observation.image": torch.rand((batch_size, 3, 128, 128)),
|
||||
"next.reward": torch.rand((batch_size, num_classes)),
|
||||
}
|
||||
|
||||
images, labels = classifier.extract_images_and_labels(input)
|
||||
assert len(images) == 1
|
||||
assert images[0].shape == torch.Size([batch_size, 3, 128, 128])
|
||||
assert labels.shape == torch.Size([batch_size, num_classes])
|
||||
|
||||
output = classifier.predict(images)
|
||||
|
||||
assert output is not None
|
||||
assert output.logits.shape == torch.Size([batch_size, num_classes])
|
||||
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
||||
assert output.probabilities.shape == torch.Size([batch_size, num_classes])
|
||||
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
||||
assert output.hidden_states.shape == torch.Size([batch_size, 256])
|
||||
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_default_device():
|
||||
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||
|
||||
config = RewardClassifierConfig()
|
||||
assert config.device == "cpu"
|
||||
|
||||
classifier = Classifier(config)
|
||||
for p in classifier.parameters():
|
||||
assert p.device == torch.device("cpu")
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_explicit_device_setup():
|
||||
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||
|
||||
config = RewardClassifierConfig(device="cpu")
|
||||
assert config.device == "cpu"
|
||||
|
||||
classifier = Classifier(config)
|
||||
for p in classifier.parameters():
|
||||
assert p.device == torch.device("cpu")
|
||||
217
tests/policies/test_sac_config.py
Normal file
217
tests/policies/test_sac_config.py
Normal file
@@ -0,0 +1,217 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.policies.sac.configuration_sac import (
|
||||
ActorLearnerConfig,
|
||||
ActorNetworkConfig,
|
||||
ConcurrencyConfig,
|
||||
CriticNetworkConfig,
|
||||
PolicyConfig,
|
||||
SACConfig,
|
||||
)
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
def test_sac_config_default_initialization():
|
||||
config = SACConfig()
|
||||
|
||||
assert config.normalization_mapping == {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ENV": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
assert config.dataset_stats == {
|
||||
"observation.image": {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
},
|
||||
"observation.state": {
|
||||
"min": [0.0, 0.0],
|
||||
"max": [1.0, 1.0],
|
||||
},
|
||||
"action": {
|
||||
"min": [0.0, 0.0, 0.0],
|
||||
"max": [1.0, 1.0, 1.0],
|
||||
},
|
||||
}
|
||||
|
||||
# Basic parameters
|
||||
assert config.device == "cpu"
|
||||
assert config.storage_device == "cpu"
|
||||
assert config.discount == 0.99
|
||||
assert config.temperature_init == 1.0
|
||||
assert config.num_critics == 2
|
||||
|
||||
# Architecture specifics
|
||||
assert config.vision_encoder_name is None
|
||||
assert config.freeze_vision_encoder is True
|
||||
assert config.image_encoder_hidden_dim == 32
|
||||
assert config.shared_encoder is True
|
||||
assert config.num_discrete_actions is None
|
||||
assert config.image_embedding_pooling_dim == 8
|
||||
|
||||
# Training parameters
|
||||
assert config.online_steps == 1000000
|
||||
assert config.online_env_seed == 10000
|
||||
assert config.online_buffer_capacity == 100000
|
||||
assert config.offline_buffer_capacity == 100000
|
||||
assert config.async_prefetch is False
|
||||
assert config.online_step_before_learning == 100
|
||||
assert config.policy_update_freq == 1
|
||||
|
||||
# SAC algorithm parameters
|
||||
assert config.num_subsample_critics is None
|
||||
assert config.critic_lr == 3e-4
|
||||
assert config.actor_lr == 3e-4
|
||||
assert config.temperature_lr == 3e-4
|
||||
assert config.critic_target_update_weight == 0.005
|
||||
assert config.utd_ratio == 1
|
||||
assert config.state_encoder_hidden_dim == 256
|
||||
assert config.latent_dim == 256
|
||||
assert config.target_entropy is None
|
||||
assert config.use_backup_entropy is True
|
||||
assert config.grad_clip_norm == 40.0
|
||||
|
||||
# Dataset stats defaults
|
||||
expected_dataset_stats = {
|
||||
"observation.image": {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
},
|
||||
"observation.state": {
|
||||
"min": [0.0, 0.0],
|
||||
"max": [1.0, 1.0],
|
||||
},
|
||||
"action": {
|
||||
"min": [0.0, 0.0, 0.0],
|
||||
"max": [1.0, 1.0, 1.0],
|
||||
},
|
||||
}
|
||||
assert config.dataset_stats == expected_dataset_stats
|
||||
|
||||
# Critic network configuration
|
||||
assert config.critic_network_kwargs.hidden_dims == [256, 256]
|
||||
assert config.critic_network_kwargs.activate_final is True
|
||||
assert config.critic_network_kwargs.final_activation is None
|
||||
|
||||
# Actor network configuration
|
||||
assert config.actor_network_kwargs.hidden_dims == [256, 256]
|
||||
assert config.actor_network_kwargs.activate_final is True
|
||||
|
||||
# Policy configuration
|
||||
assert config.policy_kwargs.use_tanh_squash is True
|
||||
assert config.policy_kwargs.log_std_min == 1e-5
|
||||
assert config.policy_kwargs.log_std_max == 10.0
|
||||
assert config.policy_kwargs.init_final == 0.05
|
||||
|
||||
# Discrete critic network configuration
|
||||
assert config.discrete_critic_network_kwargs.hidden_dims == [256, 256]
|
||||
assert config.discrete_critic_network_kwargs.activate_final is True
|
||||
assert config.discrete_critic_network_kwargs.final_activation is None
|
||||
|
||||
# Actor learner configuration
|
||||
assert config.actor_learner_config.learner_host == "127.0.0.1"
|
||||
assert config.actor_learner_config.learner_port == 50051
|
||||
assert config.actor_learner_config.policy_parameters_push_frequency == 4
|
||||
|
||||
# Concurrency configuration
|
||||
assert config.concurrency.actor == "threads"
|
||||
assert config.concurrency.learner == "threads"
|
||||
|
||||
assert isinstance(config.actor_network_kwargs, ActorNetworkConfig)
|
||||
assert isinstance(config.critic_network_kwargs, CriticNetworkConfig)
|
||||
assert isinstance(config.policy_kwargs, PolicyConfig)
|
||||
assert isinstance(config.actor_learner_config, ActorLearnerConfig)
|
||||
assert isinstance(config.concurrency, ConcurrencyConfig)
|
||||
|
||||
|
||||
def test_critic_network_kwargs():
|
||||
config = CriticNetworkConfig()
|
||||
assert config.hidden_dims == [256, 256]
|
||||
assert config.activate_final is True
|
||||
assert config.final_activation is None
|
||||
|
||||
|
||||
def test_actor_network_kwargs():
|
||||
config = ActorNetworkConfig()
|
||||
assert config.hidden_dims == [256, 256]
|
||||
assert config.activate_final is True
|
||||
|
||||
|
||||
def test_policy_kwargs():
|
||||
config = PolicyConfig()
|
||||
assert config.use_tanh_squash is True
|
||||
assert config.log_std_min == 1e-5
|
||||
assert config.log_std_max == 10.0
|
||||
assert config.init_final == 0.05
|
||||
|
||||
|
||||
def test_actor_learner_config():
|
||||
config = ActorLearnerConfig()
|
||||
assert config.learner_host == "127.0.0.1"
|
||||
assert config.learner_port == 50051
|
||||
assert config.policy_parameters_push_frequency == 4
|
||||
|
||||
|
||||
def test_concurrency_config():
|
||||
config = ConcurrencyConfig()
|
||||
assert config.actor == "threads"
|
||||
assert config.learner == "threads"
|
||||
|
||||
|
||||
def test_sac_config_custom_initialization():
|
||||
config = SACConfig(
|
||||
device="cpu",
|
||||
discount=0.95,
|
||||
temperature_init=0.5,
|
||||
num_critics=3,
|
||||
)
|
||||
|
||||
assert config.device == "cpu"
|
||||
assert config.discount == 0.95
|
||||
assert config.temperature_init == 0.5
|
||||
assert config.num_critics == 3
|
||||
|
||||
|
||||
def test_validate_features():
|
||||
config = SACConfig(
|
||||
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
)
|
||||
config.validate_features()
|
||||
|
||||
|
||||
def test_validate_features_missing_observation():
|
||||
config = SACConfig(
|
||||
input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError, match="You must provide either 'observation.state' or an image observation"
|
||||
):
|
||||
config.validate_features()
|
||||
|
||||
|
||||
def test_validate_features_missing_action():
|
||||
config = SACConfig(
|
||||
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
)
|
||||
with pytest.raises(ValueError, match="You must provide 'action' in the output features"):
|
||||
config.validate_features()
|
||||
519
tests/policies/test_sac_policy.py
Normal file
519
tests/policies/test_sac_policy.py
Normal file
@@ -0,0 +1,519 @@
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.common.policies.sac.modeling_sac import MLP, SACPolicy
|
||||
from lerobot.common.utils.random_utils import seeded_context
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
try:
|
||||
import transformers # noqa: F401
|
||||
|
||||
TRANSFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
TRANSFORMERS_AVAILABLE = False
|
||||
|
||||
|
||||
def test_mlp_with_default_args():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
|
||||
|
||||
x = torch.randn(10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (256,)
|
||||
|
||||
|
||||
def test_mlp_with_batch_dim():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
|
||||
x = torch.randn(2, 10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (2, 256)
|
||||
|
||||
|
||||
def test_forward_with_empty_hidden_dims():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[])
|
||||
x = torch.randn(1, 10)
|
||||
assert mlp(x).shape == (1, 10)
|
||||
|
||||
|
||||
def test_mlp_with_dropout():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256, 11], dropout_rate=0.1)
|
||||
x = torch.randn(1, 10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (1, 11)
|
||||
|
||||
drop_out_layers_count = sum(isinstance(layer, nn.Dropout) for layer in mlp.net)
|
||||
assert drop_out_layers_count == 2
|
||||
|
||||
|
||||
def test_mlp_with_custom_final_activation():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256], final_activation=torch.nn.Tanh())
|
||||
x = torch.randn(1, 10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (1, 256)
|
||||
assert (y >= -1).all() and (y <= 1).all()
|
||||
|
||||
|
||||
def test_sac_policy_with_default_args():
|
||||
with pytest.raises(ValueError, match="should be an instance of class `PreTrainedConfig`"):
|
||||
SACPolicy()
|
||||
|
||||
|
||||
def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor:
|
||||
return {
|
||||
"observation.state": torch.randn(batch_size, state_dim),
|
||||
}
|
||||
|
||||
|
||||
def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor:
|
||||
return {
|
||||
"observation.image": torch.randn(batch_size, 3, 84, 84),
|
||||
"observation.state": torch.randn(batch_size, state_dim),
|
||||
}
|
||||
|
||||
|
||||
def create_dummy_action(batch_size: int, action_dim: int = 10) -> Tensor:
|
||||
return torch.randn(batch_size, action_dim)
|
||||
|
||||
|
||||
def create_default_train_batch(
|
||||
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
|
||||
) -> dict[str, Tensor]:
|
||||
return {
|
||||
"action": create_dummy_action(batch_size, action_dim),
|
||||
"reward": torch.randn(batch_size),
|
||||
"state": create_dummy_state(batch_size, state_dim),
|
||||
"next_state": create_dummy_state(batch_size, state_dim),
|
||||
"done": torch.randn(batch_size),
|
||||
}
|
||||
|
||||
|
||||
def create_train_batch_with_visual_input(
|
||||
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
|
||||
) -> dict[str, Tensor]:
|
||||
return {
|
||||
"action": create_dummy_action(batch_size, action_dim),
|
||||
"reward": torch.randn(batch_size),
|
||||
"state": create_dummy_with_visual_input(batch_size, state_dim),
|
||||
"next_state": create_dummy_with_visual_input(batch_size, state_dim),
|
||||
"done": torch.randn(batch_size),
|
||||
}
|
||||
|
||||
|
||||
def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
|
||||
return {
|
||||
"observation.state": torch.randn(batch_size, state_dim),
|
||||
}
|
||||
|
||||
|
||||
def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
|
||||
return {
|
||||
"observation.state": torch.randn(batch_size, state_dim),
|
||||
"observation.image": torch.randn(batch_size, 3, 84, 84),
|
||||
}
|
||||
|
||||
|
||||
def make_optimizers(policy: SACPolicy, has_discrete_action: bool = False) -> dict[str, torch.optim.Optimizer]:
|
||||
"""Create optimizers for the SAC policy."""
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
# Handle the case of shared encoder where the encoder weights are not optimized with the actor gradient
|
||||
params=[
|
||||
p
|
||||
for n, p in policy.actor.named_parameters()
|
||||
if not policy.config.shared_encoder or not n.startswith("encoder")
|
||||
],
|
||||
lr=policy.config.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters(),
|
||||
lr=policy.config.critic_lr,
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(
|
||||
params=[policy.log_alpha],
|
||||
lr=policy.config.critic_lr,
|
||||
)
|
||||
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
"critic": optimizer_critic,
|
||||
"temperature": optimizer_temperature,
|
||||
}
|
||||
|
||||
if has_discrete_action:
|
||||
optimizers["discrete_critic"] = torch.optim.Adam(
|
||||
params=policy.discrete_critic.parameters(),
|
||||
lr=policy.config.critic_lr,
|
||||
)
|
||||
|
||||
return optimizers
|
||||
|
||||
|
||||
def create_default_config(
|
||||
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
|
||||
) -> SACConfig:
|
||||
action_dim = continuous_action_dim
|
||||
if has_discrete_action:
|
||||
action_dim += 1
|
||||
|
||||
config = SACConfig(
|
||||
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
||||
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
|
||||
dataset_stats={
|
||||
"observation.state": {
|
||||
"min": [0.0] * state_dim,
|
||||
"max": [1.0] * state_dim,
|
||||
},
|
||||
"action": {
|
||||
"min": [0.0] * continuous_action_dim,
|
||||
"max": [1.0] * continuous_action_dim,
|
||||
},
|
||||
},
|
||||
)
|
||||
config.validate_features()
|
||||
return config
|
||||
|
||||
|
||||
def create_config_with_visual_input(
|
||||
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
|
||||
) -> SACConfig:
|
||||
config = create_default_config(
|
||||
state_dim=state_dim,
|
||||
continuous_action_dim=continuous_action_dim,
|
||||
has_discrete_action=has_discrete_action,
|
||||
)
|
||||
config.input_features["observation.image"] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
|
||||
config.dataset_stats["observation.image"] = {
|
||||
"mean": torch.randn(3, 1, 1),
|
||||
"std": torch.randn(3, 1, 1),
|
||||
}
|
||||
|
||||
# Let make tests a little bit faster
|
||||
config.state_encoder_hidden_dim = 32
|
||||
config.latent_dim = 32
|
||||
|
||||
config.validate_features()
|
||||
return config
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_policy_with_default_config(batch_size: int, state_dim: int, action_dim: int):
|
||||
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
assert temperature_loss.item() is not None
|
||||
assert temperature_loss.shape == ()
|
||||
|
||||
temperature_loss.backward()
|
||||
optimizers["temperature"].step()
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape == (batch_size, action_dim)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
policy = SACPolicy(config=config)
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
assert temperature_loss.item() is not None
|
||||
assert temperature_loss.shape == ()
|
||||
|
||||
temperature_loss.backward()
|
||||
optimizers["temperature"].step()
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim
|
||||
)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape == (batch_size, action_dim)
|
||||
|
||||
|
||||
# Let's check best candidates for pretrained encoders
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size,state_dim,action_dim,vision_encoder_name",
|
||||
[(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
|
||||
)
|
||||
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
|
||||
def test_sac_policy_with_pretrained_encoder(
|
||||
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
|
||||
):
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.vision_encoder_name = vision_encoder_name
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
|
||||
def test_sac_policy_with_shared_encoder():
|
||||
batch_size = 2
|
||||
action_dim = 10
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.shared_encoder = True
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
|
||||
def test_sac_policy_with_discrete_critic():
|
||||
batch_size = 2
|
||||
continuous_action_dim = 9
|
||||
full_action_dim = continuous_action_dim + 1 # the last action is discrete
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(
|
||||
state_dim=state_dim, continuous_action_dim=continuous_action_dim, has_discrete_action=True
|
||||
)
|
||||
|
||||
num_discrete_actions = 5
|
||||
config.num_discrete_actions = num_discrete_actions
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=full_action_dim
|
||||
)
|
||||
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy, has_discrete_action=True)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
discrete_critic_loss = policy.forward(batch, model="discrete_critic")["loss_discrete_critic"]
|
||||
assert discrete_critic_loss.item() is not None
|
||||
assert discrete_critic_loss.shape == ()
|
||||
discrete_critic_loss.backward()
|
||||
optimizers["discrete_critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim
|
||||
)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape == (batch_size, full_action_dim)
|
||||
|
||||
discrete_actions = selected_action[:, -1].long()
|
||||
discrete_action_values = set(discrete_actions.tolist())
|
||||
|
||||
assert all(action in range(num_discrete_actions) for action in discrete_action_values), (
|
||||
f"Discrete action {discrete_action_values} is not in range({num_discrete_actions})"
|
||||
)
|
||||
|
||||
|
||||
def test_sac_policy_with_default_entropy():
|
||||
config = create_default_config(continuous_action_dim=10, state_dim=10)
|
||||
policy = SACPolicy(config=config)
|
||||
assert policy.target_entropy == -5.0
|
||||
|
||||
|
||||
def test_sac_policy_default_target_entropy_with_discrete_action():
|
||||
config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True)
|
||||
policy = SACPolicy(config=config)
|
||||
assert policy.target_entropy == -3.0
|
||||
|
||||
|
||||
def test_sac_policy_with_predefined_entropy():
|
||||
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
||||
config.target_entropy = -3.5
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
assert policy.target_entropy == pytest.approx(-3.5)
|
||||
|
||||
|
||||
def test_sac_policy_update_temperature():
|
||||
config = create_default_config(continuous_action_dim=10, state_dim=10)
|
||||
policy = SACPolicy(config=config)
|
||||
|
||||
assert policy.temperature == pytest.approx(1.0)
|
||||
policy.log_alpha.data = torch.tensor([math.log(0.1)])
|
||||
policy.update_temperature()
|
||||
assert policy.temperature == pytest.approx(0.1)
|
||||
|
||||
|
||||
def test_sac_policy_update_target_network():
|
||||
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
||||
config.critic_target_update_weight = 1.0
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
for p in policy.critic_ensemble.parameters():
|
||||
p.data = torch.ones_like(p.data)
|
||||
|
||||
policy.update_target_networks()
|
||||
for p in policy.critic_target.parameters():
|
||||
assert torch.allclose(p.data, torch.ones_like(p.data)), (
|
||||
f"Target network {p.data} is not equal to {torch.ones_like(p.data)}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_critics", [1, 3])
|
||||
def test_sac_policy_with_critics_number_of_heads(num_critics: int):
|
||||
batch_size = 2
|
||||
action_dim = 10
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.num_critics = num_critics
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
assert len(policy.critic_ensemble.critics) == num_critics
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
|
||||
def test_sac_policy_save_and_load(tmp_path):
|
||||
root = tmp_path / "test_sac_save_and_load"
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 10
|
||||
batch_size = 2
|
||||
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
policy = SACPolicy(config=config)
|
||||
policy.eval()
|
||||
policy.save_pretrained(root)
|
||||
loaded_policy = SACPolicy.from_pretrained(root, config=config)
|
||||
loaded_policy.eval()
|
||||
|
||||
batch = create_default_train_batch(batch_size=1, state_dim=10, action_dim=10)
|
||||
|
||||
with torch.no_grad():
|
||||
with seeded_context(12):
|
||||
# Collect policy values before saving
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
actions = policy.select_action(observation_batch)
|
||||
|
||||
with seeded_context(12):
|
||||
# Collect policy values after loading
|
||||
loaded_cirtic_loss = loaded_policy.forward(batch, model="critic")["loss_critic"]
|
||||
loaded_actor_loss = loaded_policy.forward(batch, model="actor")["loss_actor"]
|
||||
loaded_temperature_loss = loaded_policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
|
||||
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
loaded_actions = loaded_policy.select_action(loaded_observation_batch)
|
||||
|
||||
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
|
||||
for k in policy.state_dict():
|
||||
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
|
||||
|
||||
# Compare values before and after saving and loading
|
||||
# They should be the same
|
||||
assert torch.allclose(cirtic_loss, loaded_cirtic_loss)
|
||||
assert torch.allclose(actor_loss, loaded_actor_loss)
|
||||
assert torch.allclose(temperature_loss, loaded_temperature_loss)
|
||||
assert torch.allclose(actions, loaded_actions)
|
||||
599
tests/server/test_replay_buffer.py
Normal file
599
tests/server/test_replay_buffer.py
Normal file
@@ -0,0 +1,599 @@
|
||||
import sys
|
||||
from typing import Callable, Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.scripts.server.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def state_dims() -> list[str]:
|
||||
return ["observation.image", "observation.state"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def replay_buffer() -> ReplayBuffer:
|
||||
return create_empty_replay_buffer()
|
||||
|
||||
|
||||
def clone_state(state: dict) -> dict:
|
||||
return {k: v.clone() for k, v in state.items()}
|
||||
|
||||
|
||||
def create_empty_replay_buffer(
|
||||
optimize_memory: bool = False,
|
||||
use_drq: bool = False,
|
||||
image_augmentation_function: Optional[Callable] = None,
|
||||
) -> ReplayBuffer:
|
||||
buffer_capacity = 10
|
||||
device = "cpu"
|
||||
return ReplayBuffer(
|
||||
buffer_capacity,
|
||||
device,
|
||||
state_dims(),
|
||||
optimize_memory=optimize_memory,
|
||||
use_drq=use_drq,
|
||||
image_augmentation_function=image_augmentation_function,
|
||||
)
|
||||
|
||||
|
||||
def create_random_image() -> torch.Tensor:
|
||||
return torch.rand(3, 84, 84)
|
||||
|
||||
|
||||
def create_dummy_transition() -> dict:
|
||||
return {
|
||||
"observation.image": create_random_image(),
|
||||
"action": torch.randn(4),
|
||||
"reward": torch.tensor(1.0),
|
||||
"observation.state": torch.randn(
|
||||
10,
|
||||
),
|
||||
"done": torch.tensor(False),
|
||||
"truncated": torch.tensor(False),
|
||||
"complementary_info": {},
|
||||
}
|
||||
|
||||
|
||||
def create_dataset_from_replay_buffer(tmp_path) -> tuple[LeRobotDataset, ReplayBuffer]:
|
||||
dummy_state_1 = create_dummy_state()
|
||||
dummy_action_1 = create_dummy_action()
|
||||
|
||||
dummy_state_2 = create_dummy_state()
|
||||
dummy_action_2 = create_dummy_action()
|
||||
|
||||
dummy_state_3 = create_dummy_state()
|
||||
dummy_action_3 = create_dummy_action()
|
||||
|
||||
dummy_state_4 = create_dummy_state()
|
||||
dummy_action_4 = create_dummy_action()
|
||||
|
||||
replay_buffer = create_empty_replay_buffer()
|
||||
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
|
||||
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False)
|
||||
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True)
|
||||
replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True)
|
||||
|
||||
root = tmp_path / "test"
|
||||
return (replay_buffer.to_lerobot_dataset(DUMMY_REPO_ID, root=root), replay_buffer)
|
||||
|
||||
|
||||
def create_dummy_state() -> dict:
|
||||
return {
|
||||
"observation.image": create_random_image(),
|
||||
"observation.state": torch.randn(
|
||||
10,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_tensor_memory_consumption(tensor):
|
||||
return tensor.nelement() * tensor.element_size()
|
||||
|
||||
|
||||
def get_tensors_memory_consumption(obj, visited_addresses):
|
||||
total_size = 0
|
||||
|
||||
address = id(obj)
|
||||
if address in visited_addresses:
|
||||
return 0
|
||||
|
||||
visited_addresses.add(address)
|
||||
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return get_tensor_memory_consumption(obj)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
for item in obj:
|
||||
total_size += get_tensors_memory_consumption(item, visited_addresses)
|
||||
elif isinstance(obj, dict):
|
||||
for value in obj.values():
|
||||
total_size += get_tensors_memory_consumption(value, visited_addresses)
|
||||
elif hasattr(obj, "__dict__"):
|
||||
# It's an object, we need to get the size of the attributes
|
||||
for _, attr in vars(obj).items():
|
||||
total_size += get_tensors_memory_consumption(attr, visited_addresses)
|
||||
|
||||
return total_size
|
||||
|
||||
|
||||
def get_object_memory(obj):
|
||||
# Track visited addresses to avoid infinite loops
|
||||
# and cases when two properties point to the same object
|
||||
visited_addresses = set()
|
||||
|
||||
# Get the size of the object in bytes
|
||||
total_size = sys.getsizeof(obj)
|
||||
|
||||
# Get the size of the tensor attributes
|
||||
total_size += get_tensors_memory_consumption(obj, visited_addresses)
|
||||
|
||||
return total_size
|
||||
|
||||
|
||||
def create_dummy_action() -> torch.Tensor:
|
||||
return torch.randn(4)
|
||||
|
||||
|
||||
def dict_properties() -> list:
|
||||
return ["state", "next_state"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_state() -> dict:
|
||||
return create_dummy_state()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def next_dummy_state() -> dict:
|
||||
return create_dummy_state()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_action() -> torch.Tensor:
|
||||
return torch.randn(4)
|
||||
|
||||
|
||||
def test_empty_buffer_sample_raises_error(replay_buffer):
|
||||
assert len(replay_buffer) == 0, "Replay buffer should be empty."
|
||||
assert replay_buffer.capacity == 10, "Replay buffer capacity should be 10."
|
||||
with pytest.raises(RuntimeError, match="Cannot sample from an empty buffer"):
|
||||
replay_buffer.sample(1)
|
||||
|
||||
|
||||
def test_zero_capacity_buffer_raises_error():
|
||||
with pytest.raises(ValueError, match="Capacity must be greater than 0."):
|
||||
ReplayBuffer(0, "cpu", ["observation", "next_observation"])
|
||||
|
||||
|
||||
def test_add_transition(replay_buffer, dummy_state, dummy_action):
|
||||
replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False)
|
||||
assert len(replay_buffer) == 1, "Replay buffer should have one transition after adding."
|
||||
assert torch.equal(replay_buffer.actions[0], dummy_action), (
|
||||
"Action should be equal to the first transition."
|
||||
)
|
||||
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the first transition."
|
||||
assert not replay_buffer.dones[0], "Done should be False for the first transition."
|
||||
assert not replay_buffer.truncateds[0], "Truncated should be False for the first transition."
|
||||
|
||||
for dim in state_dims():
|
||||
assert torch.equal(replay_buffer.states[dim][0], dummy_state[dim]), (
|
||||
"Observation should be equal to the first transition."
|
||||
)
|
||||
assert torch.equal(replay_buffer.next_states[dim][0], dummy_state[dim]), (
|
||||
"Next observation should be equal to the first transition."
|
||||
)
|
||||
|
||||
|
||||
def test_add_over_capacity():
|
||||
replay_buffer = ReplayBuffer(2, "cpu", ["observation", "next_observation"])
|
||||
dummy_state_1 = create_dummy_state()
|
||||
dummy_action_1 = create_dummy_action()
|
||||
|
||||
dummy_state_2 = create_dummy_state()
|
||||
dummy_action_2 = create_dummy_action()
|
||||
|
||||
dummy_state_3 = create_dummy_state()
|
||||
dummy_action_3 = create_dummy_action()
|
||||
|
||||
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
|
||||
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False)
|
||||
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True)
|
||||
|
||||
assert len(replay_buffer) == 2, "Replay buffer should have 2 transitions after adding 3."
|
||||
|
||||
for dim in state_dims():
|
||||
assert torch.equal(replay_buffer.states[dim][0], dummy_state_3[dim]), (
|
||||
"Observation should be equal to the first transition."
|
||||
)
|
||||
assert torch.equal(replay_buffer.next_states[dim][0], dummy_state_3[dim]), (
|
||||
"Next observation should be equal to the first transition."
|
||||
)
|
||||
|
||||
assert torch.equal(replay_buffer.actions[0], dummy_action_3), (
|
||||
"Action should be equal to the last transition."
|
||||
)
|
||||
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the last transition."
|
||||
assert replay_buffer.dones[0], "Done should be True for the first transition."
|
||||
assert replay_buffer.truncateds[0], "Truncated should be True for the first transition."
|
||||
|
||||
|
||||
def test_sample_from_empty_buffer(replay_buffer):
|
||||
with pytest.raises(RuntimeError, match="Cannot sample from an empty buffer"):
|
||||
replay_buffer.sample(1)
|
||||
|
||||
|
||||
def test_sample_with_1_transition(replay_buffer, dummy_state, next_dummy_state, dummy_action):
|
||||
replay_buffer.add(dummy_state, dummy_action, 1.0, next_dummy_state, False, False)
|
||||
got_batch_transition = replay_buffer.sample(1)
|
||||
|
||||
expected_batch_transition = BatchTransition(
|
||||
state=clone_state(dummy_state),
|
||||
action=dummy_action.clone(),
|
||||
reward=1.0,
|
||||
next_state=clone_state(next_dummy_state),
|
||||
done=False,
|
||||
truncated=False,
|
||||
)
|
||||
|
||||
for buffer_property in dict_properties():
|
||||
for k, v in expected_batch_transition[buffer_property].items():
|
||||
got_state = got_batch_transition[buffer_property][k]
|
||||
|
||||
assert got_state.shape[0] == 1, f"{k} should have 1 transition."
|
||||
assert got_state.device.type == "cpu", f"{k} should be on cpu."
|
||||
|
||||
assert torch.equal(got_state[0], v), f"{k} should be equal to the expected batch transition."
|
||||
|
||||
for key, _value in expected_batch_transition.items():
|
||||
if key in dict_properties():
|
||||
continue
|
||||
|
||||
got_value = got_batch_transition[key]
|
||||
|
||||
v_tensor = expected_batch_transition[key]
|
||||
if not isinstance(v_tensor, torch.Tensor):
|
||||
v_tensor = torch.tensor(v_tensor)
|
||||
|
||||
assert got_value.shape[0] == 1, f"{key} should have 1 transition."
|
||||
assert got_value.device.type == "cpu", f"{key} should be on cpu."
|
||||
assert torch.equal(got_value[0], v_tensor), f"{key} should be equal to the expected batch transition."
|
||||
|
||||
|
||||
def test_sample_with_batch_bigger_than_buffer_size(
|
||||
replay_buffer, dummy_state, next_dummy_state, dummy_action
|
||||
):
|
||||
replay_buffer.add(dummy_state, dummy_action, 1.0, next_dummy_state, False, False)
|
||||
got_batch_transition = replay_buffer.sample(10)
|
||||
|
||||
expected_batch_transition = BatchTransition(
|
||||
state=dummy_state,
|
||||
action=dummy_action,
|
||||
reward=1.0,
|
||||
next_state=next_dummy_state,
|
||||
done=False,
|
||||
truncated=False,
|
||||
)
|
||||
|
||||
for buffer_property in dict_properties():
|
||||
for k in expected_batch_transition[buffer_property]:
|
||||
got_state = got_batch_transition[buffer_property][k]
|
||||
|
||||
assert got_state.shape[0] == 1, f"{k} should have 1 transition."
|
||||
|
||||
for key in expected_batch_transition:
|
||||
if key in dict_properties():
|
||||
continue
|
||||
|
||||
got_value = got_batch_transition[key]
|
||||
assert got_value.shape[0] == 1, f"{key} should have 1 transition."
|
||||
|
||||
|
||||
def test_sample_batch(replay_buffer):
|
||||
dummy_state_1 = create_dummy_state()
|
||||
dummy_action_1 = create_dummy_action()
|
||||
|
||||
dummy_state_2 = create_dummy_state()
|
||||
dummy_action_2 = create_dummy_action()
|
||||
|
||||
dummy_state_3 = create_dummy_state()
|
||||
dummy_action_3 = create_dummy_action()
|
||||
|
||||
dummy_state_4 = create_dummy_state()
|
||||
dummy_action_4 = create_dummy_action()
|
||||
|
||||
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
|
||||
replay_buffer.add(dummy_state_2, dummy_action_2, 2.0, dummy_state_2, False, False)
|
||||
replay_buffer.add(dummy_state_3, dummy_action_3, 3.0, dummy_state_3, True, True)
|
||||
replay_buffer.add(dummy_state_4, dummy_action_4, 4.0, dummy_state_4, True, True)
|
||||
|
||||
dummy_states = [dummy_state_1, dummy_state_2, dummy_state_3, dummy_state_4]
|
||||
dummy_actions = [dummy_action_1, dummy_action_2, dummy_action_3, dummy_action_4]
|
||||
|
||||
got_batch_transition = replay_buffer.sample(3)
|
||||
|
||||
for buffer_property in dict_properties():
|
||||
for k in got_batch_transition[buffer_property]:
|
||||
got_state = got_batch_transition[buffer_property][k]
|
||||
|
||||
assert got_state.shape[0] == 3, f"{k} should have 3 transition."
|
||||
|
||||
for got_state_item in got_state:
|
||||
assert any(torch.equal(got_state_item, dummy_state[k]) for dummy_state in dummy_states), (
|
||||
f"{k} should be equal to one of the dummy states."
|
||||
)
|
||||
|
||||
for got_action_item in got_batch_transition["action"]:
|
||||
assert any(torch.equal(got_action_item, dummy_action) for dummy_action in dummy_actions), (
|
||||
"Actions should be equal to the dummy actions."
|
||||
)
|
||||
|
||||
for k in got_batch_transition:
|
||||
if k in dict_properties() or k == "complementary_info":
|
||||
continue
|
||||
|
||||
got_value = got_batch_transition[k]
|
||||
assert got_value.shape[0] == 3, f"{k} should have 3 transition."
|
||||
|
||||
|
||||
def test_to_lerobot_dataset_with_empty_buffer(replay_buffer):
|
||||
with pytest.raises(ValueError, match="The replay buffer is empty. Cannot convert to a dataset."):
|
||||
replay_buffer.to_lerobot_dataset("dummy_repo")
|
||||
|
||||
|
||||
def test_to_lerobot_dataset(tmp_path):
|
||||
ds, buffer = create_dataset_from_replay_buffer(tmp_path)
|
||||
|
||||
assert len(ds) == len(buffer), "Dataset should have the same size as the Replay Buffer"
|
||||
assert ds.fps == 1, "FPS should be 1"
|
||||
assert ds.repo_id == "dummy/repo", "The dataset should have `dummy/repo` repo id"
|
||||
|
||||
for dim in state_dims():
|
||||
assert dim in ds.features
|
||||
assert ds.features[dim]["shape"] == buffer.states[dim][0].shape
|
||||
|
||||
assert ds.num_episodes == 2
|
||||
assert ds.num_frames == 4
|
||||
|
||||
for j, value in enumerate(ds):
|
||||
print(torch.equal(value["observation.image"], buffer.next_states["observation.image"][j]))
|
||||
|
||||
for i in range(len(ds)):
|
||||
for feature, value in ds[i].items():
|
||||
if feature == "action":
|
||||
assert torch.equal(value, buffer.actions[i])
|
||||
elif feature == "next.reward":
|
||||
assert torch.equal(value, buffer.rewards[i])
|
||||
elif feature == "next.done":
|
||||
assert torch.equal(value, buffer.dones[i])
|
||||
elif feature == "observation.image":
|
||||
# Tenssor -> numpy is not precise, so we have some diff there
|
||||
# TODO: Check and fix it
|
||||
torch.testing.assert_close(value, buffer.states["observation.image"][i], rtol=0.3, atol=0.003)
|
||||
elif feature == "observation.state":
|
||||
assert torch.equal(value, buffer.states["observation.state"][i])
|
||||
|
||||
|
||||
def test_from_lerobot_dataset(tmp_path):
|
||||
dummy_state_1 = create_dummy_state()
|
||||
dummy_action_1 = create_dummy_action()
|
||||
|
||||
dummy_state_2 = create_dummy_state()
|
||||
dummy_action_2 = create_dummy_action()
|
||||
|
||||
dummy_state_3 = create_dummy_state()
|
||||
dummy_action_3 = create_dummy_action()
|
||||
|
||||
dummy_state_4 = create_dummy_state()
|
||||
dummy_action_4 = create_dummy_action()
|
||||
|
||||
replay_buffer = create_empty_replay_buffer()
|
||||
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
|
||||
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False)
|
||||
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True)
|
||||
replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True)
|
||||
|
||||
root = tmp_path / "test"
|
||||
ds = replay_buffer.to_lerobot_dataset(DUMMY_REPO_ID, root=root)
|
||||
|
||||
reconverted_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
ds, state_keys=list(state_dims()), device="cpu", capacity=replay_buffer.capacity, use_drq=False
|
||||
)
|
||||
|
||||
# Check only the part of the buffer that's actually filled with data
|
||||
assert torch.equal(
|
||||
reconverted_buffer.actions[: len(replay_buffer)],
|
||||
replay_buffer.actions[: len(replay_buffer)],
|
||||
), "Actions from converted buffer should be equal to the original replay buffer."
|
||||
assert torch.equal(
|
||||
reconverted_buffer.rewards[: len(replay_buffer)], replay_buffer.rewards[: len(replay_buffer)]
|
||||
), "Rewards from converted buffer should be equal to the original replay buffer."
|
||||
assert torch.equal(
|
||||
reconverted_buffer.dones[: len(replay_buffer)], replay_buffer.dones[: len(replay_buffer)]
|
||||
), "Dones from converted buffer should be equal to the original replay buffer."
|
||||
|
||||
# Lerobot DS haven't supported truncateds yet
|
||||
expected_truncateds = torch.zeros(len(replay_buffer)).bool()
|
||||
assert torch.equal(reconverted_buffer.truncateds[: len(replay_buffer)], expected_truncateds), (
|
||||
"Truncateds from converted buffer should be equal False"
|
||||
)
|
||||
|
||||
assert torch.equal(
|
||||
replay_buffer.states["observation.state"][: len(replay_buffer)],
|
||||
reconverted_buffer.states["observation.state"][: len(replay_buffer)],
|
||||
), "State should be the same after converting to dataset and return back"
|
||||
|
||||
for i in range(4):
|
||||
torch.testing.assert_close(
|
||||
replay_buffer.states["observation.image"][i],
|
||||
reconverted_buffer.states["observation.image"][i],
|
||||
rtol=0.4,
|
||||
atol=0.004,
|
||||
)
|
||||
|
||||
# The 2, 3 frames have done flag, so their values will be equal to the current state
|
||||
for i in range(2):
|
||||
# In the current implementation we take the next state from the `states` and ignore `next_states`
|
||||
next_index = (i + 1) % 4
|
||||
|
||||
torch.testing.assert_close(
|
||||
replay_buffer.states["observation.image"][next_index],
|
||||
reconverted_buffer.next_states["observation.image"][i],
|
||||
rtol=0.4,
|
||||
atol=0.004,
|
||||
)
|
||||
|
||||
for i in range(2, 4):
|
||||
assert torch.equal(
|
||||
replay_buffer.states["observation.state"][i],
|
||||
reconverted_buffer.next_states["observation.state"][i],
|
||||
)
|
||||
|
||||
|
||||
def test_buffer_sample_alignment():
|
||||
# Initialize buffer
|
||||
buffer = ReplayBuffer(capacity=100, device="cpu", state_keys=["state_value"], storage_device="cpu")
|
||||
|
||||
# Fill buffer with patterned data
|
||||
for i in range(100):
|
||||
signature = float(i) / 100.0
|
||||
state = {"state_value": torch.tensor([[signature]]).float()}
|
||||
action = torch.tensor([[2.0 * signature]]).float()
|
||||
reward = 3.0 * signature
|
||||
|
||||
is_end = (i + 1) % 10 == 0
|
||||
if is_end:
|
||||
next_state = {"state_value": torch.tensor([[signature]]).float()}
|
||||
done = True
|
||||
else:
|
||||
next_signature = float(i + 1) / 100.0
|
||||
next_state = {"state_value": torch.tensor([[next_signature]]).float()}
|
||||
done = False
|
||||
|
||||
buffer.add(state, action, reward, next_state, done, False)
|
||||
|
||||
# Sample and verify
|
||||
batch = buffer.sample(50)
|
||||
|
||||
for i in range(50):
|
||||
state_sig = batch["state"]["state_value"][i].item()
|
||||
action_val = batch["action"][i].item()
|
||||
reward_val = batch["reward"][i].item()
|
||||
next_state_sig = batch["next_state"]["state_value"][i].item()
|
||||
is_done = batch["done"][i].item() > 0.5
|
||||
|
||||
# Verify relationships
|
||||
assert abs(action_val - 2.0 * state_sig) < 1e-4, (
|
||||
f"Action {action_val} should be 2x state signature {state_sig}"
|
||||
)
|
||||
|
||||
assert abs(reward_val - 3.0 * state_sig) < 1e-4, (
|
||||
f"Reward {reward_val} should be 3x state signature {state_sig}"
|
||||
)
|
||||
|
||||
if is_done:
|
||||
assert abs(next_state_sig - state_sig) < 1e-4, (
|
||||
f"For done states, next_state {next_state_sig} should equal state {state_sig}"
|
||||
)
|
||||
else:
|
||||
# Either it's the next sequential state (+0.01) or same state (for episode boundaries)
|
||||
valid_next = (
|
||||
abs(next_state_sig - state_sig - 0.01) < 1e-4 or abs(next_state_sig - state_sig) < 1e-4
|
||||
)
|
||||
assert valid_next, (
|
||||
f"Next state {next_state_sig} should be either state+0.01 or same as state {state_sig}"
|
||||
)
|
||||
|
||||
|
||||
def test_memory_optimization():
|
||||
dummy_state_1 = create_dummy_state()
|
||||
dummy_action_1 = create_dummy_action()
|
||||
|
||||
dummy_state_2 = create_dummy_state()
|
||||
dummy_action_2 = create_dummy_action()
|
||||
|
||||
dummy_state_3 = create_dummy_state()
|
||||
dummy_action_3 = create_dummy_action()
|
||||
|
||||
dummy_state_4 = create_dummy_state()
|
||||
dummy_action_4 = create_dummy_action()
|
||||
|
||||
replay_buffer = create_empty_replay_buffer()
|
||||
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_2, False, False)
|
||||
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_3, False, False)
|
||||
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_4, False, False)
|
||||
replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True)
|
||||
|
||||
optimized_replay_buffer = create_empty_replay_buffer(True)
|
||||
optimized_replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_2, False, False)
|
||||
optimized_replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_3, False, False)
|
||||
optimized_replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_4, False, False)
|
||||
optimized_replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, None, True, True)
|
||||
|
||||
assert get_object_memory(optimized_replay_buffer) < get_object_memory(replay_buffer), (
|
||||
"Optimized replay buffer should be smaller than the original replay buffer"
|
||||
)
|
||||
|
||||
|
||||
def test_check_image_augmentations_with_drq_and_dummy_image_augmentation_function(dummy_state, dummy_action):
|
||||
def dummy_image_augmentation_function(x):
|
||||
return torch.ones_like(x) * 10
|
||||
|
||||
replay_buffer = create_empty_replay_buffer(
|
||||
use_drq=True, image_augmentation_function=dummy_image_augmentation_function
|
||||
)
|
||||
|
||||
replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False)
|
||||
|
||||
sampled_transitions = replay_buffer.sample(1)
|
||||
assert torch.all(sampled_transitions["state"]["observation.image"] == 10), (
|
||||
"Image augmentations should be applied"
|
||||
)
|
||||
assert torch.all(sampled_transitions["next_state"]["observation.image"] == 10), (
|
||||
"Image augmentations should be applied"
|
||||
)
|
||||
|
||||
|
||||
def test_check_image_augmentations_with_drq_and_default_image_augmentation_function(
|
||||
dummy_state, dummy_action
|
||||
):
|
||||
replay_buffer = create_empty_replay_buffer(use_drq=True)
|
||||
|
||||
replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False)
|
||||
|
||||
# Let's check that it doesn't fail and shapes are correct
|
||||
sampled_transitions = replay_buffer.sample(1)
|
||||
assert sampled_transitions["state"]["observation.image"].shape == (1, 3, 84, 84)
|
||||
assert sampled_transitions["next_state"]["observation.image"].shape == (1, 3, 84, 84)
|
||||
|
||||
|
||||
def test_random_crop_vectorized_basic():
|
||||
# Create a batch of 2 images with known patterns
|
||||
batch_size, channels, height, width = 2, 3, 10, 8
|
||||
images = torch.zeros((batch_size, channels, height, width))
|
||||
|
||||
# Fill with unique values for testing
|
||||
for b in range(batch_size):
|
||||
images[b] = b + 1
|
||||
|
||||
crop_size = (6, 4) # Smaller than original
|
||||
cropped = random_crop_vectorized(images, crop_size)
|
||||
|
||||
# Check output shape
|
||||
assert cropped.shape == (batch_size, channels, *crop_size)
|
||||
|
||||
# Check that values are preserved (should be either 1s or 2s for respective batches)
|
||||
assert torch.all(cropped[0] == 1)
|
||||
assert torch.all(cropped[1] == 2)
|
||||
|
||||
|
||||
def test_random_crop_vectorized_invalid_size():
|
||||
images = torch.zeros((2, 3, 10, 8))
|
||||
|
||||
# Test crop size larger than image
|
||||
with pytest.raises(ValueError, match="Requested crop size .* is bigger than the image size"):
|
||||
random_crop_vectorized(images, (12, 8))
|
||||
|
||||
with pytest.raises(ValueError, match="Requested crop size .* is bigger than the image size"):
|
||||
random_crop_vectorized(images, (10, 10))
|
||||
Reference in New Issue
Block a user