forked from tangger/lerobot
Compare commits
43 Commits
user/miche
...
user/miche
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
daa1480a91 | ||
|
|
71ec721e48 | ||
|
|
bbb5ba0adf | ||
|
|
844bfcf484 | ||
|
|
13441f0d98 | ||
|
|
41b377211c | ||
|
|
9ceb68ee90 | ||
|
|
d1baa5a82f | ||
|
|
04da4dd3e3 | ||
|
|
b0e2fcdba7 | ||
|
|
1e2a757cd3 | ||
|
|
ab842ba6ae | ||
|
|
94a7221a94 | ||
|
|
00dadcace0 | ||
|
|
81a2f2958d | ||
|
|
68b4fb60ad | ||
|
|
96b2b62377 | ||
|
|
b5c98bbfd3 | ||
|
|
58e12cf2e8 | ||
|
|
d8b5fae622 | ||
|
|
67ac81d728 | ||
|
|
b5f1ea3140 | ||
|
|
4d854a1513 | ||
|
|
87da655eab | ||
|
|
a8fda9c61a | ||
|
|
55505ff817 | ||
|
|
20d31ab8e0 | ||
|
|
e5b83aab5e | ||
|
|
a9d5f62304 | ||
|
|
72e1ed7058 | ||
|
|
d8e67a2609 | ||
|
|
50e12376de | ||
|
|
73aa6c25f3 | ||
|
|
380b836eee | ||
|
|
eec6796cb8 | ||
|
|
25a8597680 | ||
|
|
b8b368310c | ||
|
|
5097cd900e | ||
|
|
bc16e1b497 | ||
|
|
8f821ecad0 | ||
|
|
4519016e67 | ||
|
|
59e2757434 | ||
|
|
73b64c3089 |
18
checkport.py
18
checkport.py
@@ -1,18 +0,0 @@
|
||||
import socket
|
||||
|
||||
|
||||
def check_port(host, port):
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
try:
|
||||
s.connect((host, port))
|
||||
print(f"Connection successful to {host}:{port}!")
|
||||
except Exception as e:
|
||||
print(f"Connection failed to {host}:{port}: {e}")
|
||||
finally:
|
||||
s.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
host = "127.0.0.1" # or "localhost"
|
||||
port = 51350
|
||||
check_port(host, port)
|
||||
@@ -74,23 +74,7 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||
|
||||
image_transforms = None
|
||||
if cfg.training.image_transforms.enable:
|
||||
default_tf = OmegaConf.create(
|
||||
{
|
||||
"brightness": {"weight": 0.0, "min_max": None},
|
||||
"contrast": {"weight": 0.0, "min_max": None},
|
||||
"saturation": {"weight": 0.0, "min_max": None},
|
||||
"hue": {"weight": 0.0, "min_max": None},
|
||||
"sharpness": {"weight": 0.0, "min_max": None},
|
||||
"max_num_transforms": None,
|
||||
"random_order": False,
|
||||
"image_size": None,
|
||||
"interpolation": None,
|
||||
"image_mean": None,
|
||||
"image_std": None,
|
||||
}
|
||||
)
|
||||
cfg_tf = OmegaConf.merge(OmegaConf.create(default_tf), cfg.training.image_transforms)
|
||||
|
||||
cfg_tf = cfg.training.image_transforms
|
||||
image_transforms = get_image_transforms(
|
||||
brightness_weight=cfg_tf.brightness.weight,
|
||||
brightness_min_max=cfg_tf.brightness.min_max,
|
||||
@@ -104,10 +88,6 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||
sharpness_min_max=cfg_tf.sharpness.min_max,
|
||||
max_num_transforms=cfg_tf.max_num_transforms,
|
||||
random_order=cfg_tf.random_order,
|
||||
image_size=(cfg_tf.image_size.height, cfg_tf.image_size.width) if cfg_tf.image_size else None,
|
||||
interpolation=cfg_tf.interpolation,
|
||||
image_mean=cfg_tf.image_mean,
|
||||
image_std=cfg_tf.image_std,
|
||||
)
|
||||
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
|
||||
@@ -84,8 +84,7 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
# Load metadata
|
||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||
if not self.local_files_only:
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.info = load_info(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
self.tasks = load_tasks(self.root)
|
||||
@@ -538,8 +537,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
]
|
||||
files += video_files
|
||||
|
||||
if not self.local_files_only:
|
||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
|
||||
def load_hf_dataset(self) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
|
||||
@@ -150,10 +150,6 @@ def get_image_transforms(
|
||||
sharpness_min_max: tuple[float, float] | None = None,
|
||||
max_num_transforms: int | None = None,
|
||||
random_order: bool = False,
|
||||
interpolation: str | None = None,
|
||||
image_size: tuple[int, int] | None = None,
|
||||
image_mean: list[float] | None = None,
|
||||
image_std: list[float] | None = None,
|
||||
):
|
||||
def check_value(name, weight, min_max):
|
||||
if min_max is not None:
|
||||
@@ -174,18 +170,6 @@ def get_image_transforms(
|
||||
|
||||
weights = []
|
||||
transforms = []
|
||||
if image_size is not None:
|
||||
interpolations = [interpolation.value for interpolation in v2.InterpolationMode]
|
||||
if interpolation is None:
|
||||
# Use BICUBIC as default interpolation
|
||||
interpolation_mode = v2.InterpolationMode.BICUBIC
|
||||
elif interpolation in interpolations:
|
||||
interpolation_mode = v2.InterpolationMode(interpolation)
|
||||
else:
|
||||
raise ValueError("The interpolation passed is not supported")
|
||||
# Weight for resizing is always 1
|
||||
weights.append(1.0)
|
||||
transforms.append(v2.Resize(size=(image_size[0], image_size[1]), interpolation=interpolation_mode))
|
||||
if brightness_min_max is not None and brightness_weight > 0.0:
|
||||
weights.append(brightness_weight)
|
||||
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
|
||||
@@ -201,15 +185,6 @@ def get_image_transforms(
|
||||
if sharpness_min_max is not None and sharpness_weight > 0.0:
|
||||
weights.append(sharpness_weight)
|
||||
transforms.append(SharpnessJitter(sharpness=sharpness_min_max))
|
||||
if image_mean is not None and image_std is not None:
|
||||
# Weight for normalization is always 1
|
||||
weights.append(1.0)
|
||||
transforms.append(
|
||||
v2.Normalize(
|
||||
mean=image_mean,
|
||||
std=image_std,
|
||||
)
|
||||
)
|
||||
|
||||
n_subset = len(transforms)
|
||||
if max_num_transforms is not None:
|
||||
|
||||
@@ -275,7 +275,6 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
hf_features[key] = datasets.Sequence(
|
||||
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
|
||||
)
|
||||
# TODO: (alibers, azouitine) Add support for ft["shap"] == 0 as Value
|
||||
|
||||
return datasets.Features(hf_features)
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
# from mani_skill.utils import common
|
||||
from mani_skill.utils import common
|
||||
|
||||
|
||||
def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
|
||||
@@ -122,30 +122,28 @@ class PixelWrapper(gym.Wrapper):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
return self._get_obs(obs), reward, terminated, truncated, info
|
||||
|
||||
|
||||
class ConvertToLeRobotEnv(gym.Wrapper):
|
||||
def __init__(self, env, num_envs):
|
||||
super().__init__(env)
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
obs, info = self.env.reset(seed=seed, options={})
|
||||
return self._get_obs(obs), info
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
return self._get_obs(obs), reward, terminated, truncated, info
|
||||
|
||||
def _get_obs(self, observation):
|
||||
sensor_data = observation.pop("sensor_data")
|
||||
del observation["sensor_param"]
|
||||
images = []
|
||||
for cam_data in sensor_data.values():
|
||||
images.append(cam_data["rgb"])
|
||||
images.append(cam_data["rgb"])
|
||||
|
||||
images = torch.concat(images, axis=-1)
|
||||
# flatten the rest of the data which should just be state data
|
||||
observation = common.flatten_state_dict(observation, use_torch=True, device=self.base_env.device)
|
||||
observation = common.flatten_state_dict(
|
||||
observation, use_torch=True, device=self.base_env.device
|
||||
)
|
||||
ret = dict()
|
||||
ret["state"] = observation
|
||||
ret["pixels"] = images
|
||||
return ret
|
||||
return ret
|
||||
@@ -31,25 +31,28 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
# TODO: You have to merge all tensors from agent key and extra key
|
||||
# You don't keep sensor param key in the observation
|
||||
# And you keep sensor data rgb
|
||||
for key, img in observations.items():
|
||||
if "images" not in key:
|
||||
continue
|
||||
if "pixels" in observations:
|
||||
if isinstance(observations["pixels"], dict):
|
||||
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||
else:
|
||||
imgs = {"observation.image": observations["pixels"]}
|
||||
|
||||
if img.ndim == 3:
|
||||
img = img.unsqueeze(0)
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
for imgkey, img in imgs.items():
|
||||
img = torch.from_numpy(img)
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
|
||||
return_observations[key] = img
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
return_observations[imgkey] = img
|
||||
# obs state agent qpos and qvel
|
||||
# image
|
||||
|
||||
@@ -60,8 +63,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
|
||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||
# requirement for "agent_pos"
|
||||
# return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
||||
return_observations["observation.state"] = observations["observation.state"].float()
|
||||
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
||||
return return_observations
|
||||
|
||||
|
||||
|
||||
@@ -127,8 +127,6 @@ class Logger:
|
||||
job_type="train_eval",
|
||||
resume="must" if cfg.resume else None,
|
||||
)
|
||||
# Handle custom step key for rl asynchronous training.
|
||||
self._wandb_custom_step_key: set[str] | None = None
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
@@ -174,32 +172,18 @@ class Logger:
|
||||
self,
|
||||
save_dir: Path,
|
||||
train_step: int,
|
||||
optimizer: Optimizer | dict,
|
||||
optimizer: Optimizer,
|
||||
scheduler: LRScheduler | None,
|
||||
interaction_step: int | None = None,
|
||||
):
|
||||
"""Checkpoint the global training_step, optimizer state, scheduler state, and random state.
|
||||
|
||||
All of these are saved as "training_state.pth" under the checkpoint directory.
|
||||
"""
|
||||
# In Sac, for example, we have a dictionary of torch.optim.Optimizer
|
||||
if type(optimizer) is dict:
|
||||
optimizer_state_dict = {}
|
||||
for k in optimizer:
|
||||
optimizer_state_dict[k] = optimizer[k].state_dict()
|
||||
else:
|
||||
optimizer_state_dict = optimizer.state_dict()
|
||||
|
||||
training_state = {
|
||||
"step": train_step,
|
||||
"optimizer": optimizer_state_dict,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
**get_global_random_state(),
|
||||
}
|
||||
# Interaction step is related to the distributed training code
|
||||
# In that setup, we have two kinds of steps, the online step of the env and the optimization step
|
||||
# We need to save both in order to resume the optimization properly and not break the logs dependant on the interaction step
|
||||
if interaction_step is not None:
|
||||
training_state["interaction_step"] = interaction_step
|
||||
if scheduler is not None:
|
||||
training_state["scheduler"] = scheduler.state_dict()
|
||||
torch.save(training_state, save_dir / self.training_state_file_name)
|
||||
@@ -211,7 +195,6 @@ class Logger:
|
||||
optimizer: Optimizer,
|
||||
scheduler: LRScheduler | None,
|
||||
identifier: str,
|
||||
interaction_step: int | None = None,
|
||||
):
|
||||
"""Checkpoint the model weights and the training state."""
|
||||
checkpoint_dir = self.checkpoints_dir / str(identifier)
|
||||
@@ -223,24 +206,16 @@ class Logger:
|
||||
self.save_model(
|
||||
checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name
|
||||
)
|
||||
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler, interaction_step)
|
||||
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler)
|
||||
os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir)
|
||||
|
||||
def load_last_training_state(self, optimizer: Optimizer | dict, scheduler: LRScheduler | None) -> int:
|
||||
def load_last_training_state(self, optimizer: Optimizer, scheduler: LRScheduler | None) -> int:
|
||||
"""
|
||||
Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and
|
||||
random state, and return the global training step.
|
||||
"""
|
||||
training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name)
|
||||
# For the case where the optimizer is a dictionary of optimizers (e.g., sac)
|
||||
if type(training_state["optimizer"]) is dict:
|
||||
assert set(training_state["optimizer"].keys()) == set(optimizer.keys()), (
|
||||
"Optimizer dictionaries do not have the same keys during resume!"
|
||||
)
|
||||
for k, v in training_state["optimizer"].items():
|
||||
optimizer[k].load_state_dict(v)
|
||||
else:
|
||||
optimizer.load_state_dict(training_state["optimizer"])
|
||||
optimizer.load_state_dict(training_state["optimizer"])
|
||||
if scheduler is not None:
|
||||
scheduler.load_state_dict(training_state["scheduler"])
|
||||
elif "scheduler" in training_state:
|
||||
@@ -251,44 +226,17 @@ class Logger:
|
||||
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
|
||||
return training_state["step"]
|
||||
|
||||
def log_dict(self, d, step: int | None = None, mode="train", custom_step_key: str | None = None):
|
||||
"""Log a dictionary of metrics to WandB."""
|
||||
def log_dict(self, d, step, mode="train"):
|
||||
assert mode in {"train", "eval"}
|
||||
# TODO(alexander-soare): Add local text log.
|
||||
if step is None and custom_step_key is None:
|
||||
raise ValueError("Either step or custom_step_key must be provided.")
|
||||
|
||||
if self._wandb is not None:
|
||||
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it
|
||||
# increases with each wandb.log call, but in the case of asynchronous RL for example,
|
||||
# multiple time steps is possible for example, the interaction step with the environment,
|
||||
# the training step, the evaluation step, etc. So we need to define a custom step key
|
||||
# to log the correct step for each metric.
|
||||
if custom_step_key is not None:
|
||||
if self._wandb_custom_step_key is None:
|
||||
self._wandb_custom_step_key = set()
|
||||
new_custom_key = f"{mode}/{custom_step_key}"
|
||||
if new_custom_key not in self._wandb_custom_step_key:
|
||||
self._wandb_custom_step_key.add(new_custom_key)
|
||||
self._wandb.define_metric(new_custom_key, hidden=True)
|
||||
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str, wandb.Table)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
|
||||
# Do not log the custom step key itself.
|
||||
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
|
||||
continue
|
||||
|
||||
if custom_step_key is not None:
|
||||
value_custom_step = d[custom_step_key]
|
||||
self._wandb.log({f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step})
|
||||
continue
|
||||
|
||||
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
assert mode in {"train", "eval"}
|
||||
|
||||
@@ -76,11 +76,7 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
|
||||
|
||||
|
||||
def make_policy(
|
||||
hydra_cfg: DictConfig,
|
||||
pretrained_policy_name_or_path: str | None = None,
|
||||
dataset_stats=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
hydra_cfg: DictConfig, pretrained_policy_name_or_path: str | None = None, dataset_stats=None
|
||||
) -> Policy:
|
||||
"""Make an instance of a policy class.
|
||||
|
||||
@@ -104,9 +100,7 @@ def make_policy(
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
|
||||
if pretrained_policy_name_or_path is None:
|
||||
# Make a fresh policy.
|
||||
# HACK: We pass *args and **kwargs to the policy constructor to allow for additional arguments
|
||||
# for example device for the sac policy.
|
||||
policy = policy_cls(config=policy_cfg, dataset_stats=dataset_stats)
|
||||
policy = policy_cls(policy_cfg, dataset_stats)
|
||||
else:
|
||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||
# hyperparameters that we want to vary).
|
||||
|
||||
@@ -10,7 +10,7 @@ class ClassifierConfig:
|
||||
num_classes: int = 2
|
||||
hidden_dim: int = 256
|
||||
dropout_rate: float = 0.1
|
||||
model_name: str = "helper2424/resnet10"
|
||||
model_name: str = "microsoft/resnet-50"
|
||||
device: str = "cpu"
|
||||
model_type: str = "cnn" # "transformer" or "cnn"
|
||||
num_cameras: int = 2
|
||||
|
||||
@@ -47,7 +47,7 @@ class Classifier(
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
# Extract vision model if we're given a multimodal model
|
||||
if hasattr(encoder, "vision_model"):
|
||||
@@ -108,12 +108,11 @@ class Classifier(
|
||||
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Extract the appropriate output from the encoder."""
|
||||
# Process images with the processor (handles resizing and normalization)
|
||||
# processed = self.processor(
|
||||
# images=x, # LeRobotDataset already provides proper tensor format
|
||||
# return_tensors="pt",
|
||||
# )
|
||||
# processed = processed["pixel_values"].to(x.device)
|
||||
processed = x
|
||||
processed = self.processor(
|
||||
images=x, # LeRobotDataset already provides proper tensor format
|
||||
return_tensors="pt",
|
||||
)
|
||||
processed = processed["pixel_values"].to(x.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.is_cnn:
|
||||
@@ -145,10 +144,8 @@ class Classifier(
|
||||
|
||||
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
|
||||
|
||||
def predict_reward(self, x, threshold=0.6):
|
||||
def predict_reward(self, x):
|
||||
if self.config.num_classes == 2:
|
||||
probs = self.forward(x).probabilities
|
||||
logging.debug(f"Predicted reward images: {probs}")
|
||||
return (probs > threshold).float()
|
||||
return (self.forward(x).probabilities > 0.5).float()
|
||||
else:
|
||||
return torch.argmax(self.forward(x).probabilities, dim=1)
|
||||
|
||||
@@ -39,32 +39,16 @@ class SACConfig:
|
||||
"observation.environment_state": "min_max",
|
||||
}
|
||||
)
|
||||
input_normalization_params: dict[str, dict[str, list[float]]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": {"mean": [[0.485, 0.456, 0.406]], "std": [[0.229, 0.224, 0.225]]},
|
||||
"observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]},
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
|
||||
output_normalization_params: dict[str, dict[str, list[float]]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": {"min": [-1, -1], "max": [1, 1]},
|
||||
}
|
||||
)
|
||||
# TODO: Move it outside of the config
|
||||
actor_learner_config: dict[str, str | int] = field(
|
||||
default_factory=lambda: {
|
||||
"actor_ip": "127.0.0.1",
|
||||
"port": 50051,
|
||||
"learner_ip": "127.0.0.1",
|
||||
}
|
||||
)
|
||||
camera_number: int = 1
|
||||
# Add type annotations for these fields:
|
||||
vision_encoder_name: str | None = field(default="helper2424/resnet10")
|
||||
freeze_vision_encoder: bool = True
|
||||
image_encoder_hidden_dim: int = 32
|
||||
shared_encoder: bool = True
|
||||
shared_encoder: bool = False
|
||||
discount: float = 0.99
|
||||
temperature_init: float = 1.0
|
||||
num_critics: int = 2
|
||||
@@ -95,6 +79,5 @@ class SACConfig:
|
||||
"use_tanh_squash": True,
|
||||
"log_std_min": -5,
|
||||
"log_std_max": 2,
|
||||
"init_final": 0.005,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -17,7 +17,8 @@
|
||||
|
||||
# TODO: (1) better device management
|
||||
|
||||
from typing import Callable, Optional, Tuple
|
||||
from collections import deque
|
||||
from typing import Callable, Optional, Sequence, Tuple, Union
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
@@ -29,7 +30,6 @@ from torch import Tensor
|
||||
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
|
||||
|
||||
class SACPolicy(
|
||||
@@ -45,26 +45,25 @@ class SACPolicy(
|
||||
self,
|
||||
config: SACConfig | None = None,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
device: str = "cpu",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = SACConfig()
|
||||
self.config = config
|
||||
|
||||
if config.input_normalization_modes is not None:
|
||||
input_normalization_params = _convert_normalization_params_to_tensor(
|
||||
config.input_normalization_params
|
||||
)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, input_normalization_params
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
|
||||
output_normalization_params = _convert_normalization_params_to_tensor(
|
||||
config.output_normalization_params
|
||||
)
|
||||
output_normalization_params = {}
|
||||
for outer_key, inner_dict in config.output_normalization_params.items():
|
||||
output_normalization_params[outer_key] = {}
|
||||
for key, value in inner_dict.items():
|
||||
output_normalization_params[outer_key][key] = torch.tensor(value)
|
||||
|
||||
# HACK: This is hacky and should be removed
|
||||
dataset_stats = dataset_stats or output_normalization_params
|
||||
@@ -75,80 +74,88 @@ class SACPolicy(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
# NOTE: For images the encoder should be shared between the actor and critic
|
||||
if config.shared_encoder:
|
||||
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
||||
encoder_critic = SACObservationEncoder(config)
|
||||
encoder_actor: SACObservationEncoder = encoder_critic
|
||||
else:
|
||||
encoder_critic = SACObservationEncoder(config)
|
||||
encoder_actor = SACObservationEncoder(config)
|
||||
# Define networks
|
||||
critic_nets = []
|
||||
for _ in range(config.num_critics):
|
||||
critic_net = Critic(
|
||||
encoder=encoder_critic,
|
||||
network=MLP(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
critic_nets.append(critic_net)
|
||||
|
||||
self.critic_ensemble = CriticEnsemble(
|
||||
encoder=encoder_critic,
|
||||
network_list=nn.ModuleList(
|
||||
[
|
||||
MLP(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
),
|
||||
output_normalization=self.normalize_targets,
|
||||
target_critic_nets = []
|
||||
for _ in range(config.num_critics):
|
||||
target_critic_net = Critic(
|
||||
encoder=encoder_critic,
|
||||
network=MLP(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
target_critic_nets.append(target_critic_net)
|
||||
|
||||
self.critic_ensemble = create_critic_ensemble(
|
||||
critics=critic_nets, num_critics=config.num_critics, device=device
|
||||
)
|
||||
|
||||
self.critic_target = CriticEnsemble(
|
||||
encoder=encoder_critic,
|
||||
network_list=nn.ModuleList(
|
||||
[
|
||||
MLP(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
),
|
||||
output_normalization=self.normalize_targets,
|
||||
self.critic_target = create_critic_ensemble(
|
||||
critics=target_critic_nets, num_critics=config.num_critics, device=device
|
||||
)
|
||||
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
|
||||
action_dim=config.output_shapes["action"][0],
|
||||
device=device,
|
||||
encoder_is_shared=config.shared_encoder,
|
||||
**config.policy_kwargs,
|
||||
)
|
||||
if config.target_entropy is None:
|
||||
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
|
||||
|
||||
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
||||
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
|
||||
# it triggers "can't optimize a non-leaf Tensor"
|
||||
self.log_alpha = nn.Parameter(torch.tensor([0.0]))
|
||||
# TODO: Handle the case where the temparameter is a fixed
|
||||
self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
def reset(self):
|
||||
"""Reset the policy"""
|
||||
pass
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
"""Override .to(device) method to involve moving the log_alpha fixed_std"""
|
||||
if self.actor.fixed_std is not None:
|
||||
self.actor.fixed_std = self.actor.fixed_std.to(*args, **kwargs)
|
||||
# self.log_alpha = self.log_alpha.to(*args, **kwargs)
|
||||
super().to(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select action for inference/evaluation"""
|
||||
actions, _, _ = self.actor(batch)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
return actions
|
||||
|
||||
def critic_forward(self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False) -> Tensor:
|
||||
"""Forward pass through a critic network ensemble
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
actions: Action tensor
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from all critics
|
||||
"""
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = torch.stack([critic(observations, actions) for critic in critics])
|
||||
return q_values
|
||||
|
||||
|
||||
def critic_forward(
|
||||
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False, features: Optional[Tensor] = None
|
||||
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False
|
||||
) -> Tensor:
|
||||
"""Forward pass through a critic network ensemble
|
||||
|
||||
@@ -161,28 +168,27 @@ class SACPolicy(
|
||||
Tensor of Q-values from all critics
|
||||
"""
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = critics(observations, actions, features)
|
||||
q_values = torch.stack([critic(observations, actions) for critic in critics])
|
||||
return q_values
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ...
|
||||
def update_target_networks(self):
|
||||
"""Update target networks with exponential moving average"""
|
||||
for target_param, param in zip(
|
||||
self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=False
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False):
|
||||
for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
def compute_loss_critic(self, observations, actions, rewards, next_observations, done, obs_features=None, next_obs_features=None) -> Tensor:
|
||||
def compute_loss_critic(self, observations, actions, rewards, next_observations, done) -> Tensor:
|
||||
temperature = self.log_alpha.exp().item()
|
||||
with torch.no_grad():
|
||||
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_obs_features)
|
||||
next_action_preds, next_log_probs, _ = self.actor(next_observations)
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.critic_forward(
|
||||
observations=next_observations, actions=next_action_preds, use_target=True, features=next_obs_features
|
||||
observations=next_observations, actions=next_action_preds, use_target=True
|
||||
)
|
||||
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
@@ -199,7 +205,7 @@ class SACPolicy(
|
||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
q_preds = self.critic_forward(observations, actions, use_target=False, features=obs_features)
|
||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
@@ -214,18 +220,18 @@ class SACPolicy(
|
||||
).sum()
|
||||
return critics_loss
|
||||
|
||||
def compute_loss_temperature(self, observations, obs_features=None) -> Tensor:
|
||||
def compute_loss_temperature(self, observations) -> Tensor:
|
||||
"""Compute the temperature loss"""
|
||||
# calculate temperature loss
|
||||
with torch.no_grad():
|
||||
_, log_probs, _ = self.actor(observations, obs_features)
|
||||
_, log_probs, _ = self.actor(observations)
|
||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
|
||||
return temperature_loss
|
||||
|
||||
def compute_loss_actor(self, observations, obs_features=None) -> Tensor:
|
||||
def compute_loss_actor(self, observations) -> Tensor:
|
||||
temperature = self.log_alpha.exp().item()
|
||||
|
||||
actions_pi, log_probs, _ = self.actor(observations, obs_features)
|
||||
actions_pi, log_probs, _ = self.actor(observations)
|
||||
|
||||
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
@@ -272,115 +278,54 @@ class MLP(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class CriticEnsemble(nn.Module):
|
||||
"""
|
||||
┌──────────────────┬─────────────────────────────────────────────────────────┐
|
||||
│ Critic Ensemble │ │
|
||||
├──────────────────┘ │
|
||||
│ │
|
||||
│ ┌────┐ ┌────┐ ┌────┐ │
|
||||
│ │ Q1 │ │ Q2 │ │ Qn │ │
|
||||
│ └────┘ └────┘ └────┘ │
|
||||
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
|
||||
│ │ │ │ │ │ │ │
|
||||
│ │ MLP 1 │ │ MLP 2 │ │ MLP │ │
|
||||
│ │ │ │ │ ... │ num_critics │ │
|
||||
│ │ │ │ │ │ │ │
|
||||
│ └──────────────┘ └──────────────┘ └──────────────┘ │
|
||||
│ ▲ ▲ ▲ │
|
||||
│ └───────────────────┴───────┬────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ │ │
|
||||
│ ┌───────────────────┐ │
|
||||
│ │ Embedding │ │
|
||||
│ │ │ │
|
||||
│ └───────────────────┘ │
|
||||
│ ▲ │
|
||||
│ │ │
|
||||
│ ┌─────────────┴────────────┐ │
|
||||
│ │ │ │
|
||||
│ │ SACObservationEncoder │ │
|
||||
│ │ │ │
|
||||
│ └──────────────────────────┘ │
|
||||
│ ▲ │
|
||||
│ │ │
|
||||
│ │ │
|
||||
│ │ │
|
||||
└───────────────────────────┬────────────────────┬───────────────────────────┘
|
||||
│ Observation │
|
||||
└────────────────────┘
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
network_list: nn.ModuleList,
|
||||
output_normalization: nn.Module,
|
||||
network: nn.Module,
|
||||
init_final: Optional[float] = None,
|
||||
device: str = "cpu",
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
self.encoder = encoder
|
||||
self.network_list = network_list
|
||||
self.network = network
|
||||
self.init_final = init_final
|
||||
self.output_normalization = output_normalization
|
||||
|
||||
self.parameters_to_optimize = []
|
||||
# Handle the case where a part of the encoder if frozen
|
||||
if self.encoder is not None:
|
||||
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
|
||||
|
||||
self.parameters_to_optimize += list(self.network_list.parameters())
|
||||
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network_list[0].net):
|
||||
for layer in reversed(network.net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
out_features = layer.out_features
|
||||
break
|
||||
|
||||
|
||||
# Output layer
|
||||
self.output_layers = []
|
||||
if init_final is not None:
|
||||
for _ in network_list:
|
||||
output_layer = nn.Linear(out_features, 1)
|
||||
nn.init.uniform_(output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(output_layer.bias, -init_final, init_final)
|
||||
self.output_layers.append(output_layer)
|
||||
self.output_layer = nn.Linear(out_features, 1)
|
||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
self.output_layers = []
|
||||
for _ in network_list:
|
||||
output_layer = nn.Linear(out_features, 1)
|
||||
orthogonal_init()(output_layer.weight)
|
||||
self.output_layers.append(output_layer)
|
||||
|
||||
self.output_layers = nn.ModuleList(self.output_layers)
|
||||
self.parameters_to_optimize += list(self.output_layers.parameters())
|
||||
self.output_layer = nn.Linear(out_features, 1)
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: dict[str, torch.Tensor],
|
||||
actions: torch.Tensor,
|
||||
features: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
# Move observations to the correct device
|
||||
observations = {k: v.to(device) for k, v in observations.items()}
|
||||
# Normalize actions for sample efficiency
|
||||
actions: dict[str, torch.Tensor] = {"action": actions}
|
||||
actions = self.output_normalization(actions)["action"]
|
||||
actions = actions.to(device)
|
||||
|
||||
# Use precomputed features if provided; otherwise, encode observations.
|
||||
obs_enc = features if features is not None else (observations if self.encoder is None else self.encoder(observations))
|
||||
# Move each tensor in observations to device
|
||||
observations = {k: v.to(self.device) for k, v in observations.items()}
|
||||
actions = actions.to(self.device)
|
||||
|
||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
list_q_values = []
|
||||
for network, output_layer in zip(self.network_list, self.output_layers, strict=False):
|
||||
x = network(inputs)
|
||||
value = output_layer(x)
|
||||
list_q_values.append(value.squeeze(-1))
|
||||
return torch.stack(list_q_values)
|
||||
|
||||
x = self.network(inputs)
|
||||
value = self.output_layer(x)
|
||||
return value.squeeze(-1)
|
||||
|
||||
class Policy(nn.Module):
|
||||
def __init__(
|
||||
@@ -393,15 +338,17 @@ class Policy(nn.Module):
|
||||
fixed_std: Optional[torch.Tensor] = None,
|
||||
init_final: Optional[float] = None,
|
||||
use_tanh_squash: bool = False,
|
||||
device: str = "cpu",
|
||||
encoder_is_shared: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
self.encoder = encoder
|
||||
self.network = network
|
||||
self.action_dim = action_dim
|
||||
self.log_std_min = log_std_min
|
||||
self.log_std_max = log_std_max
|
||||
self.fixed_std = fixed_std
|
||||
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
|
||||
self.use_tanh_squash = use_tanh_squash
|
||||
self.parameters_to_optimize = []
|
||||
|
||||
@@ -433,14 +380,16 @@ class Policy(nn.Module):
|
||||
orthogonal_init()(self.std_layer.weight)
|
||||
self.parameters_to_optimize += list(self.std_layer.parameters())
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
features: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# Use precomputed features if provided; otherwise compute encoder representations.
|
||||
obs_enc = features if features is not None else (observations if self.encoder is None else self.encoder(observations))
|
||||
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
# Encode observations if encoder exists
|
||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||
|
||||
# Get network outputs
|
||||
outputs = self.network(obs_enc)
|
||||
means = self.mean_layer(outputs)
|
||||
@@ -449,7 +398,7 @@ class Policy(nn.Module):
|
||||
if self.fixed_std is None:
|
||||
log_std = self.std_layer(outputs)
|
||||
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
|
||||
|
||||
|
||||
if self.use_tanh_squash:
|
||||
log_std = torch.tanh(log_std)
|
||||
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
|
||||
@@ -457,8 +406,8 @@ class Policy(nn.Module):
|
||||
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
|
||||
else:
|
||||
log_std = self.fixed_std.expand_as(means)
|
||||
|
||||
# Get distribution and sample actions
|
||||
|
||||
# uses tanh activation function to squash the action to be in the range of [-1, 1]
|
||||
normal = torch.distributions.Normal(means, torch.exp(log_std))
|
||||
x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1))
|
||||
log_probs = normal.log_prob(x_t) # Base log probability before Tanh
|
||||
@@ -475,8 +424,7 @@ class Policy(nn.Module):
|
||||
|
||||
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
"""Get encoded features from observations"""
|
||||
device = get_device_from_parameters(self)
|
||||
observations = observations.to(device)
|
||||
observations = observations.to(self.device)
|
||||
if self.encoder is not None:
|
||||
with torch.inference_mode():
|
||||
return self.encoder(observations)
|
||||
@@ -484,35 +432,65 @@ class Policy(nn.Module):
|
||||
|
||||
|
||||
class SACObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations."""
|
||||
"""Encode image and/or state vector observations.
|
||||
TODO(ke-wang): The original work allows for (1) stacking multiple history frames and (2) using pretrained resnet encoders.
|
||||
"""
|
||||
|
||||
def __init__(self, config: SACConfig, input_normalizer: nn.Module):
|
||||
def __init__(self, config: SACConfig):
|
||||
"""
|
||||
Creates encoders for pixel and/or state modalities.
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.input_normalization = input_normalizer
|
||||
self.has_pretrained_vision_encoder = False
|
||||
self.parameters_to_optimize = []
|
||||
|
||||
self.aggregation_size: int = 0
|
||||
if any("observation.image" in key for key in config.input_shapes):
|
||||
if "observation.image" in config.input_shapes:
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=config.input_shapes["observation.image"][0],
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=5,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.camera_number = config.camera_number
|
||||
self.aggregation_size: int = 0
|
||||
|
||||
if self.config.vision_encoder_name is not None:
|
||||
self.image_enc_layers = PretrainedImageEncoder(config)
|
||||
self.has_pretrained_vision_encoder = True
|
||||
else:
|
||||
self.image_enc_layers = DefaultImageEncoder(config)
|
||||
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
||||
with torch.inference_mode():
|
||||
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||
self.image_enc_layers.extend(
|
||||
sequential=nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(
|
||||
in_features=np.prod(out_shape) * self.camera_number, out_features=config.latent_dim
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
)
|
||||
|
||||
self.aggregation_size += config.latent_dim * self.camera_number
|
||||
|
||||
if config.freeze_vision_encoder:
|
||||
freeze_image_encoder(self.image_enc_layers)
|
||||
else:
|
||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||
|
||||
if "observation.state" in config.input_shapes:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
@@ -523,8 +501,6 @@ class SACObservationEncoder(nn.Module):
|
||||
)
|
||||
self.aggregation_size += config.latent_dim
|
||||
|
||||
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
|
||||
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
@@ -534,11 +510,9 @@ class SACObservationEncoder(nn.Module):
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self.aggregation_size += config.latent_dim
|
||||
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
|
||||
|
||||
self.aggregation_size += config.latent_dim
|
||||
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
||||
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
|
||||
|
||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode the image and/or state vector.
|
||||
@@ -547,20 +521,16 @@ class SACObservationEncoder(nn.Module):
|
||||
over all features.
|
||||
"""
|
||||
feat = []
|
||||
obs_dict = self.input_normalization(obs_dict)
|
||||
# Concatenate all images along the channel dimension.
|
||||
image_keys = [k for k in obs_dict if k.startswith("observation.image")]
|
||||
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")]
|
||||
for image_key in image_keys:
|
||||
enc_feat = self.image_enc_layers(obs_dict[image_key])
|
||||
|
||||
# if not self.has_pretrained_vision_encoder:
|
||||
# enc_feat = flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])
|
||||
feat.append(enc_feat)
|
||||
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key]))
|
||||
if "observation.environment_state" in self.config.input_shapes:
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
|
||||
# TODO(ke-wang): currently average over all features, concatenate all features maybe a better way
|
||||
# return torch.stack(feat, dim=0).mean(0)
|
||||
features = torch.cat(tensors=feat, dim=-1)
|
||||
features = self.aggregation_layer(features)
|
||||
|
||||
@@ -572,109 +542,15 @@ class SACObservationEncoder(nn.Module):
|
||||
return self.config.latent_dim
|
||||
|
||||
|
||||
class DefaultImageEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=config.input_shapes["observation.image"][0],
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=5,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
)
|
||||
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
||||
with torch.inference_mode():
|
||||
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||
self.image_enc_layers.extend(
|
||||
nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.image_enc_layers(x)
|
||||
|
||||
|
||||
class PretrainedImageEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
||||
self.image_enc_proj = nn.Sequential(
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def _load_pretrained_vision_encoder(self, config):
|
||||
"""Set up CNN encoder"""
|
||||
from transformers import AutoModel
|
||||
|
||||
self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True)
|
||||
# self.image_enc_layers.pooler = Identity()
|
||||
|
||||
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
|
||||
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
|
||||
elif hasattr(self.image_enc_layers, "fc"):
|
||||
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
|
||||
else:
|
||||
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
|
||||
return self.image_enc_layers, self.image_enc_out_shape
|
||||
|
||||
def forward(self, x):
|
||||
# TODO: (maractingi, azouitine) check the forward pass of the pretrained model
|
||||
# doesn't reach the classifier layer because we don't need it
|
||||
enc_feat = self.image_enc_layers(x).pooler_output
|
||||
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
|
||||
return enc_feat
|
||||
|
||||
|
||||
def freeze_image_encoder(image_encoder: nn.Module):
|
||||
"""Freeze all parameters in the encoder"""
|
||||
for param in image_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
def orthogonal_init():
|
||||
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cpu") -> nn.ModuleList:
|
||||
"""Creates an ensemble of critic networks"""
|
||||
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
|
||||
return nn.ModuleList(critics).to(device)
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
# TODO (azouitine): I think in our case this function is not usefull we should remove it
|
||||
# after some investigation
|
||||
# borrowed from tdmpc
|
||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||
@@ -682,7 +558,7 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
|
||||
Args:
|
||||
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
|
||||
(B, *), where * is any number of dimensions.
|
||||
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
|
||||
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
|
||||
can be more than 1 dimensions, generally different from *.
|
||||
Returns:
|
||||
A return value from the callable reshaped to (**, *).
|
||||
@@ -692,67 +568,4 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
|
||||
start_dims = image_tensor.shape[:-3]
|
||||
inp = torch.flatten(image_tensor, end_dim=-4)
|
||||
flat_out = fn(inp)
|
||||
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
||||
|
||||
|
||||
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
||||
converted_params = {}
|
||||
for outer_key, inner_dict in normalization_params.items():
|
||||
converted_params[outer_key] = {}
|
||||
for key, value in inner_dict.items():
|
||||
converted_params[outer_key][key] = torch.tensor(value)
|
||||
if "image" in outer_key:
|
||||
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
|
||||
|
||||
return converted_params
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test the SACObservationEncoder
|
||||
import time
|
||||
|
||||
config = SACConfig()
|
||||
config.num_critics = 10
|
||||
encoder = SACObservationEncoder(config)
|
||||
actor_encoder = SACObservationEncoder(config)
|
||||
encoder = torch.compile(encoder)
|
||||
critic_ensemble = CriticEnsemble(
|
||||
encoder=encoder,
|
||||
network_list=nn.ModuleList(
|
||||
[
|
||||
MLP(
|
||||
input_dim=encoder.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
),
|
||||
)
|
||||
actor = Policy(
|
||||
encoder=actor_encoder,
|
||||
network=MLP(input_dim=actor_encoder.output_dim, **config.actor_network_kwargs),
|
||||
action_dim=config.output_shapes["action"][0],
|
||||
encoder_is_shared=config.shared_encoder,
|
||||
**config.policy_kwargs,
|
||||
)
|
||||
encoder = encoder.to("cuda:0")
|
||||
critic_ensemble = torch.compile(critic_ensemble)
|
||||
critic_ensemble = critic_ensemble.to("cuda:0")
|
||||
actor = torch.compile(actor)
|
||||
actor = actor.to("cuda:0")
|
||||
obs_dict = {
|
||||
"observation.image": torch.randn(1, 3, 84, 84),
|
||||
"observation.state": torch.randn(1, 4),
|
||||
}
|
||||
actions = torch.randn(1, 2).to("cuda:0")
|
||||
obs_dict = {k: v.to("cuda:0") for k, v in obs_dict.items()}
|
||||
print("compiling...")
|
||||
# q_value = critic_ensemble(obs_dict, actions)
|
||||
action = actor(obs_dict)
|
||||
print("compiled")
|
||||
start = time.perf_counter()
|
||||
for _ in range(1000):
|
||||
# features = encoder(obs_dict)
|
||||
action = actor(obs_dict)
|
||||
# q_value = critic_ensemble(obs_dict, actions)
|
||||
print("Time taken:", time.perf_counter() - start)
|
||||
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
||||
@@ -36,7 +36,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
|
||||
|
||||
def log_dt(shortname, dt_val_s):
|
||||
nonlocal log_items, fps
|
||||
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)"
|
||||
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)"
|
||||
if fps is not None:
|
||||
actual_fps = 1 / dt_val_s
|
||||
if actual_fps < fps - 1:
|
||||
@@ -225,7 +225,6 @@ def record_episode(
|
||||
device,
|
||||
use_amp,
|
||||
fps,
|
||||
record_delta_actions,
|
||||
):
|
||||
control_loop(
|
||||
robot=robot,
|
||||
@@ -237,7 +236,6 @@ def record_episode(
|
||||
device=device,
|
||||
use_amp=use_amp,
|
||||
fps=fps,
|
||||
record_delta_actions=record_delta_actions,
|
||||
teleoperate=policy is None,
|
||||
)
|
||||
|
||||
@@ -254,7 +252,6 @@ def control_loop(
|
||||
device=None,
|
||||
use_amp=None,
|
||||
fps=None,
|
||||
record_delta_actions=False,
|
||||
):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
if not robot.is_connected:
|
||||
@@ -277,12 +274,8 @@ def control_loop(
|
||||
while timestamp < control_time_s:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
current_joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
|
||||
if teleoperate:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
if record_delta_actions:
|
||||
action["action"] = action["action"] - current_joint_positions
|
||||
else:
|
||||
observation = robot.capture_observation()
|
||||
|
||||
@@ -297,12 +290,8 @@ def control_loop(
|
||||
frame = {**observation, **action}
|
||||
if "next.reward" in events:
|
||||
frame["next.reward"] = events["next.reward"]
|
||||
frame["next.done"] = (events["next.reward"] == 1) or (events["exit_early"])
|
||||
dataset.add_frame(frame)
|
||||
|
||||
# if frame["next.done"]:
|
||||
# break
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
@@ -346,9 +335,7 @@ def reset_environment(robot, events, reset_time_s):
|
||||
|
||||
def reset_follower_position(robot: Robot, target_position):
|
||||
current_position = robot.follower_arms["main"].read("Present_Position")
|
||||
trajectory = torch.from_numpy(
|
||||
np.linspace(current_position, target_position, 30)
|
||||
) # NOTE: 30 is just an aribtrary number
|
||||
trajectory = torch.from_numpy(np.linspace(current_position, target_position, 30)) # NOTE: 30 is just an aribtrary number
|
||||
for pose in trajectory:
|
||||
robot.send_action(pose)
|
||||
busy_wait(0.015)
|
||||
|
||||
@@ -32,7 +32,7 @@ def ensure_safe_goal_position(
|
||||
safe_goal_pos = present_pos + safe_diff
|
||||
|
||||
if not torch.allclose(goal_pos, safe_goal_pos):
|
||||
logging.debug(
|
||||
logging.warning(
|
||||
"Relative goal position magnitude had to be clamped to be safe.\n"
|
||||
f" requested relative goal position target: {diff}\n"
|
||||
f" clamped relative goal position target: {safe_diff}"
|
||||
@@ -67,8 +67,6 @@ class ManipulatorRobotConfig:
|
||||
# gripper is not put in torque mode.
|
||||
gripper_open_degree: float | None = None
|
||||
|
||||
joint_position_relative_bounds: dict[np.ndarray] | None = None
|
||||
|
||||
def __setattr__(self, prop: str, val):
|
||||
if prop == "max_relative_target" and val is not None and isinstance(val, Sequence):
|
||||
for name in self.follower_arms:
|
||||
@@ -80,9 +78,6 @@ class ManipulatorRobotConfig:
|
||||
"Note: This feature does not yet work with robots where different follower arms have "
|
||||
"different numbers of motors."
|
||||
)
|
||||
if prop == "joint_position_relative_bounds" and val is not None:
|
||||
for key in val:
|
||||
val[key] = torch.tensor(val[key])
|
||||
super().__setattr__(prop, val)
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -528,14 +523,6 @@ class ManipulatorRobot:
|
||||
before_fwrite_t = time.perf_counter()
|
||||
goal_pos = leader_pos[name]
|
||||
|
||||
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
|
||||
if self.config.joint_position_relative_bounds is not None:
|
||||
goal_pos = torch.clamp(
|
||||
goal_pos,
|
||||
self.config.joint_position_relative_bounds["min"],
|
||||
self.config.joint_position_relative_bounds["max"],
|
||||
)
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
# Slower fps expected due to reading from the follower.
|
||||
if self.config.max_relative_target is not None:
|
||||
@@ -657,14 +644,6 @@ class ManipulatorRobot:
|
||||
goal_pos = action[from_idx:to_idx]
|
||||
from_idx = to_idx
|
||||
|
||||
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
|
||||
if self.config.joint_position_relative_bounds is not None:
|
||||
goal_pos = torch.clamp(
|
||||
goal_pos,
|
||||
self.config.joint_position_relative_bounds["min"],
|
||||
self.config.joint_position_relative_bounds["max"],
|
||||
)
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
# Slower fps expected due to reading from the follower.
|
||||
if self.config.max_relative_target is not None:
|
||||
@@ -677,7 +656,6 @@ class ManipulatorRobot:
|
||||
|
||||
# Send goal position to each follower
|
||||
goal_pos = goal_pos.numpy().astype(np.int32)
|
||||
|
||||
self.follower_arms[name].write("Goal_Position", goal_pos)
|
||||
|
||||
return torch.cat(action_sent)
|
||||
|
||||
@@ -18,7 +18,6 @@ import os
|
||||
import os.path as osp
|
||||
import platform
|
||||
import random
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
@@ -218,28 +217,3 @@ def log_say(text, play_sounds, blocking=False):
|
||||
|
||||
if play_sounds:
|
||||
say(text, blocking)
|
||||
|
||||
|
||||
class TimerManager:
|
||||
def __init__(self, elapsed_time_list: list[float] | None = None, label="Elapsed time", log=True):
|
||||
self.label = label
|
||||
self.elapsed_time_list = elapsed_time_list
|
||||
self.log = log
|
||||
self.elapsed = 0.0
|
||||
|
||||
def __enter__(self):
|
||||
self.start = time.perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.elapsed: float = time.perf_counter() - self.start
|
||||
|
||||
if self.elapsed_time_list is not None:
|
||||
self.elapsed_time_list.append(self.elapsed)
|
||||
|
||||
if self.log:
|
||||
print(f"{self.label}: {self.elapsed:.6f} seconds")
|
||||
|
||||
@property
|
||||
def elapsed_seconds(self):
|
||||
return self.elapsed
|
||||
|
||||
@@ -2,7 +2,6 @@ defaults:
|
||||
- _self_
|
||||
- env: pusht
|
||||
- policy: diffusion
|
||||
- robot: so100
|
||||
|
||||
hydra:
|
||||
run:
|
||||
|
||||
20
lerobot/configs/env/maniskill_example.yaml
vendored
20
lerobot/configs/env/maniskill_example.yaml
vendored
@@ -1,20 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
fps: 20
|
||||
|
||||
env:
|
||||
name: maniskill/pushcube
|
||||
task: PushCube-v1
|
||||
image_size: 128
|
||||
control_mode: pd_ee_delta_pose
|
||||
state_dim: 25
|
||||
action_dim: 7
|
||||
fps: ${fps}
|
||||
obs: rgb
|
||||
render_mode: rgb_array
|
||||
render_size: 128
|
||||
device: cuda
|
||||
|
||||
reward_classifier:
|
||||
pretrained_path: null
|
||||
config_path: null
|
||||
23
lerobot/configs/env/so100_real.yaml
vendored
23
lerobot/configs/env/so100_real.yaml
vendored
@@ -1,6 +1,6 @@
|
||||
# @package _global_
|
||||
|
||||
fps: 10
|
||||
fps: 30
|
||||
|
||||
env:
|
||||
name: real_world
|
||||
@@ -8,24 +8,3 @@ env:
|
||||
state_dim: 6
|
||||
action_dim: 6
|
||||
fps: ${fps}
|
||||
device: mps
|
||||
|
||||
wrapper:
|
||||
crop_params_dict:
|
||||
observation.images.front: [102, 43, 358, 523]
|
||||
observation.images.side: [92, 123, 379, 349]
|
||||
# observation.images.front: [109, 37, 361, 557]
|
||||
# observation.images.side: [94, 161, 372, 315]
|
||||
resize_size: [128, 128]
|
||||
control_time_s: 20
|
||||
reset_follower_pos: true
|
||||
use_relative_joint_positions: true
|
||||
reset_time_s: 5
|
||||
display_cameras: false
|
||||
delta_action: 0.1
|
||||
joint_masking_action_space: [1, 1, 1, 1, 0, 0] # disable wrist and gripper
|
||||
|
||||
reward_classifier:
|
||||
pretrained_path: outputs/classifier/13-02-random-sample-resnet10-frozen/checkpoints/best/pretrained_model
|
||||
config_path: lerobot/configs/policy/hilserl_classifier.yaml
|
||||
|
||||
@@ -4,10 +4,7 @@ defaults:
|
||||
- _self_
|
||||
|
||||
seed: 13
|
||||
dataset_repo_id: aractingi/push_cube_square_light_reward_cropped_resized
|
||||
# aractingi/push_cube_square_reward_1_cropped_resized
|
||||
dataset_root: data/aractingi/push_cube_square_light_reward_cropped_resized
|
||||
local_files_only: true
|
||||
dataset_repo_id: aractingi/pick_place_lego_cube_1
|
||||
train_split_proportion: 0.8
|
||||
|
||||
# Required by logger
|
||||
@@ -17,7 +14,7 @@ env:
|
||||
|
||||
|
||||
training:
|
||||
num_epochs: 6
|
||||
num_epochs: 5
|
||||
batch_size: 16
|
||||
learning_rate: 1e-4
|
||||
num_workers: 4
|
||||
@@ -27,18 +24,16 @@ training:
|
||||
eval_freq: 1 # How often to run validation (in epochs)
|
||||
save_freq: 1 # How often to save checkpoints (in epochs)
|
||||
save_checkpoint: true
|
||||
image_keys: ["observation.images.front", "observation.images.side"]
|
||||
image_keys: ["observation.images.top", "observation.images.wrist"]
|
||||
label_key: "next.reward"
|
||||
profile_inference_time: false
|
||||
profile_inference_time_iters: 20
|
||||
|
||||
eval:
|
||||
batch_size: 16
|
||||
num_samples_to_log: 30 # Number of validation samples to log in the table
|
||||
|
||||
policy:
|
||||
name: "hilserl/classifier"
|
||||
model_name: "helper2424/resnet10" # "facebook/convnext-base-224
|
||||
name: "hilserl/classifier/pick_place_lego_cube_1"
|
||||
model_name: "facebook/convnext-base-224"
|
||||
model_type: "cnn"
|
||||
num_cameras: 2 # Has to be len(training.image_keys)
|
||||
|
||||
@@ -50,4 +45,4 @@ wandb:
|
||||
|
||||
device: "mps"
|
||||
resume: false
|
||||
output_dir: "outputs/classifier/old_trainer_resnet10_frozen"
|
||||
output_dir: "outputs/classifier"
|
||||
|
||||
@@ -8,7 +8,8 @@
|
||||
# env.gym.obs_type=environment_state_agent_pos \
|
||||
|
||||
seed: 1
|
||||
dataset_repo_id: null
|
||||
dataset_repo_id: null
|
||||
|
||||
|
||||
training:
|
||||
# Offline training dataloader
|
||||
@@ -20,8 +21,8 @@ training:
|
||||
lr: 3e-4
|
||||
|
||||
eval_freq: 2500
|
||||
log_freq: 10
|
||||
save_freq: 2000000
|
||||
log_freq: 500
|
||||
save_freq: 50000
|
||||
|
||||
online_steps: 1000000
|
||||
online_rollout_n_episodes: 10
|
||||
@@ -52,16 +53,12 @@ policy:
|
||||
n_action_steps: 1
|
||||
|
||||
shared_encoder: true
|
||||
# vision_encoder_name: null
|
||||
vision_encoder_name: "helper2424/resnet10"
|
||||
freeze_vision_encoder: true
|
||||
# freeze_vision_encoder: false
|
||||
input_shapes:
|
||||
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.state: ["${env.state_dim}"]
|
||||
observation.image: [3, 128, 128]
|
||||
observation.image: [3, 64, 64]
|
||||
output_shapes:
|
||||
action: [7]
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: null
|
||||
@@ -69,8 +66,8 @@ policy:
|
||||
action: min_max
|
||||
output_normalization_params:
|
||||
action:
|
||||
min: [-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0]
|
||||
max: [10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]
|
||||
min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0]
|
||||
max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
|
||||
# Architecture / modeling.
|
||||
# Neural networks.
|
||||
@@ -78,15 +75,23 @@ policy:
|
||||
# discount: 0.99
|
||||
discount: 0.80
|
||||
temperature_init: 1.0
|
||||
num_critics: 2 #10
|
||||
num_critics: 2
|
||||
num_subsample_critics: null
|
||||
critic_lr: 3e-4
|
||||
actor_lr: 3e-4
|
||||
temperature_lr: 3e-4
|
||||
# critic_target_update_weight: 0.005
|
||||
critic_target_update_weight: 0.01
|
||||
utd_ratio: 2 # 10
|
||||
utd_ratio: 1
|
||||
|
||||
actor_learner_config:
|
||||
actor_ip: "127.0.0.1"
|
||||
port: 50051
|
||||
|
||||
# # Loss coefficients.
|
||||
# reward_coeff: 0.5
|
||||
# expectile_weight: 0.9
|
||||
# value_coeff: 0.1
|
||||
# consistency_coeff: 20.0
|
||||
# advantage_scaling: 3.0
|
||||
# pi_coeff: 0.5
|
||||
# temporal_decay_coeff: 0.5
|
||||
# # Target model.
|
||||
# target_model_momentum: 0.995
|
||||
@@ -1,127 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
# Train with:
|
||||
#
|
||||
# python lerobot/scripts/train.py \
|
||||
# +dataset=lerobot/pusht_keypoints
|
||||
# env=pusht \
|
||||
# env.gym.obs_type=environment_state_agent_pos \
|
||||
|
||||
seed: 1
|
||||
dataset_repo_id: aractingi/push_cube_overfit_cropped_resized
|
||||
#aractingi/push_cube_square_offline_demo_cropped_resized
|
||||
|
||||
training:
|
||||
# Offline training dataloader
|
||||
num_workers: 4
|
||||
|
||||
# batch_size: 256
|
||||
batch_size: 512
|
||||
grad_clip_norm: 10.0
|
||||
lr: 3e-4
|
||||
|
||||
eval_freq: 2500
|
||||
log_freq: 1
|
||||
save_freq: 2000000
|
||||
|
||||
online_steps: 1000000
|
||||
online_rollout_n_episodes: 10
|
||||
online_rollout_batch_size: 10
|
||||
online_steps_between_rollouts: 1000
|
||||
online_sampling_ratio: 1.0
|
||||
online_env_seed: 10000
|
||||
online_buffer_capacity: 1000000
|
||||
online_buffer_seed_size: 0
|
||||
online_step_before_learning: 100 #5000
|
||||
do_online_rollout_async: false
|
||||
policy_update_freq: 1
|
||||
|
||||
# delta_timestamps:
|
||||
# observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
# observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
# action: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
# next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
|
||||
policy:
|
||||
name: sac
|
||||
|
||||
pretrained_model_path:
|
||||
|
||||
# Input / output structure.
|
||||
n_action_repeats: 1
|
||||
horizon: 1
|
||||
n_action_steps: 1
|
||||
|
||||
shared_encoder: true
|
||||
vision_encoder_name: "helper2424/resnet10"
|
||||
freeze_vision_encoder: true
|
||||
input_shapes:
|
||||
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.state: ["${env.state_dim}"]
|
||||
observation.images.front: [3, 128, 128]
|
||||
observation.images.side: [3, 128, 128]
|
||||
# observation.image: [3, 128, 128]
|
||||
output_shapes:
|
||||
action: [4] # ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.front: mean_std
|
||||
observation.images.side: mean_std
|
||||
observation.state: min_max
|
||||
input_normalization_params:
|
||||
observation.images.front:
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
observation.images.side:
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
observation.state:
|
||||
min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786]
|
||||
max: [ 7.215820e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01]
|
||||
|
||||
# min: [-87.09961, 62.402344, 67.23633, 36.035156, 77.34375,0.53691274]
|
||||
# max: [58.183594, 131.83594, 145.98633, 82.08984, 78.22266, 0.60402685]
|
||||
# min: [-88.50586, 23.81836, 0.87890625, -32.16797, 78.66211, 0.53691274]
|
||||
# max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156, 88.18792]
|
||||
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
output_normalization_params:
|
||||
# action:
|
||||
# min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0]
|
||||
# max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
action:
|
||||
min: [-149.23828125, -97.734375, -100.1953125, -73.740234375]
|
||||
max: [149.23828125, 97.734375, 100.1953125, 73.740234375]
|
||||
|
||||
# Architecture / modeling.
|
||||
# Neural networks.
|
||||
image_encoder_hidden_dim: 32
|
||||
# discount: 0.99
|
||||
discount: 0.97
|
||||
temperature_init: 1.0
|
||||
num_critics: 2 #10
|
||||
camera_number: 2
|
||||
num_subsample_critics: null
|
||||
critic_lr: 3e-4
|
||||
actor_lr: 3e-4
|
||||
temperature_lr: 3e-4
|
||||
# critic_target_update_weight: 0.005
|
||||
critic_target_update_weight: 0.01
|
||||
utd_ratio: 2 # 10
|
||||
|
||||
actor_learner_config:
|
||||
actor_ip: "127.0.0.1"
|
||||
port: 50051
|
||||
|
||||
# # Loss coefficients.
|
||||
# reward_coeff: 0.5
|
||||
# expectile_weight: 0.9
|
||||
# value_coeff: 0.1
|
||||
# consistency_coeff: 20.0
|
||||
# advantage_scaling: 3.0
|
||||
# pi_coeff: 0.5
|
||||
# temporal_decay_coeff: 0.5
|
||||
# # Target model.
|
||||
# target_model_momentum: 0.995
|
||||
@@ -14,9 +14,6 @@ calibration_dir: .cache/calibration/so100
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: null
|
||||
joint_position_relative_bounds:
|
||||
max: [ 7.2158203e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01]
|
||||
min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786]
|
||||
|
||||
leader_arms:
|
||||
main:
|
||||
@@ -34,7 +31,7 @@ leader_arms:
|
||||
follower_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.feetech.FeetechMotorsBus
|
||||
port: /dev/tty.usbmodem58760431631
|
||||
port: /dev/tty.usbmodem585A0080971
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "sts3215"]
|
||||
@@ -45,13 +42,13 @@ follower_arms:
|
||||
gripper: [6, "sts3215"]
|
||||
|
||||
cameras:
|
||||
front:
|
||||
laptop:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 0
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
side:
|
||||
phone:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 1
|
||||
fps: 30
|
||||
|
||||
@@ -206,8 +206,7 @@ def record(
|
||||
num_image_writer_threads_per_camera: int = 4,
|
||||
display_cameras: bool = True,
|
||||
play_sounds: bool = True,
|
||||
reset_follower: bool = False,
|
||||
record_delta_actions: bool = False,
|
||||
reset_follower: bool = False,
|
||||
resume: bool = False,
|
||||
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
|
||||
local_files_only: bool = False,
|
||||
@@ -219,12 +218,7 @@ def record(
|
||||
device = None
|
||||
use_amp = None
|
||||
extra_features = (
|
||||
{
|
||||
"next.reward": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
|
||||
}
|
||||
if assign_rewards
|
||||
else None
|
||||
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
|
||||
)
|
||||
|
||||
if single_task:
|
||||
@@ -275,7 +269,7 @@ def record(
|
||||
|
||||
if reset_follower:
|
||||
initial_position = robot.follower_arms["main"].read("Present_Position")
|
||||
|
||||
|
||||
# Execute a few seconds without recording to:
|
||||
# 1. teleoperate the robot to move it in starting position if no policy provided,
|
||||
# 2. give times to the robot devices to connect and start synchronizing,
|
||||
@@ -308,7 +302,6 @@ def record(
|
||||
device=device,
|
||||
use_amp=use_amp,
|
||||
fps=fps,
|
||||
record_delta_actions=record_delta_actions,
|
||||
)
|
||||
|
||||
# Execute a few seconds without recording to give time to manually reset the environment
|
||||
@@ -360,24 +353,21 @@ def replay(
|
||||
fps: int | None = None,
|
||||
play_sounds: bool = True,
|
||||
local_files_only: bool = False,
|
||||
replay_delta_actions: bool = False,
|
||||
):
|
||||
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
|
||||
# TODO(rcadene): Add option to record logs
|
||||
|
||||
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
log_say("Replaying episode", play_sounds, blocking=True)
|
||||
for idx in range(dataset.num_frames):
|
||||
current_joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action = actions[idx]["action"]
|
||||
if replay_delta_actions:
|
||||
action = action + current_joint_positions
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
@@ -544,12 +534,6 @@ if __name__ == "__main__":
|
||||
default=0,
|
||||
help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--record-delta-actions",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Enables the recording of delta actions instead of absolute actions.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--reset-follower",
|
||||
type=int,
|
||||
@@ -579,12 +563,6 @@ if __name__ == "__main__":
|
||||
default=0,
|
||||
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
|
||||
)
|
||||
parser_replay.add_argument(
|
||||
"--replay-delta-actions",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Enables the replay of delta actions instead of absolute actions.",
|
||||
)
|
||||
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episode to replay.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -15,25 +15,34 @@
|
||||
# limitations under the License.
|
||||
"""Evaluate a policy by running rollouts on the real robot and computing metrics.
|
||||
|
||||
Usage examples: evaluate a checkpoint from the LeRobot training script for 10 episodes.
|
||||
This script supports performing human interventions during rollouts.
|
||||
Human interventions allow the user to take control of the robot from the policy
|
||||
and correct its behavior. It is specifically designed for reinforcement learning
|
||||
experiments and HIL-SERL (human-in-the-loop reinforcement learning) methods.
|
||||
|
||||
```
|
||||
python lerobot/scripts/eval_on_robot.py \
|
||||
-p outputs/train/model/checkpoints/005000/pretrained_model \
|
||||
eval.n_episodes=10
|
||||
```
|
||||
### How to Use
|
||||
|
||||
Test reward classifier with teleoperation (you need to press space to take over)
|
||||
To rollout a policy on the robot:
|
||||
```
|
||||
python lerobot/scripts/eval_on_robot.py \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--pretrained-policy-path-or-name path/to/pretrained_model \
|
||||
--policy-config path/to/policy/config.yaml \
|
||||
--display-cameras 1
|
||||
```
|
||||
|
||||
If you trained a reward classifier on your task, you can also evaluate it using this script.
|
||||
You can annotate the collection with a pre-trained reward classifier by running:
|
||||
|
||||
```
|
||||
python lerobot/scripts/eval_on_robot.py \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--pretrained-policy-path-or-name path/to/pretrained_model \
|
||||
--policy-config path/to/policy/config.yaml \
|
||||
--reward-classifier-pretrained-path outputs/classifier/checkpoints/best/pretrained_model \
|
||||
--reward-classifier-config-file lerobot/configs/policy/hilserl_classifier.yaml \
|
||||
--display-cameras 1
|
||||
```
|
||||
|
||||
**NOTE** (michel-aractingi): This script is incomplete and it is being prepared
|
||||
for running training on the real robot.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -46,7 +55,8 @@ import torch
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position, predict_action
|
||||
from lerobot.common.robot_devices.robots.factory import Robot, make_robot
|
||||
from lerobot.common.utils.utils import (
|
||||
init_hydra_config,
|
||||
@@ -69,7 +79,6 @@ def get_classifier(pretrained_path, config_path):
|
||||
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
|
||||
model = Classifier(classifier_config)
|
||||
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
|
||||
model = model.to("mps")
|
||||
return model
|
||||
|
||||
|
||||
@@ -81,48 +90,45 @@ def rollout(
|
||||
control_time_s: float = 20,
|
||||
use_amp: bool = True,
|
||||
display_cameras: bool = False,
|
||||
device: str = "cpu"
|
||||
) -> dict:
|
||||
"""Run a batched policy rollout on the real robot.
|
||||
|
||||
The return dictionary contains:
|
||||
"robot": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation
|
||||
keys. NOTE the that this has an extra sequence element relative to the other keys in the
|
||||
dictionary. This is because an extra observation is included for after the environment is
|
||||
terminated or truncated.
|
||||
"action": A (batch, sequence, action_dim) tensor of actions applied based on the observations (not
|
||||
including the last observations).
|
||||
"reward": A (batch, sequence) tensor of rewards received for applying the actions.
|
||||
"success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon
|
||||
environment termination/truncation).
|
||||
"done": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
|
||||
the first True is followed by True's all the way till the end. This can be used for masking
|
||||
extraneous elements from the sequences above.
|
||||
This function executes a rollout using the provided policy and robot interface,
|
||||
simulating batched interactions for a fixed control duration.
|
||||
|
||||
The returned dictionary contains rollout statistics, which can be used for analysis and debugging.
|
||||
|
||||
Args:
|
||||
robot: The robot class that defines the interface with the real robot.
|
||||
policy: The policy. Must be a PyTorch nn module.
|
||||
"robot": The robot interface for interacting with the real robot hardware.
|
||||
"policy": The policy to execute. Must be a PyTorch `nn.Module` object.
|
||||
"reward_classifier": A module to classify rewards during the rollout.
|
||||
"fps": The control frequency at which the policy is executed.
|
||||
"control_time_s": The total control duration of the rollout in seconds.
|
||||
"use_amp": Whether to use automatic mixed precision (AMP) for policy evaluation.
|
||||
"display_cameras": If True, displays camera streams during the rollout.
|
||||
"device": The device to use for computations (e.g., "cpu", "cuda" or "mps").
|
||||
|
||||
Returns:
|
||||
The dictionary described above.
|
||||
Dictionary of the statisitcs collected during rollouts.
|
||||
"""
|
||||
# TODO (michel-aractingi): Infer the device from policy parameters when policy is added
|
||||
# assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
|
||||
# device = get_device_from_parameters(policy)
|
||||
|
||||
# define keyboard listener
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
# Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready.
|
||||
# policy.reset()
|
||||
if policy is not None:
|
||||
policy.reset()
|
||||
|
||||
# NOTE: sorting to make sure the key sequence is the same during training and testing.
|
||||
observation = robot.capture_observation()
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
image_keys.sort()
|
||||
image_keys.sort() # CG{T}
|
||||
|
||||
|
||||
all_actions = []
|
||||
all_rewards = []
|
||||
all_successes = []
|
||||
|
||||
indices_from_policy = []
|
||||
|
||||
start_episode_t = time.perf_counter()
|
||||
init_pos = robot.follower_arms["main"].read("Present_Position")
|
||||
@@ -141,27 +147,32 @@ def rollout(
|
||||
else:
|
||||
# explore with policy
|
||||
with torch.inference_mode():
|
||||
# TODO (michel-aractingi) replace this part with policy (predict_action)
|
||||
action = robot.follower_arms["main"].read("Present_Position")
|
||||
action = torch.from_numpy(action)
|
||||
# TODO (michel-aractingi) in placy temporarly for testing purposes
|
||||
if policy is None:
|
||||
action = robot.follower_arms["main"].read("Present_Position")
|
||||
action = torch.from_numpy(action)
|
||||
indices_from_policy.append(False)
|
||||
else:
|
||||
action = predict_action(observation, policy, device, use_amp)
|
||||
indices_from_policy.append(True)
|
||||
|
||||
robot.send_action(action)
|
||||
# action = predict_action(observation, policy, device, use_amp)
|
||||
observation = robot.capture_observation()
|
||||
|
||||
|
||||
observation = robot.capture_observation()
|
||||
images = []
|
||||
for key in image_keys:
|
||||
if display_cameras:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
images.append(observation[key].to("mps"))
|
||||
images.append(observation[key].to(device))
|
||||
|
||||
reward = reward_classifier.predict_reward(images) if reward_classifier is not None else 0.0
|
||||
|
||||
# TODO send data through the server as soon as you have it
|
||||
|
||||
all_rewards.append(reward)
|
||||
|
||||
# print("REWARD : ", reward)
|
||||
|
||||
all_actions.append(action)
|
||||
all_successes.append(torch.tensor([False]))
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
@@ -180,7 +191,6 @@ def rollout(
|
||||
ret = {
|
||||
"action": torch.stack(all_actions, dim=1),
|
||||
"next.reward": torch.stack(all_rewards, dim=1),
|
||||
"next.success": torch.stack(all_successes, dim=1),
|
||||
"done": dones,
|
||||
}
|
||||
|
||||
@@ -199,14 +209,32 @@ def eval_policy(
|
||||
display_cameras: bool = False,
|
||||
reward_classifier_pretrained_path: str | None = None,
|
||||
reward_classifier_config_file: str | None = None,
|
||||
device: str | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Evaluate a policy on a real robot by running multiple episodes and collecting metrics.
|
||||
|
||||
This function executes rollouts of the specified policy on the robot, computes metrics
|
||||
for the rollouts, and optionally evaluates a reward classifier if provided.
|
||||
|
||||
Args:
|
||||
env: The batch of environments.
|
||||
policy: The policy.
|
||||
n_episodes: The number of episodes to evaluate.
|
||||
"robot": The robot interface used to interact with the real robot hardware.
|
||||
"policy": The policy to be evaluated. Must be a PyTorch neural network module.
|
||||
"fps": Frames per second (control frequency) for running the policy.
|
||||
"n_episodes": The number of episodes to evaluate the policy.
|
||||
"control_time_s": The max duration for each episode in seconds.
|
||||
"use_amp": Whether to use automatic mixed precision (AMP) for policy evaluation.
|
||||
"display_cameras": Whether to display camera streams during rollouts.
|
||||
"reward_classifier_pretrained_path": Path to the pretrained reward classifier.
|
||||
If provided, the reward classifier will be evaluated during rollouts.
|
||||
"reward_classifier_config_file": Path to the configuration file for the reward classifier.
|
||||
Required if `reward_classifier_pretrained_path` is provided.
|
||||
"device": The device for computations (e.g., "cpu", "cuda" or "mps").
|
||||
|
||||
Returns:
|
||||
Dictionary with metrics and data regarding the rollouts.
|
||||
"dict": A dictionary containing the following rollout metrics and data:
|
||||
- "metrics": Evaluation metrics such as cumulative rewards, success rates, etc.
|
||||
- "rollout_data": Detailed data from the rollouts, including observations, actions, rewards, and done flags.
|
||||
"""
|
||||
# TODO (michel-aractingi) comment this out for testing with a fixed policy
|
||||
# assert isinstance(policy, Policy)
|
||||
@@ -214,22 +242,22 @@ def eval_policy(
|
||||
|
||||
sum_rewards = []
|
||||
max_rewards = []
|
||||
successes = []
|
||||
rollouts = []
|
||||
|
||||
start_eval = time.perf_counter()
|
||||
progbar = trange(n_episodes, desc="Evaluating policy on real robot")
|
||||
reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file)
|
||||
reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file).to(device)
|
||||
|
||||
device = get_device_from_parameters(policy) if device is None else device
|
||||
|
||||
for _ in progbar:
|
||||
rollout_data = rollout(
|
||||
robot, policy, reward_classifier, fps, control_time_s, use_amp, display_cameras
|
||||
robot, policy, reward_classifier, fps, control_time_s, use_amp, display_cameras, device
|
||||
)
|
||||
|
||||
rollouts.append(rollout_data)
|
||||
sum_rewards.append(sum(rollout_data["next.reward"]))
|
||||
max_rewards.append(max(rollout_data["next.reward"]))
|
||||
successes.append(rollout_data["next.success"][-1])
|
||||
|
||||
info = {
|
||||
"per_episode": [
|
||||
@@ -237,21 +265,18 @@ def eval_policy(
|
||||
"episode_ix": i,
|
||||
"sum_reward": sum_reward,
|
||||
"max_reward": max_reward,
|
||||
"pc_success": success * 100,
|
||||
}
|
||||
for i, (sum_reward, max_reward, success) in enumerate(
|
||||
for i, (sum_reward, max_reward) in enumerate(
|
||||
zip(
|
||||
sum_rewards[:n_episodes],
|
||||
max_rewards[:n_episodes],
|
||||
successes[:n_episodes],
|
||||
strict=False,
|
||||
)
|
||||
)
|
||||
],
|
||||
"aggregated": {
|
||||
"avg_sum_reward": float(np.nanmean(torch.cat(sum_rewards[:n_episodes]))),
|
||||
"avg_sum_reward": float(np.nanmean(torch.cat(sum_rewards[:n_episodes]))),
|
||||
"avg_max_reward": float(np.nanmean(torch.cat(max_rewards[:n_episodes]))),
|
||||
"pc_success": float(np.nanmean(torch.cat(successes[:n_episodes])) * 100),
|
||||
"eval_s": time.time() - start_eval,
|
||||
"eval_ep_s": (time.time() - start_eval) / n_episodes,
|
||||
},
|
||||
@@ -264,9 +289,18 @@ def eval_policy(
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
"""
|
||||
Initialize a keyboard listener for controlling the recording and human intervention process.
|
||||
|
||||
Keyboard controls: (Note that this might require sudo permissions to monitor keyboard events)
|
||||
- Right Arrow Key ('->'): Stops the current recording and exits early, useful for ending an episode
|
||||
and moving the next episode recording.
|
||||
- Left Arrow Key ('<-'): Re-records the current episode, allowing the user to start over.
|
||||
- Space Bar: Controls the human intervention process in three steps:
|
||||
1. First press pauses the policy and prompts the user to position the leader similar to the follower.
|
||||
2. Second press initiates human interventions, allowing teleop control of the robot.
|
||||
3. Third press resumes the policy rollout.
|
||||
"""
|
||||
events = {}
|
||||
events["exit_early"] = False
|
||||
events["rerecord_episode"] = False
|
||||
@@ -302,10 +336,15 @@ def init_keyboard_listener():
|
||||
)
|
||||
events["pause_policy"] = True
|
||||
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
|
||||
else:
|
||||
elif events["pause_policy"] and not events["human_intervention_step"]:
|
||||
events["human_intervention_step"] = True
|
||||
print("Space key pressed. Human intervention starting.")
|
||||
log_say("Starting human intervention.", play_sounds=True)
|
||||
elif events["human_intervention_step"]:
|
||||
events["human_intervention_step"] = False
|
||||
events["pause_policy"] = False
|
||||
print("Space key pressed. Human intervention ending, policy resumes control.")
|
||||
log_say("Policy resuming.", play_sounds=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
@@ -1,375 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import io
|
||||
import logging
|
||||
import pickle
|
||||
import queue
|
||||
import time
|
||||
from concurrent import futures
|
||||
from statistics import mean, quantiles
|
||||
|
||||
# from lerobot.scripts.eval import eval_policy
|
||||
from threading import Thread
|
||||
|
||||
import grpc
|
||||
import hydra
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
from torch import nn
|
||||
|
||||
# TODO: Remove the import of maniskill
|
||||
# from lerobot.common.envs.factory import make_maniskill_env
|
||||
# from lerobot.common.envs.utils import preprocess_maniskill_observation
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.robot_devices.control_utils import busy_wait
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.utils.utils import (
|
||||
TimerManager,
|
||||
get_safe_torch_device,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc
|
||||
from lerobot.scripts.server.buffer import Transition, move_state_dict_to_device, move_transition_to_device
|
||||
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
parameters_queue = queue.Queue(maxsize=1)
|
||||
message_queue = queue.Queue(maxsize=1_000_000)
|
||||
|
||||
|
||||
class ActorInformation:
|
||||
"""
|
||||
This helper class is used to differentiate between two types of messages that are placed in the same queue during streaming:
|
||||
|
||||
- **Transition Data:** Contains experience tuples (observation, action, reward, next observation) collected during interaction.
|
||||
- **Interaction Messages:** Encapsulates statistics related to the interaction process.
|
||||
|
||||
Attributes:
|
||||
transition (Optional): Transition data to be sent to the learner.
|
||||
interaction_message (Optional): Iteraction message providing additional statistics for logging.
|
||||
"""
|
||||
|
||||
def __init__(self, transition=None, interaction_message=None):
|
||||
self.transition = transition
|
||||
self.interaction_message = interaction_message
|
||||
|
||||
|
||||
class ActorServiceServicer(hilserl_pb2_grpc.ActorServiceServicer):
|
||||
"""
|
||||
gRPC service for actor-learner communication in reinforcement learning.
|
||||
|
||||
This service is responsible for:
|
||||
1. Streaming batches of transition data and statistical metrics from the actor to the learner.
|
||||
2. Receiving updated network parameters from the learner.
|
||||
"""
|
||||
|
||||
def StreamTransition(self, request, context): # noqa: N802
|
||||
"""
|
||||
Streams data from the actor to the learner.
|
||||
|
||||
This function continuously retrieves messages from the queue and processes them based on their type:
|
||||
|
||||
- **Transition Data:**
|
||||
- A batch of transitions (observation, action, reward, next observation) is collected.
|
||||
- Transitions are moved to the CPU and serialized using PyTorch.
|
||||
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
|
||||
|
||||
- **Interaction Messages:**
|
||||
- Contains useful statistics about episodic rewards and policy timings.
|
||||
- The message is serialized using `pickle` and sent to the learner.
|
||||
|
||||
Yields:
|
||||
hilserl_pb2.ActorInformation: The response message containing either transition data or an interaction message.
|
||||
"""
|
||||
while True:
|
||||
message = message_queue.get(block=True)
|
||||
|
||||
if message.transition is not None:
|
||||
transition_to_send_to_learner: list[Transition] = [
|
||||
move_transition_to_device(transition=T, device="cpu") for T in message.transition
|
||||
]
|
||||
# Check for NaNs in transitions before sending to learner
|
||||
for transition in transition_to_send_to_learner:
|
||||
for key, value in transition["state"].items():
|
||||
if torch.isnan(value).any():
|
||||
logging.warning(f"Found NaN values in transition {key}")
|
||||
buf = io.BytesIO()
|
||||
torch.save(transition_to_send_to_learner, buf)
|
||||
transition_bytes = buf.getvalue()
|
||||
|
||||
transition_message = hilserl_pb2.Transition(transition_bytes=transition_bytes)
|
||||
|
||||
response = hilserl_pb2.ActorInformation(transition=transition_message)
|
||||
|
||||
elif message.interaction_message is not None:
|
||||
content = hilserl_pb2.InteractionMessage(
|
||||
interaction_message_bytes=pickle.dumps(message.interaction_message)
|
||||
)
|
||||
response = hilserl_pb2.ActorInformation(interaction_message=content)
|
||||
|
||||
yield response
|
||||
|
||||
def SendParameters(self, request, context): # noqa: N802
|
||||
"""
|
||||
Receives updated parameters from the learner and updates the actor.
|
||||
|
||||
The learner calls this method to send new model parameters. The received parameters are deserialized
|
||||
and placed in a queue to be consumed by the actor.
|
||||
|
||||
Args:
|
||||
request (hilserl_pb2.ParameterUpdate): The request containing serialized network parameters.
|
||||
context (grpc.ServicerContext): The gRPC context.
|
||||
|
||||
Returns:
|
||||
hilserl_pb2.Empty: An empty response to acknowledge receipt.
|
||||
"""
|
||||
buffer = io.BytesIO(request.parameter_bytes)
|
||||
params = torch.load(buffer)
|
||||
parameters_queue.put(params)
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
|
||||
def serve_actor_service(port=50052):
|
||||
"""
|
||||
Runs a gRPC server to start streaming the data from the actor to the learner.
|
||||
Throught this server the learner can push parameters to the Actor as well.
|
||||
"""
|
||||
server = grpc.server(
|
||||
futures.ThreadPoolExecutor(max_workers=20),
|
||||
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
|
||||
)
|
||||
hilserl_pb2_grpc.add_ActorServiceServicer_to_server(ActorServiceServicer(), server)
|
||||
server.add_insecure_port(f"[::]:{port}")
|
||||
server.start()
|
||||
logging.info(f"[ACTOR] gRPC server listening on port {port}")
|
||||
server.wait_for_termination()
|
||||
|
||||
|
||||
def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, device):
|
||||
if not parameters_queue.empty():
|
||||
logging.info("[ACTOR] Load new parameters from Learner.")
|
||||
state_dict = parameters_queue.get()
|
||||
state_dict = move_state_dict_to_device(state_dict, device=device)
|
||||
policy.load_state_dict(state_dict, strict=False)
|
||||
|
||||
|
||||
def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module):
|
||||
"""
|
||||
Executes policy interaction within the environment.
|
||||
|
||||
This function rolls out the policy in the environment, collecting interaction data and pushing it to a queue for streaming to the learner.
|
||||
Once an episode is completed, updated network parameters received from the learner are retrieved from a queue and loaded into the network.
|
||||
|
||||
Args:
|
||||
cfg (DictConfig): Configuration settings for the interaction process.
|
||||
"""
|
||||
|
||||
logging.info("make_env online")
|
||||
|
||||
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg.env)
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("make_policy")
|
||||
|
||||
# HACK: This is an ugly hack to pass the normalization parameters to the policy
|
||||
# Because the action space is dynamic so we override the output normalization parameters
|
||||
# it's ugly, we know ... and we will fix it
|
||||
min_action_space: list = online_env.action_space.spaces[0].low.tolist()
|
||||
max_action_space: list = online_env.action_space.spaces[0].high.tolist()
|
||||
output_normalization_params: dict[dict[str, list]] = {
|
||||
"action": {"min": min_action_space, "max": max_action_space}
|
||||
}
|
||||
cfg.policy.output_normalization_params = output_normalization_params
|
||||
cfg.policy.output_shapes["action"] = online_env.action_space.spaces[0].shape
|
||||
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
# TODO: At some point we should just need make sac policy
|
||||
policy: SACPolicy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online training, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
# TODO: Handle resume training
|
||||
device=device,
|
||||
)
|
||||
policy = torch.compile(policy)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
obs, info = online_env.reset()
|
||||
|
||||
# NOTE: For the moment we will solely handle the case of a single environment
|
||||
sum_reward_episode = 0
|
||||
list_transition_to_send_to_learner = []
|
||||
list_policy_time = []
|
||||
episode_intervention = False
|
||||
|
||||
for interaction_step in range(cfg.training.online_steps):
|
||||
if interaction_step >= cfg.training.online_step_before_learning:
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with TimerManager(
|
||||
elapsed_time_list=list_policy_time, label="Policy inference time", log=False
|
||||
) as timer: # noqa: F841
|
||||
action = policy.select_action(batch=obs)
|
||||
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
|
||||
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
|
||||
else:
|
||||
# TODO (azouitine): Make a custom space for torch tensor
|
||||
action = online_env.action_space.sample()
|
||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||
|
||||
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
|
||||
action = (
|
||||
torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0)
|
||||
)
|
||||
|
||||
sum_reward_episode += float(reward)
|
||||
|
||||
# NOTE: We overide the action if the intervention is True, because the action applied is the intervention action
|
||||
if "is_intervention" in info and info["is_intervention"]:
|
||||
# TODO: Check the shape
|
||||
# NOTE: The action space for demonstration before hand is with the full action space
|
||||
# but sometimes for example we want to deactivate the gripper
|
||||
action = info["action_intervention"]
|
||||
episode_intervention = True
|
||||
|
||||
# Check for NaN values in observations
|
||||
for key, tensor in obs.items():
|
||||
if torch.isnan(tensor).any():
|
||||
logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}")
|
||||
|
||||
list_transition_to_send_to_learner.append(
|
||||
Transition(
|
||||
state=obs,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_obs,
|
||||
done=done,
|
||||
complementary_info=info, # TODO Handle information for the transition, is_demonstraction: bool
|
||||
)
|
||||
)
|
||||
|
||||
# assign obs to the next obs and continue the rollout
|
||||
obs = next_obs
|
||||
|
||||
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
|
||||
# Because we are using a single environment we can index at zero
|
||||
if done or truncated:
|
||||
# TODO: Handle logging for episode information
|
||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
|
||||
update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device)
|
||||
|
||||
if len(list_transition_to_send_to_learner) > 0:
|
||||
send_transitions_in_chunks(
|
||||
transitions=list_transition_to_send_to_learner, message_queue=message_queue, chunk_size=4
|
||||
)
|
||||
list_transition_to_send_to_learner = []
|
||||
|
||||
stats = get_frequency_stats(list_policy_time)
|
||||
list_policy_time.clear()
|
||||
|
||||
# Send episodic reward to the learner
|
||||
message_queue.put(
|
||||
ActorInformation(
|
||||
interaction_message={
|
||||
"Episodic reward": sum_reward_episode,
|
||||
"Interaction step": interaction_step,
|
||||
"Episode intervention": int(episode_intervention),
|
||||
**stats,
|
||||
}
|
||||
)
|
||||
)
|
||||
sum_reward_episode = 0.0
|
||||
episode_intervention = False
|
||||
obs, info = online_env.reset()
|
||||
|
||||
|
||||
def send_transitions_in_chunks(transitions: list, message_queue, chunk_size: int = 100):
|
||||
"""Send transitions to learner in smaller chunks to avoid network issues.
|
||||
|
||||
Args:
|
||||
transitions: List of transitions to send
|
||||
message_queue: Queue to send messages to learner
|
||||
chunk_size: Size of each chunk to send
|
||||
"""
|
||||
for i in range(0, len(transitions), chunk_size):
|
||||
chunk = transitions[i : i + chunk_size]
|
||||
logging.debug(f"[ACTOR] Sending chunk of {len(chunk)} transitions to Learner.")
|
||||
message_queue.put(ActorInformation(transition=chunk))
|
||||
|
||||
|
||||
def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
|
||||
stats = {}
|
||||
list_policy_fps = [1.0 / t for t in list_policy_time]
|
||||
if len(list_policy_fps) > 1:
|
||||
policy_fps = mean(list_policy_fps)
|
||||
quantiles_90 = quantiles(list_policy_fps, n=10)[-1]
|
||||
logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}")
|
||||
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}")
|
||||
stats = {"Policy frequency [Hz]": policy_fps, "Policy frequency 90th-p [Hz]": quantiles_90}
|
||||
return stats
|
||||
|
||||
|
||||
def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_step: int):
|
||||
if policy_fps < cfg.fps:
|
||||
logging.warning(
|
||||
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"
|
||||
)
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
||||
def actor_cli(cfg: dict):
|
||||
robot = make_robot(cfg=cfg.robot)
|
||||
|
||||
server_thread = Thread(target=serve_actor_service, args=(cfg.actor_learner_config.port,), daemon=True)
|
||||
|
||||
# HACK: FOR MANISKILL we do not have a reward classifier
|
||||
# TODO: Remove this once we merge into main
|
||||
reward_classifier = None
|
||||
if (
|
||||
cfg.env.reward_classifier.pretrained_path is not None
|
||||
and cfg.env.reward_classifier.config_path is not None
|
||||
):
|
||||
reward_classifier = get_classifier(
|
||||
pretrained_path=cfg.env.reward_classifier.pretrained_path,
|
||||
config_path=cfg.env.reward_classifier.config_path,
|
||||
)
|
||||
policy_thread = Thread(
|
||||
target=act_with_policy,
|
||||
daemon=True,
|
||||
args=(cfg, robot, reward_classifier),
|
||||
)
|
||||
server_thread.start()
|
||||
policy_thread.start()
|
||||
policy_thread.join()
|
||||
server_thread.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
actor_cli()
|
||||
@@ -1,593 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import functools
|
||||
import random
|
||||
from typing import Any, Callable, Optional, Sequence, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
class Transition(TypedDict):
|
||||
state: dict[str, torch.Tensor]
|
||||
action: torch.Tensor
|
||||
reward: float
|
||||
next_state: dict[str, torch.Tensor]
|
||||
done: bool
|
||||
complementary_info: dict[str, Any] = None
|
||||
|
||||
|
||||
class BatchTransition(TypedDict):
|
||||
state: dict[str, torch.Tensor]
|
||||
action: torch.Tensor
|
||||
reward: torch.Tensor
|
||||
next_state: dict[str, torch.Tensor]
|
||||
done: torch.Tensor
|
||||
|
||||
|
||||
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
|
||||
# Move state tensors to CPU
|
||||
device = torch.device(device)
|
||||
transition["state"] = {
|
||||
key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items()
|
||||
}
|
||||
|
||||
# Move action to CPU
|
||||
transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda")
|
||||
|
||||
# No need to move reward or done, as they are float and bool
|
||||
|
||||
# No need to move reward or done, as they are float and bool
|
||||
if isinstance(transition["reward"], torch.Tensor):
|
||||
transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda")
|
||||
|
||||
if isinstance(transition["done"], torch.Tensor):
|
||||
transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda")
|
||||
|
||||
# Move next_state tensors to CPU
|
||||
transition["next_state"] = {
|
||||
key: val.to(device, non_blocking=device.type == "cuda")
|
||||
for key, val in transition["next_state"].items()
|
||||
}
|
||||
|
||||
# If complementary_info is present, move its tensors to CPU
|
||||
# if transition["complementary_info"] is not None:
|
||||
# transition["complementary_info"] = {
|
||||
# key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items()
|
||||
# }
|
||||
return transition
|
||||
|
||||
|
||||
def move_state_dict_to_device(state_dict, device):
|
||||
"""
|
||||
Recursively move all tensors in a (potentially) nested
|
||||
dict/list/tuple structure to the CPU.
|
||||
"""
|
||||
if isinstance(state_dict, torch.Tensor):
|
||||
return state_dict.to(device)
|
||||
elif isinstance(state_dict, dict):
|
||||
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
|
||||
elif isinstance(state_dict, list):
|
||||
return [move_state_dict_to_device(v, device=device) for v in state_dict]
|
||||
elif isinstance(state_dict, tuple):
|
||||
return tuple(move_state_dict_to_device(v, device=device) for v in state_dict)
|
||||
else:
|
||||
return state_dict
|
||||
|
||||
|
||||
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
|
||||
"""
|
||||
Perform a per-image random crop over a batch of images in a vectorized way.
|
||||
(Same as shown previously.)
|
||||
"""
|
||||
B, C, H, W = images.shape # noqa: N806
|
||||
crop_h, crop_w = output_size
|
||||
|
||||
if crop_h > H or crop_w > W:
|
||||
raise ValueError(
|
||||
f"Requested crop size ({crop_h}, {crop_w}) is bigger than the image size ({H}, {W})."
|
||||
)
|
||||
|
||||
tops = torch.randint(0, H - crop_h + 1, (B,), device=images.device)
|
||||
lefts = torch.randint(0, W - crop_w + 1, (B,), device=images.device)
|
||||
|
||||
rows = torch.arange(crop_h, device=images.device).unsqueeze(0) + tops.unsqueeze(1)
|
||||
cols = torch.arange(crop_w, device=images.device).unsqueeze(0) + lefts.unsqueeze(1)
|
||||
|
||||
rows = rows.unsqueeze(2).expand(-1, -1, crop_w) # (B, crop_h, crop_w)
|
||||
cols = cols.unsqueeze(1).expand(-1, crop_h, -1) # (B, crop_h, crop_w)
|
||||
|
||||
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
|
||||
|
||||
# Gather pixels
|
||||
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
|
||||
# cropped_hwcn => (B, crop_h, crop_w, C)
|
||||
|
||||
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
|
||||
return cropped
|
||||
|
||||
|
||||
def random_shift(images: torch.Tensor, pad: int = 4):
|
||||
"""Vectorized random shift, imgs: (B,C,H,W), pad: #pixels"""
|
||||
_, _, h, w = images.shape
|
||||
images = F.pad(input=images, pad=(pad, pad, pad, pad), mode="replicate")
|
||||
return random_crop_vectorized(images=images, output_size=(h, w))
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
def __init__(
|
||||
self,
|
||||
capacity: int,
|
||||
device: str = "cuda:0",
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
image_augmentation_function: Optional[Callable] = None,
|
||||
use_drq: bool = True,
|
||||
storage_device: str = "cpu",
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
capacity (int): Maximum number of transitions to store in the buffer.
|
||||
device (str): The device where the tensors will be moved ("cuda:0" or "cpu").
|
||||
state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
|
||||
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
|
||||
and returns a batch of augmented images. If None, a default augmentation function is used.
|
||||
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
|
||||
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored when adding transitions to the buffer.
|
||||
Using "cpu" can help save GPU memory.
|
||||
"""
|
||||
self.capacity = capacity
|
||||
self.device = device
|
||||
self.storage_device = storage_device
|
||||
self.memory: list[Transition] = []
|
||||
self.position = 0
|
||||
|
||||
# If no state_keys provided, default to an empty list
|
||||
# (you can handle this differently if needed)
|
||||
self.state_keys = state_keys if state_keys is not None else []
|
||||
if image_augmentation_function is None:
|
||||
self.image_augmentation_function = functools.partial(random_shift, pad=4)
|
||||
self.use_drq = use_drq
|
||||
|
||||
def __len__(self):
|
||||
return len(self.memory)
|
||||
|
||||
def add(
|
||||
self,
|
||||
state: dict[str, torch.Tensor],
|
||||
action: torch.Tensor,
|
||||
reward: float,
|
||||
next_state: dict[str, torch.Tensor],
|
||||
done: bool,
|
||||
complementary_info: Optional[dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
||||
# Move tensors to the storage device
|
||||
state = {key: tensor.to(self.storage_device) for key, tensor in state.items()}
|
||||
next_state = {key: tensor.to(self.storage_device) for key, tensor in next_state.items()}
|
||||
action = action.to(self.storage_device)
|
||||
# if complementary_info is not None:
|
||||
# complementary_info = {
|
||||
# key: tensor.to(self.storage_device) for key, tensor in complementary_info.items()
|
||||
# }
|
||||
|
||||
if len(self.memory) < self.capacity:
|
||||
self.memory.append(None)
|
||||
|
||||
# Create and store the Transition
|
||||
self.memory[self.position] = Transition(
|
||||
state=state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
|
||||
# TODO: ADD image_augmentation and use_drq arguments in this function in order to instantiate the class with them
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
cls,
|
||||
lerobot_dataset: LeRobotDataset,
|
||||
device: str = "cuda:0",
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
capacity: Optional[int] = None,
|
||||
action_mask: Optional[Sequence[int]] = None,
|
||||
action_delta: Optional[float] = None,
|
||||
) -> "ReplayBuffer":
|
||||
"""
|
||||
Convert a LeRobotDataset into a ReplayBuffer.
|
||||
|
||||
Args:
|
||||
lerobot_dataset (LeRobotDataset): The dataset to convert.
|
||||
device (str): The device . Defaults to "cuda:0".
|
||||
state_keys (Optional[Sequence[str]], optional): The list of keys that appear in `state` and `next_state`.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
ReplayBuffer: The replay buffer with offline dataset transitions.
|
||||
"""
|
||||
# We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from
|
||||
# a replay buffer than from a lerobot dataset.
|
||||
if capacity is None:
|
||||
capacity = len(lerobot_dataset)
|
||||
|
||||
if capacity < len(lerobot_dataset):
|
||||
raise ValueError(
|
||||
"The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset."
|
||||
)
|
||||
|
||||
replay_buffer = cls(capacity=capacity, device=device, state_keys=state_keys)
|
||||
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
|
||||
# Fill the replay buffer with the lerobot dataset transitions
|
||||
for data in list_transition:
|
||||
for k, v in data.items():
|
||||
if isinstance(v, dict):
|
||||
for key, tensor in v.items():
|
||||
v[key] = tensor.to(device)
|
||||
elif isinstance(v, torch.Tensor):
|
||||
data[k] = v.to(device)
|
||||
|
||||
if action_mask is not None:
|
||||
if data["action"].dim() == 1:
|
||||
data["action"] = data["action"][action_mask]
|
||||
else:
|
||||
data["action"] = data["action"][:, action_mask]
|
||||
|
||||
if action_delta is not None:
|
||||
data["action"] = data["action"] / action_delta
|
||||
|
||||
replay_buffer.add(
|
||||
state=data["state"],
|
||||
action=data["action"],
|
||||
reward=data["reward"],
|
||||
next_state=data["next_state"],
|
||||
done=data["done"],
|
||||
)
|
||||
return replay_buffer
|
||||
|
||||
@staticmethod
|
||||
def _lerobotdataset_to_transitions(
|
||||
dataset: LeRobotDataset,
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
) -> list[Transition]:
|
||||
"""
|
||||
Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions.
|
||||
|
||||
Args:
|
||||
dataset (LeRobotDataset):
|
||||
The dataset to convert. Each item in the dataset is expected to have
|
||||
at least the following keys:
|
||||
{
|
||||
"action": ...
|
||||
"next.reward": ...
|
||||
"next.done": ...
|
||||
"episode_index": ...
|
||||
}
|
||||
plus whatever your 'state_keys' specify.
|
||||
|
||||
state_keys (Optional[Sequence[str]]):
|
||||
The dataset keys to include in 'state' and 'next_state'. Their names
|
||||
will be kept as-is in the output transitions. E.g.
|
||||
["observation.state", "observation.environment_state"].
|
||||
If None, you must handle or define default keys.
|
||||
|
||||
Returns:
|
||||
transitions (List[Transition]):
|
||||
A list of Transition dictionaries with the same length as `dataset`.
|
||||
"""
|
||||
|
||||
# If not provided, you can either raise an error or define a default:
|
||||
if state_keys is None:
|
||||
raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.")
|
||||
|
||||
transitions: list[Transition] = []
|
||||
num_frames = len(dataset)
|
||||
|
||||
for i in tqdm(range(num_frames)):
|
||||
current_sample = dataset[i]
|
||||
|
||||
# ----- 1) Current state -----
|
||||
current_state: dict[str, torch.Tensor] = {}
|
||||
for key in state_keys:
|
||||
val = current_sample[key]
|
||||
current_state[key] = val.unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- 2) Action -----
|
||||
action = current_sample["action"].unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- 3) Reward and done -----
|
||||
reward = float(current_sample["next.reward"].item()) # ensure float
|
||||
done = bool(current_sample["next.done"].item()) # ensure bool
|
||||
|
||||
# ----- 4) Next state -----
|
||||
# If not done and the next sample is in the same episode, we pull the next sample's state.
|
||||
# Otherwise (done=True or next sample crosses to a new episode), next_state = current_state.
|
||||
next_state = current_state # default
|
||||
if not done and (i < num_frames - 1):
|
||||
next_sample = dataset[i + 1]
|
||||
if next_sample["episode_index"] == current_sample["episode_index"]:
|
||||
# Build next_state from the same keys
|
||||
next_state_data: dict[str, torch.Tensor] = {}
|
||||
for key in state_keys:
|
||||
val = next_sample[key]
|
||||
next_state_data[key] = val.unsqueeze(0) # Add batch dimension
|
||||
next_state = next_state_data
|
||||
|
||||
# ----- Construct the Transition -----
|
||||
transition = Transition(
|
||||
state=current_state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done,
|
||||
)
|
||||
transitions.append(transition)
|
||||
|
||||
return transitions
|
||||
|
||||
def sample(self, batch_size: int) -> BatchTransition:
|
||||
"""Sample a random batch of transitions and collate them into batched tensors."""
|
||||
batch_size = min(batch_size, len(self.memory))
|
||||
list_of_transitions = random.sample(self.memory, batch_size)
|
||||
|
||||
# -- Build batched states --
|
||||
batch_state = {}
|
||||
for key in self.state_keys:
|
||||
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to(
|
||||
self.device
|
||||
)
|
||||
if key.startswith("observation.image") and self.use_drq:
|
||||
batch_state[key] = self.image_augmentation_function(batch_state[key])
|
||||
|
||||
# -- Build batched actions --
|
||||
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device)
|
||||
|
||||
# -- Build batched rewards --
|
||||
batch_rewards = torch.tensor([t["reward"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
# -- Build batched next states --
|
||||
batch_next_state = {}
|
||||
for key in self.state_keys:
|
||||
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to(
|
||||
self.device
|
||||
)
|
||||
if key.startswith("observation.image") and self.use_drq:
|
||||
batch_next_state[key] = self.image_augmentation_function(batch_next_state[key])
|
||||
|
||||
# -- Build batched dones --
|
||||
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
# Return a BatchTransition typed dict
|
||||
return BatchTransition(
|
||||
state=batch_state,
|
||||
action=batch_actions,
|
||||
reward=batch_rewards,
|
||||
next_state=batch_next_state,
|
||||
done=batch_dones,
|
||||
)
|
||||
|
||||
def to_lerobot_dataset(
|
||||
self,
|
||||
repo_id: str,
|
||||
fps=1, # If you have real timestamps, adjust this
|
||||
root=None,
|
||||
task_name="from_replay_buffer",
|
||||
) -> LeRobotDataset:
|
||||
"""
|
||||
Converts all transitions in this ReplayBuffer into a single LeRobotDataset object,
|
||||
splitting episodes by transitions where 'done=True'.
|
||||
|
||||
Returns:
|
||||
LeRobotDataset: The resulting offline dataset.
|
||||
"""
|
||||
if len(self.memory) == 0:
|
||||
raise ValueError("The replay buffer is empty. Cannot convert to a dataset.")
|
||||
|
||||
# Infer the shapes and dtypes of your features
|
||||
# We'll create a features dict that is suitable for LeRobotDataset
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# First, grab one transition to inspect shapes
|
||||
first_transition = self.memory[0]
|
||||
|
||||
# We'll store default metadata for every episode: indexes, timestamps, etc.
|
||||
features = {
|
||||
"index": {"dtype": "int64", "shape": [1]}, # global index across episodes
|
||||
"episode_index": {"dtype": "int64", "shape": [1]}, # which episode
|
||||
"frame_index": {"dtype": "int64", "shape": [1]}, # index inside an episode
|
||||
"timestamp": {"dtype": "float32", "shape": [1]}, # for now we store dummy
|
||||
"task_index": {"dtype": "int64", "shape": [1]},
|
||||
}
|
||||
|
||||
# Add "action"
|
||||
act_info = guess_feature_info(
|
||||
first_transition["action"].squeeze(dim=0), "action"
|
||||
) # Remove batch dimension
|
||||
features["action"] = act_info
|
||||
|
||||
# Add "reward" (scalars)
|
||||
features["next.reward"] = {"dtype": "float32", "shape": (1,)}
|
||||
|
||||
# Add "done" (boolean scalars)
|
||||
features["next.done"] = {"dtype": "bool", "shape": (1,)}
|
||||
|
||||
# Add state keys
|
||||
for key in self.state_keys:
|
||||
sample_val = first_transition["state"][key].squeeze(dim=0) # Remove batch dimension
|
||||
if not isinstance(sample_val, torch.Tensor):
|
||||
raise ValueError(
|
||||
f"State key '{key}' is not a torch.Tensor. Please ensure your states are stored as torch.Tensors."
|
||||
)
|
||||
f_info = guess_feature_info(sample_val, key)
|
||||
features[key] = f_info
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Create an empty LeRobotDataset
|
||||
# We'll store all frames as separate images only if we detect shape = (3, H, W) or (1, H, W).
|
||||
# By default we won't do videos, but feel free to adapt if you have them.
|
||||
# --------------------------------------------------------------------------------------------
|
||||
lerobot_dataset = LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
fps=fps, # If you have real timestamps, adjust this
|
||||
root=root, # Or some local path where you'd like the dataset files to go
|
||||
robot=None,
|
||||
robot_type=None,
|
||||
features=features,
|
||||
use_videos=True, # We won't do actual video encoding for a replay buffer
|
||||
)
|
||||
|
||||
# Start writing images if needed. If you have no image features, this is harmless.
|
||||
# Set num_processes or num_threads if you want concurrency.
|
||||
lerobot_dataset.start_image_writer(num_processes=0, num_threads=2)
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Convert transitions into episodes and frames
|
||||
# We detect episode boundaries by `done == True`.
|
||||
# --------------------------------------------------------------------------------------------
|
||||
episode_index = 0
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index)
|
||||
|
||||
frame_idx_in_episode = 0
|
||||
for global_frame_idx, transition in enumerate(self.memory):
|
||||
frame_dict = {}
|
||||
|
||||
# Fill the data for state keys
|
||||
for key in self.state_keys:
|
||||
# Expand dimension to match what the dataset expects (the dataset wants the raw shape)
|
||||
# We assume your buffer has shape [C, H, W] (if image) or [D] if vector
|
||||
# This is typically already correct, but if needed you can reshape below.
|
||||
frame_dict[key] = transition["state"][key].cpu().squeeze(dim=0) # Remove batch dimension
|
||||
|
||||
# Fill action, reward, done
|
||||
# Make sure they are shape (X,) or (X,Y,...) as needed.
|
||||
frame_dict["action"] = transition["action"].cpu().squeeze(dim=0) # Remove batch dimension
|
||||
frame_dict["next.reward"] = (
|
||||
torch.tensor([transition["reward"]], dtype=torch.float32).cpu().squeeze(dim=0)
|
||||
)
|
||||
frame_dict["next.done"] = (
|
||||
torch.tensor([transition["done"]], dtype=torch.bool).cpu().squeeze(dim=0)
|
||||
)
|
||||
# Add to the dataset's buffer
|
||||
lerobot_dataset.add_frame(frame_dict)
|
||||
|
||||
# Move to next frame
|
||||
frame_idx_in_episode += 1
|
||||
# If we reached an episode boundary, call save_episode, reset counters
|
||||
if transition["done"]:
|
||||
# Use some placeholder name for the task
|
||||
lerobot_dataset.save_episode(task="from_replay_buffer")
|
||||
episode_index += 1
|
||||
frame_idx_in_episode = 0
|
||||
# Start a new buffer for the next episode
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index)
|
||||
|
||||
# We are done adding frames
|
||||
# If the last transition wasn't done=True, we still have an open buffer with frames.
|
||||
# We'll consider that an incomplete episode and still save it:
|
||||
if lerobot_dataset.episode_buffer["size"] > 0:
|
||||
lerobot_dataset.save_episode(task=task_name)
|
||||
|
||||
lerobot_dataset.stop_image_writer()
|
||||
|
||||
lerobot_dataset.consolidate(run_compute_stats=False, keep_image_files=False)
|
||||
|
||||
return lerobot_dataset
|
||||
|
||||
|
||||
# Utility function to guess shapes/dtypes from a tensor
|
||||
def guess_feature_info(t: torch.Tensor, name: str):
|
||||
"""
|
||||
Return a dictionary with the 'dtype' and 'shape' for a given tensor or array.
|
||||
If it looks like a 3D (C,H,W) shape, we might consider it an 'image'.
|
||||
Otherwise default to 'float32' for numeric. You can customize as needed.
|
||||
"""
|
||||
shape = tuple(t.shape)
|
||||
# Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image'
|
||||
if len(shape) == 3 and shape[0] in [1, 3]:
|
||||
return {
|
||||
"dtype": "image",
|
||||
"shape": shape,
|
||||
}
|
||||
else:
|
||||
# Otherwise treat as numeric
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": shape,
|
||||
}
|
||||
|
||||
|
||||
def concatenate_batch_transitions(
|
||||
left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
|
||||
) -> BatchTransition:
|
||||
"""NOTE: Be careful it change the left_batch_transitions in place"""
|
||||
left_batch_transitions["state"] = {
|
||||
key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0)
|
||||
for key in left_batch_transitions["state"]
|
||||
}
|
||||
left_batch_transitions["action"] = torch.cat(
|
||||
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
|
||||
)
|
||||
left_batch_transitions["reward"] = torch.cat(
|
||||
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
|
||||
)
|
||||
left_batch_transitions["next_state"] = {
|
||||
key: torch.cat(
|
||||
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0
|
||||
)
|
||||
for key in left_batch_transitions["next_state"]
|
||||
}
|
||||
left_batch_transitions["done"] = torch.cat(
|
||||
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
|
||||
)
|
||||
return left_batch_transitions
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# dataset_name = "aractingi/push_green_cube_hf_cropped_resized"
|
||||
# dataset = LeRobotDataset(repo_id=dataset_name)
|
||||
|
||||
# replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
# lerobot_dataset=dataset, state_keys=["observation.image", "observation.state"]
|
||||
# )
|
||||
# replay_buffer_converted = replay_buffer.to_lerobot_dataset(repo_id="AdilZtn/pusht_image_converted")
|
||||
# for i in range(len(replay_buffer_converted)):
|
||||
# replay_convert = replay_buffer_converted[i]
|
||||
# dataset_convert = dataset[i]
|
||||
# for key in replay_convert.keys():
|
||||
# if key in {"index", "episode_index", "frame_index", "timestamp", "task_index"}:
|
||||
# continue
|
||||
# if key in dataset_convert.keys():
|
||||
# assert torch.equal(replay_convert[key], dataset_convert[key])
|
||||
# print(f"Key {key} is equal : {replay_convert[key].size()}, {dataset_convert[key].size()}")
|
||||
# re_reconverted_dataset = ReplayBuffer.from_lerobot_dataset(
|
||||
# replay_buffer_converted, state_keys=["observation.image", "observation.state"], device="cpu"
|
||||
# )
|
||||
# for _ in range(20):
|
||||
# batch = re_reconverted_dataset.sample(32)
|
||||
|
||||
# for key in batch.keys():
|
||||
# if key in {"state", "next_state"}:
|
||||
# for key_state in batch[key].keys():
|
||||
# print(key_state, batch[key][key_state].size())
|
||||
# continue
|
||||
# print(key, batch[key].size())
|
||||
@@ -1,287 +0,0 @@
|
||||
import argparse # noqa: I001
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Tuple
|
||||
from pathlib import Path
|
||||
import cv2
|
||||
|
||||
# import torch.nn.functional as F # noqa: N812
|
||||
import torchvision.transforms.functional as F # type: ignore # noqa: N812
|
||||
from tqdm import tqdm # type: ignore
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def select_rect_roi(img):
|
||||
"""
|
||||
Allows the user to draw a rectangular ROI on the image.
|
||||
|
||||
The user must click and drag to draw the rectangle.
|
||||
- While dragging, the rectangle is dynamically drawn.
|
||||
- On mouse button release, the rectangle is fixed.
|
||||
- Press 'c' to confirm the selection.
|
||||
- Press 'r' to reset the selection.
|
||||
- Press ESC to cancel.
|
||||
|
||||
Returns:
|
||||
A tuple (top, left, height, width) representing the rectangular ROI,
|
||||
or None if no valid ROI is selected.
|
||||
"""
|
||||
# Create a working copy of the image
|
||||
clone = img.copy()
|
||||
working_img = clone.copy()
|
||||
|
||||
roi = None # Will store the final ROI as (top, left, height, width)
|
||||
drawing = False
|
||||
ix, iy = -1, -1 # Initial click coordinates
|
||||
|
||||
def mouse_callback(event, x, y, flags, param):
|
||||
nonlocal ix, iy, drawing, roi, working_img
|
||||
|
||||
if event == cv2.EVENT_LBUTTONDOWN:
|
||||
# Start drawing: record starting coordinates
|
||||
drawing = True
|
||||
ix, iy = x, y
|
||||
|
||||
elif event == cv2.EVENT_MOUSEMOVE:
|
||||
if drawing:
|
||||
# Compute the top-left and bottom-right corners regardless of drag direction
|
||||
top = min(iy, y)
|
||||
left = min(ix, x)
|
||||
bottom = max(iy, y)
|
||||
right = max(ix, x)
|
||||
# Show a temporary image with the current rectangle drawn
|
||||
temp = working_img.copy()
|
||||
cv2.rectangle(temp, (left, top), (right, bottom), (0, 255, 0), 2)
|
||||
cv2.imshow("Select ROI", temp)
|
||||
|
||||
elif event == cv2.EVENT_LBUTTONUP:
|
||||
# Finish drawing
|
||||
drawing = False
|
||||
top = min(iy, y)
|
||||
left = min(ix, x)
|
||||
bottom = max(iy, y)
|
||||
right = max(ix, x)
|
||||
height = bottom - top
|
||||
width = right - left
|
||||
roi = (top, left, height, width) # (top, left, height, width)
|
||||
# Draw the final rectangle on the working image and display it
|
||||
working_img = clone.copy()
|
||||
cv2.rectangle(working_img, (left, top), (right, bottom), (0, 255, 0), 2)
|
||||
cv2.imshow("Select ROI", working_img)
|
||||
|
||||
# Create the window and set the callback
|
||||
cv2.namedWindow("Select ROI")
|
||||
cv2.setMouseCallback("Select ROI", mouse_callback)
|
||||
cv2.imshow("Select ROI", working_img)
|
||||
|
||||
print("Instructions for ROI selection:")
|
||||
print(" - Click and drag to draw a rectangular ROI.")
|
||||
print(" - Press 'c' to confirm the selection.")
|
||||
print(" - Press 'r' to reset and draw again.")
|
||||
print(" - Press ESC to cancel the selection.")
|
||||
|
||||
# Wait until the user confirms with 'c', resets with 'r', or cancels with ESC
|
||||
while True:
|
||||
key = cv2.waitKey(1) & 0xFF
|
||||
# Confirm ROI if one has been drawn
|
||||
if key == ord("c") and roi is not None:
|
||||
break
|
||||
# Reset: clear the ROI and restore the original image
|
||||
elif key == ord("r"):
|
||||
working_img = clone.copy()
|
||||
roi = None
|
||||
cv2.imshow("Select ROI", working_img)
|
||||
# Cancel selection for this image
|
||||
elif key == 27: # ESC key
|
||||
roi = None
|
||||
break
|
||||
|
||||
cv2.destroyWindow("Select ROI")
|
||||
return roi
|
||||
|
||||
|
||||
def select_square_roi_for_images(images: dict) -> dict:
|
||||
"""
|
||||
For each image in the provided dictionary, open a window to allow the user
|
||||
to select a rectangular ROI. Returns a dictionary mapping each key to a tuple
|
||||
(top, left, height, width) representing the ROI.
|
||||
|
||||
Parameters:
|
||||
images (dict): Dictionary where keys are identifiers and values are OpenCV images.
|
||||
|
||||
Returns:
|
||||
dict: Mapping of image keys to the selected rectangular ROI.
|
||||
"""
|
||||
selected_rois = {}
|
||||
|
||||
for key, img in images.items():
|
||||
if img is None:
|
||||
print(f"Image for key '{key}' is None, skipping.")
|
||||
continue
|
||||
|
||||
print(f"\nSelect rectangular ROI for image with key: '{key}'")
|
||||
roi = select_rect_roi(img)
|
||||
|
||||
if roi is None:
|
||||
print(f"No valid ROI selected for '{key}'.")
|
||||
else:
|
||||
selected_rois[key] = roi
|
||||
print(f"ROI for '{key}': {roi}")
|
||||
|
||||
return selected_rois
|
||||
|
||||
|
||||
def get_image_from_lerobot_dataset(dataset: LeRobotDataset):
|
||||
"""
|
||||
Find the first row in the dataset and extract the image in order to be used for the crop.
|
||||
"""
|
||||
row = dataset[0]
|
||||
image_dict = {}
|
||||
for k in row:
|
||||
if "image" in k:
|
||||
image_dict[k] = deepcopy(row[k])
|
||||
return image_dict
|
||||
|
||||
|
||||
def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
original_dataset: LeRobotDataset,
|
||||
crop_params_dict: Dict[str, Tuple[int, int, int, int]],
|
||||
new_repo_id: str,
|
||||
new_dataset_root: str,
|
||||
resize_size: Tuple[int, int] = (128, 128),
|
||||
) -> LeRobotDataset:
|
||||
"""
|
||||
Converts an existing LeRobotDataset by iterating over its episodes and frames,
|
||||
applying cropping and resizing to image observations, and saving a new dataset
|
||||
with the transformed data.
|
||||
|
||||
Args:
|
||||
original_dataset (LeRobotDataset): The source dataset.
|
||||
crop_params_dict (Dict[str, Tuple[int, int, int, int]]):
|
||||
A dictionary mapping observation keys to crop parameters (top, left, height, width).
|
||||
new_repo_id (str): Repository id for the new dataset.
|
||||
new_dataset_root (str): The root directory where the new dataset will be written.
|
||||
resize_size (Tuple[int, int], optional): The target size (height, width) after cropping.
|
||||
Defaults to (128, 128).
|
||||
|
||||
Returns:
|
||||
LeRobotDataset: A new LeRobotDataset where the specified image observations have been cropped
|
||||
and resized.
|
||||
"""
|
||||
# 1. Create a new (empty) LeRobotDataset for writing.
|
||||
new_dataset = LeRobotDataset.create(
|
||||
repo_id=new_repo_id,
|
||||
fps=original_dataset.fps,
|
||||
root=new_dataset_root,
|
||||
robot_type=original_dataset.meta.robot_type,
|
||||
features=original_dataset.meta.info["features"],
|
||||
use_videos=len(original_dataset.meta.video_keys) > 0,
|
||||
)
|
||||
|
||||
# Update the metadata for every image key that will be cropped:
|
||||
# (Here we simply set the shape to be the final resize_size.)
|
||||
for key in crop_params_dict:
|
||||
if key in new_dataset.meta.info["features"]:
|
||||
new_dataset.meta.info["features"][key]["shape"] = list(resize_size)
|
||||
|
||||
# 2. Process each episode in the original dataset.
|
||||
episodes_info = original_dataset.meta.episodes
|
||||
# (Sort episodes by episode_index for consistency.)
|
||||
|
||||
episodes_info = sorted(episodes_info, key=lambda x: x["episode_index"])
|
||||
# Use the first task from the episode metadata (or "unknown" if not provided)
|
||||
task = episodes_info[0]["tasks"][0] if episodes_info[0].get("tasks") else "unknown"
|
||||
|
||||
last_episode_index = 0
|
||||
for sample in tqdm(original_dataset):
|
||||
episode_index = sample.pop("episode_index")
|
||||
if episode_index != last_episode_index:
|
||||
new_dataset.save_episode(task, encode_videos=True)
|
||||
last_episode_index = episode_index
|
||||
sample.pop("frame_index")
|
||||
# Make a shallow copy of the sample (the values—e.g. torch tensors—are assumed immutable)
|
||||
new_sample = sample.copy()
|
||||
# Loop over each observation key that should be cropped/resized.
|
||||
for key, params in crop_params_dict.items():
|
||||
if key in new_sample:
|
||||
top, left, height, width = params
|
||||
# Apply crop then resize.
|
||||
cropped = F.crop(new_sample[key], top, left, height, width)
|
||||
resized = F.resize(cropped, resize_size)
|
||||
new_sample[key] = resized
|
||||
# Add the transformed frame to the new dataset.
|
||||
new_dataset.add_frame(new_sample)
|
||||
|
||||
# save last episode
|
||||
new_dataset.save_episode(task, encode_videos=True)
|
||||
|
||||
# Optionally, consolidate the new dataset to compute statistics and update video info.
|
||||
new_dataset.consolidate(run_compute_stats=True, keep_image_files=True)
|
||||
|
||||
new_dataset.push_to_hub(tags=None)
|
||||
|
||||
return new_dataset
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default="lerobot",
|
||||
help="The repository id of the LeRobot dataset to process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The root directory of the LeRobot dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crop-params-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path to the JSON file containing the ROIs.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
local_files_only = args.root is not None
|
||||
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=local_files_only)
|
||||
|
||||
images = get_image_from_lerobot_dataset(dataset)
|
||||
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
|
||||
images = {k: (v * 255).astype("uint8") for k, v in images.items()}
|
||||
|
||||
if args.crop_params_path is None:
|
||||
rois = select_square_roi_for_images(images)
|
||||
else:
|
||||
with open(args.crop_params_path, "r") as f:
|
||||
rois = json.load(f)
|
||||
|
||||
# rois = {
|
||||
# "observation.images.front": [102, 43, 358, 523],
|
||||
# "observation.images.side": [92, 123, 379, 349],
|
||||
# }
|
||||
|
||||
# Print the selected rectangular ROIs
|
||||
print("\nSelected Rectangular Regions of Interest (top, left, height, width):")
|
||||
for key, roi in rois.items():
|
||||
print(f"{key}: {roi}")
|
||||
|
||||
new_repo_id = args.repo_id + "_cropped_resized"
|
||||
new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
|
||||
|
||||
croped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
original_dataset=dataset,
|
||||
crop_params_dict=rois,
|
||||
new_repo_id=new_repo_id,
|
||||
new_dataset_root=new_dataset_root,
|
||||
resize_size=(128, 128),
|
||||
)
|
||||
|
||||
meta_dir = new_dataset_root / "meta"
|
||||
meta_dir.mkdir(exist_ok=True)
|
||||
|
||||
with open(meta_dir / "crop_params.json", "w") as f:
|
||||
json.dump(rois, f, indent=4)
|
||||
@@ -1,65 +0,0 @@
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.robot_devices.control_utils import is_headless
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
||||
|
||||
def find_joint_bounds(
|
||||
robot,
|
||||
control_time_s=20,
|
||||
display_cameras=False,
|
||||
):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
control_time_s = float("inf")
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
pos_list = []
|
||||
while timestamp < control_time_s:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
|
||||
pos_list.append(robot.follower_arms["main"].read("Present_Position"))
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
if timestamp > 60:
|
||||
max = np.max(np.stack(pos_list), 0)
|
||||
min = np.min(np.stack(pos_list), 0)
|
||||
print(f"Max angle position per joint {max}")
|
||||
print(f"Min angle position per joint {min}")
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--robot-path",
|
||||
type=str,
|
||||
default="lerobot/configs/robot/koch.yaml",
|
||||
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-overrides",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds")
|
||||
args = parser.parse_args()
|
||||
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
|
||||
|
||||
robot = make_robot(robot_cfg)
|
||||
find_joint_bounds(robot, control_time_s=args.control_time_s)
|
||||
@@ -1,861 +0,0 @@
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
from threading import Lock
|
||||
from typing import Annotated, Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F # noqa: N812
|
||||
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.utils.utils import init_hydra_config, log_say
|
||||
from lerobot.scripts.server.maniskill_manipulator import make_maniskill
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
class HILSerlRobotEnv(gym.Env):
|
||||
"""
|
||||
Gym-compatible environment for evaluating robotic control policies with integrated human intervention.
|
||||
|
||||
This environment wraps a robot interface to provide a consistent API for policy evaluation. It supports both relative (delta)
|
||||
and absolute joint position commands and automatically configures its observation and action spaces based on the robot's
|
||||
sensors and configuration.
|
||||
|
||||
The environment can switch between executing actions from a policy or using teleoperated actions (human intervention) during
|
||||
each step. When teleoperation is used, the override action is captured and returned in the `info` dict along with a flag
|
||||
`is_intervention`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
robot,
|
||||
use_delta_action_space: bool = True,
|
||||
delta: float | None = None,
|
||||
display_cameras: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the HILSerlRobotEnv environment.
|
||||
|
||||
The environment is set up with a robot interface, which is used to capture observations and send joint commands. The setup
|
||||
supports both relative (delta) adjustments and absolute joint positions for controlling the robot.
|
||||
|
||||
Args:
|
||||
robot: The robot interface object used to connect and interact with the physical robot.
|
||||
use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute
|
||||
joint positions are used.
|
||||
delta (float or None): A scaling factor for the relative adjustments applied to joint positions. Should be a value between
|
||||
0 and 1 when using a delta action space.
|
||||
display_cameras (bool): If True, the robot's camera feeds will be displayed during execution.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.robot = robot
|
||||
self.display_cameras = display_cameras
|
||||
|
||||
# Connect to the robot if not already connected.
|
||||
if not self.robot.is_connected:
|
||||
self.robot.connect()
|
||||
|
||||
self.initial_follower_position = robot.follower_arms["main"].read("Present_Position")
|
||||
|
||||
# Episode tracking.
|
||||
self.current_step = 0
|
||||
self.episode_data = None
|
||||
|
||||
self.delta = delta
|
||||
self.use_delta_action_space = use_delta_action_space
|
||||
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
|
||||
|
||||
# Retrieve the size of the joint position interval bound.
|
||||
self.relative_bounds_size = (
|
||||
self.robot.config.joint_position_relative_bounds["max"]
|
||||
- self.robot.config.joint_position_relative_bounds["min"]
|
||||
)
|
||||
|
||||
self.delta_relative_bounds_size = self.relative_bounds_size * self.delta
|
||||
|
||||
self.robot.config.max_relative_target = self.delta_relative_bounds_size.float()
|
||||
|
||||
# Dynamically configure the observation and action spaces.
|
||||
self._setup_spaces()
|
||||
|
||||
def _setup_spaces(self):
|
||||
"""
|
||||
Dynamically configure the observation and action spaces based on the robot's capabilities.
|
||||
|
||||
Observation Space:
|
||||
- For keys with "image": A Box space with pixel values ranging from 0 to 255.
|
||||
- For non-image keys: A nested Dict space is created under 'observation.state' with a suitable range.
|
||||
|
||||
Action Space:
|
||||
- The action space is defined as a Tuple where:
|
||||
• The first element is a Box space representing joint position commands. It is defined as relative (delta)
|
||||
or absolute, based on the configuration.
|
||||
• The second element is a Discrete space (with 2 values) serving as a flag for intervention (teleoperation).
|
||||
"""
|
||||
example_obs = self.robot.capture_observation()
|
||||
|
||||
# Define observation spaces for images and other states.
|
||||
image_keys = [key for key in example_obs if "image" in key]
|
||||
state_keys = [key for key in example_obs if "image" not in key]
|
||||
observation_spaces = {
|
||||
key: gym.spaces.Box(low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8)
|
||||
for key in image_keys
|
||||
}
|
||||
observation_spaces["observation.state"] = gym.spaces.Dict(
|
||||
{
|
||||
key: gym.spaces.Box(low=0, high=10, shape=example_obs[key].shape, dtype=np.float32)
|
||||
for key in state_keys
|
||||
}
|
||||
)
|
||||
|
||||
self.observation_space = gym.spaces.Dict(observation_spaces)
|
||||
|
||||
# Define the action space for joint positions along with setting an intervention flag.
|
||||
action_dim = len(self.robot.follower_arms["main"].read("Present_Position"))
|
||||
if self.use_delta_action_space:
|
||||
action_space_robot = gym.spaces.Box(
|
||||
low=-self.relative_bounds_size.cpu().numpy(),
|
||||
high=self.relative_bounds_size.cpu().numpy(),
|
||||
shape=(action_dim,),
|
||||
dtype=np.float32,
|
||||
)
|
||||
else:
|
||||
action_space_robot = gym.spaces.Box(
|
||||
low=self.robot.config.joint_position_relative_bounds["min"].cpu().numpy(),
|
||||
high=self.robot.config.joint_position_relative_bounds["max"].cpu().numpy(),
|
||||
shape=(action_dim,),
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
self.action_space = gym.spaces.Tuple(
|
||||
(
|
||||
action_space_robot,
|
||||
gym.spaces.Discrete(2),
|
||||
),
|
||||
)
|
||||
|
||||
def reset(self, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
|
||||
"""
|
||||
Reset the environment to its initial state.
|
||||
This method resets the step counter and clears any episodic data.
|
||||
|
||||
Args:
|
||||
seed (Optional[int]): A seed for random number generation to ensure reproducibility.
|
||||
options (Optional[dict]): Additional options to influence the reset behavior.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- observation (dict): The initial sensor observation.
|
||||
- info (dict): A dictionary with supplementary information, including the key "initial_position".
|
||||
"""
|
||||
super().reset(seed=seed, options=options)
|
||||
|
||||
# Capture the initial observation.
|
||||
observation = self.robot.capture_observation()
|
||||
|
||||
# Reset episode tracking variables.
|
||||
self.current_step = 0
|
||||
self.episode_data = None
|
||||
|
||||
return observation, {"initial_position": self.initial_follower_position}
|
||||
|
||||
def step(
|
||||
self, action: Tuple[np.ndarray, bool]
|
||||
) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]:
|
||||
"""
|
||||
Execute a single step within the environment using the specified action.
|
||||
|
||||
The provided action is a tuple comprised of:
|
||||
• A policy action (joint position commands) that may be either in absolute values or as a delta.
|
||||
• A boolean flag indicating whether teleoperation (human intervention) should be used for this step.
|
||||
|
||||
Behavior:
|
||||
- When the intervention flag is False, the environment processes and sends the policy action to the robot.
|
||||
- When True, a teleoperation step is executed. If using a delta action space, an absolute teleop action is converted
|
||||
to relative change based on the current joint positions.
|
||||
|
||||
Args:
|
||||
action (tuple): A tuple with two elements:
|
||||
- policy_action (np.ndarray or torch.Tensor): The commanded joint positions.
|
||||
- intervention_bool (bool): True if the human operator intervenes by providing a teleoperation input.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing:
|
||||
- observation (dict): The new sensor observation after taking the step.
|
||||
- reward (float): The step reward (default is 0.0 within this wrapper).
|
||||
- terminated (bool): True if the episode has reached a terminal state.
|
||||
- truncated (bool): True if the episode was truncated (e.g., time constraints).
|
||||
- info (dict): Additional debugging information including:
|
||||
◦ "action_intervention": The teleop action if intervention was used.
|
||||
◦ "is_intervention": Flag indicating whether teleoperation was employed.
|
||||
"""
|
||||
policy_action, intervention_bool = action
|
||||
teleop_action = None
|
||||
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
|
||||
if isinstance(policy_action, torch.Tensor):
|
||||
policy_action = policy_action.cpu().numpy()
|
||||
policy_action = np.clip(policy_action, self.action_space[0].low, self.action_space[0].high)
|
||||
if not intervention_bool:
|
||||
if self.use_delta_action_space:
|
||||
target_joint_positions = self.current_joint_positions + self.delta * policy_action
|
||||
else:
|
||||
target_joint_positions = policy_action
|
||||
self.robot.send_action(torch.from_numpy(target_joint_positions))
|
||||
observation = self.robot.capture_observation()
|
||||
else:
|
||||
observation, teleop_action = self.robot.teleop_step(record_data=True)
|
||||
teleop_action = teleop_action["action"] # Convert tensor to appropriate format
|
||||
|
||||
# When applying the delta action space, convert teleop absolute values to relative differences.
|
||||
if self.use_delta_action_space:
|
||||
teleop_action = (teleop_action - self.current_joint_positions) / self.delta
|
||||
if torch.any(teleop_action < -self.relative_bounds_size) and torch.any(
|
||||
teleop_action > self.relative_bounds_size
|
||||
):
|
||||
logging.debug(
|
||||
f"Relative teleop delta exceeded bounds {self.relative_bounds_size}, teleop_action {teleop_action}\n"
|
||||
f"lower bounds condition {teleop_action < -self.relative_bounds_size}\n"
|
||||
f"upper bounds condition {teleop_action > self.relative_bounds_size}"
|
||||
)
|
||||
|
||||
teleop_action = torch.clamp(
|
||||
teleop_action, -self.relative_bounds_size, self.relative_bounds_size
|
||||
)
|
||||
# NOTE: To mimic the shape of a neural network output, we add a batch dimension to the teleop action.
|
||||
if teleop_action.dim() == 1:
|
||||
teleop_action = teleop_action.unsqueeze(0)
|
||||
|
||||
# self.render()
|
||||
|
||||
self.current_step += 1
|
||||
|
||||
reward = 0.0
|
||||
terminated = False
|
||||
truncated = False
|
||||
|
||||
return (
|
||||
observation,
|
||||
reward,
|
||||
terminated,
|
||||
truncated,
|
||||
{"action_intervention": teleop_action, "is_intervention": teleop_action is not None},
|
||||
)
|
||||
|
||||
def render(self):
|
||||
"""
|
||||
Render the current state of the environment by displaying the robot's camera feeds.
|
||||
"""
|
||||
import cv2
|
||||
|
||||
observation = self.robot.capture_observation()
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close the environment and clean up resources by disconnecting the robot.
|
||||
|
||||
If the robot is currently connected, this method properly terminates the connection to ensure that all
|
||||
associated resources are released.
|
||||
"""
|
||||
if self.robot.is_connected:
|
||||
self.robot.disconnect()
|
||||
|
||||
|
||||
class ActionRepeatWrapper(gym.Wrapper):
|
||||
def __init__(self, env, nb_repeat: int = 1):
|
||||
super().__init__(env)
|
||||
self.nb_repeat = nb_repeat
|
||||
|
||||
def step(self, action):
|
||||
for _ in range(self.nb_repeat):
|
||||
obs, reward, done, truncated, info = self.env.step(action)
|
||||
if done or truncated:
|
||||
break
|
||||
return obs, reward, done, truncated, info
|
||||
|
||||
|
||||
class RewardWrapper(gym.Wrapper):
|
||||
def __init__(self, env, reward_classifier, device: torch.device = "cuda"):
|
||||
"""
|
||||
Wrapper to add reward prediction to the environment, it use a trained classifer.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap
|
||||
reward_classifier: The reward classifier model
|
||||
device: The device to run the model on
|
||||
"""
|
||||
self.env = env
|
||||
|
||||
# NOTE: We got 15% speedup by compiling the model
|
||||
self.reward_classifier = torch.compile(reward_classifier)
|
||||
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
self.device = device
|
||||
|
||||
def step(self, action):
|
||||
observation, _, terminated, truncated, info = self.env.step(action)
|
||||
images = [
|
||||
observation[key].to(self.device, non_blocking=self.device.type == "cuda")
|
||||
for key in observation
|
||||
if "image" in key
|
||||
]
|
||||
start_time = time.perf_counter()
|
||||
with torch.inference_mode():
|
||||
reward = (
|
||||
self.reward_classifier.predict_reward(images, threshold=0.8)
|
||||
if self.reward_classifier is not None
|
||||
else 0.0
|
||||
)
|
||||
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
|
||||
|
||||
# logging.info(f"Reward: {reward}")
|
||||
|
||||
if reward == 1.0:
|
||||
terminated = True
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
return self.env.reset(seed=seed, options=options)
|
||||
|
||||
|
||||
class JointMaskingActionSpace(gym.Wrapper):
|
||||
def __init__(self, env, mask):
|
||||
"""
|
||||
Wrapper to mask out dimensions of the action space.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap
|
||||
mask: Binary mask array where 0 indicates dimensions to remove
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
# Validate mask matches action space
|
||||
|
||||
# Keep only dimensions where mask is 1
|
||||
self.active_dims = np.where(mask)[0]
|
||||
|
||||
if isinstance(env.action_space, gym.spaces.Box):
|
||||
if len(mask) != env.action_space.shape[0]:
|
||||
raise ValueError("Mask length must match action space dimensions")
|
||||
low = env.action_space.low[self.active_dims]
|
||||
high = env.action_space.high[self.active_dims]
|
||||
self.action_space = gym.spaces.Box(low=low, high=high, dtype=env.action_space.dtype)
|
||||
|
||||
if isinstance(env.action_space, gym.spaces.Tuple):
|
||||
if len(mask) != env.action_space[0].shape[0]:
|
||||
raise ValueError("Mask length must match action space 0 dimensions")
|
||||
|
||||
low = env.action_space[0].low[self.active_dims]
|
||||
high = env.action_space[0].high[self.active_dims]
|
||||
action_space_masked = gym.spaces.Box(low=low, high=high, dtype=env.action_space[0].dtype)
|
||||
self.action_space = gym.spaces.Tuple((action_space_masked, env.action_space[1]))
|
||||
# Create new action space with masked dimensions
|
||||
|
||||
def action(self, action):
|
||||
"""
|
||||
Convert masked action back to full action space.
|
||||
|
||||
Args:
|
||||
action: Action in masked space. For Tuple spaces, the first element is masked.
|
||||
|
||||
Returns:
|
||||
Action in original space with masked dims set to 0.
|
||||
"""
|
||||
|
||||
# Determine whether we are handling a Tuple space or a Box.
|
||||
if isinstance(self.env.action_space, gym.spaces.Tuple):
|
||||
# Extract the masked component from the tuple.
|
||||
masked_action = action[0] if isinstance(action, tuple) else action
|
||||
# Create a full action for the Box element.
|
||||
full_box_action = np.zeros(self.env.action_space[0].shape, dtype=self.env.action_space[0].dtype)
|
||||
full_box_action[self.active_dims] = masked_action
|
||||
# Return a tuple with the reconstructed Box action and the unchanged remainder.
|
||||
return (full_box_action, action[1])
|
||||
else:
|
||||
# For Box action spaces.
|
||||
masked_action = action if not isinstance(action, tuple) else action[0]
|
||||
full_action = np.zeros(self.env.action_space.shape, dtype=self.env.action_space.dtype)
|
||||
full_action[self.active_dims] = masked_action
|
||||
return full_action
|
||||
|
||||
def step(self, action):
|
||||
action = self.action(action)
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
if "action_intervention" in info and info["action_intervention"] is not None:
|
||||
if info["action_intervention"].dim() == 1:
|
||||
info["action_intervention"] = info["action_intervention"][self.active_dims]
|
||||
else:
|
||||
info["action_intervention"] = info["action_intervention"][:, self.active_dims]
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
|
||||
class TimeLimitWrapper(gym.Wrapper):
|
||||
def __init__(self, env, control_time_s, fps):
|
||||
self.env = env
|
||||
self.control_time_s = control_time_s
|
||||
self.fps = fps
|
||||
|
||||
self.last_timestamp = 0.0
|
||||
self.episode_time_in_s = 0.0
|
||||
|
||||
self.max_episode_steps = int(self.control_time_s * self.fps)
|
||||
|
||||
self.current_step = 0
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
time_since_last_step = time.perf_counter() - self.last_timestamp
|
||||
self.episode_time_in_s += time_since_last_step
|
||||
self.last_timestamp = time.perf_counter()
|
||||
self.current_step += 1
|
||||
# check if last timestep took more time than the expected fps
|
||||
if 1.0 / time_since_last_step < self.fps:
|
||||
logging.debug(f"Current timestep exceeded expected fps {self.fps}")
|
||||
|
||||
if self.episode_time_in_s > self.control_time_s:
|
||||
# if self.current_step >= self.max_episode_steps:
|
||||
# Terminated = True
|
||||
terminated = True
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
self.episode_time_in_s = 0.0
|
||||
self.last_timestamp = time.perf_counter()
|
||||
self.current_step = 0
|
||||
return self.env.reset(seed=seed, options=options)
|
||||
|
||||
|
||||
class ImageCropResizeWrapper(gym.Wrapper):
|
||||
def __init__(self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None):
|
||||
super().__init__(env)
|
||||
self.env = env
|
||||
self.crop_params_dict = crop_params_dict
|
||||
print(f"obs_keys , {self.env.observation_space}")
|
||||
print(f"crop params dict {crop_params_dict.keys()}")
|
||||
for key_crop in crop_params_dict:
|
||||
if key_crop not in self.env.observation_space.keys(): # noqa: SIM118
|
||||
raise ValueError(f"Key {key_crop} not in observation space")
|
||||
for key in crop_params_dict:
|
||||
top, left, height, width = crop_params_dict[key]
|
||||
new_shape = (top + height, left + width)
|
||||
self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape)
|
||||
|
||||
self.resize_size = resize_size
|
||||
if self.resize_size is None:
|
||||
self.resize_size = (128, 128)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
for k in self.crop_params_dict:
|
||||
device = obs[k].device
|
||||
|
||||
# Check for NaNs before processing
|
||||
if torch.isnan(obs[k]).any():
|
||||
logging.error(f"NaN values detected in observation {k} before crop and resize")
|
||||
|
||||
if device == torch.device("mps:0"):
|
||||
obs[k] = obs[k].cpu()
|
||||
|
||||
obs[k] = F.crop(obs[k], *self.crop_params_dict[k])
|
||||
obs[k] = F.resize(obs[k], self.resize_size)
|
||||
|
||||
# Check for NaNs after processing
|
||||
if torch.isnan(obs[k]).any():
|
||||
logging.error(f"NaN values detected in observation {k} after crop and resize")
|
||||
|
||||
obs[k] = obs[k].to(device)
|
||||
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
obs, info = self.env.reset(seed=seed, options=options)
|
||||
for k in self.crop_params_dict:
|
||||
device = obs[k].device
|
||||
if device == torch.device("mps:0"):
|
||||
obs[k] = obs[k].cpu()
|
||||
obs[k] = F.crop(obs[k], *self.crop_params_dict[k])
|
||||
obs[k] = F.resize(obs[k], self.resize_size)
|
||||
obs[k] = obs[k].to(device)
|
||||
return obs, info
|
||||
|
||||
|
||||
class ConvertToLeRobotObservation(gym.ObservationWrapper):
|
||||
def __init__(self, env, device):
|
||||
super().__init__(env)
|
||||
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
self.device = device
|
||||
|
||||
def observation(self, observation):
|
||||
observation = preprocess_observation(observation)
|
||||
|
||||
observation = {
|
||||
key: observation[key].to(self.device, non_blocking=self.device.type == "cuda")
|
||||
for key in observation
|
||||
}
|
||||
observation = {k: torch.tensor(v, device=self.device) for k, v in observation.items()}
|
||||
return observation
|
||||
|
||||
|
||||
class KeyboardInterfaceWrapper(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
self.listener = None
|
||||
self.events = {
|
||||
"exit_early": False,
|
||||
"pause_policy": False,
|
||||
"reset_env": False,
|
||||
"human_intervention_step": False,
|
||||
"episode_success": False,
|
||||
}
|
||||
self.event_lock = Lock() # Thread-safe access to events
|
||||
self._init_keyboard_listener()
|
||||
|
||||
def _init_keyboard_listener(self):
|
||||
"""Initialize keyboard listener if not in headless mode"""
|
||||
|
||||
if is_headless():
|
||||
logging.warning(
|
||||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||
)
|
||||
return
|
||||
try:
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
with self.event_lock:
|
||||
try:
|
||||
if key == keyboard.Key.right or key == keyboard.Key.esc:
|
||||
print("Right arrow key pressed. Exiting loop...")
|
||||
self.events["exit_early"] = True
|
||||
return
|
||||
if hasattr(key, "char") and key.char == "s":
|
||||
print("Key 's' pressed. Episode success triggered.")
|
||||
self.events["episode_success"] = True
|
||||
return
|
||||
if key == keyboard.Key.space and not self.events["exit_early"]:
|
||||
if not self.events["pause_policy"]:
|
||||
print(
|
||||
"Space key pressed. Human intervention required.\n"
|
||||
"Place the leader in similar pose to the follower and press space again."
|
||||
)
|
||||
self.events["pause_policy"] = True
|
||||
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
|
||||
return
|
||||
if self.events["pause_policy"] and not self.events["human_intervention_step"]:
|
||||
self.events["human_intervention_step"] = True
|
||||
print("Space key pressed. Human intervention starting.")
|
||||
log_say("Starting human intervention.", play_sounds=True)
|
||||
return
|
||||
if self.events["pause_policy"] and self.events["human_intervention_step"]:
|
||||
self.events["pause_policy"] = False
|
||||
self.events["human_intervention_step"] = False
|
||||
print("Space key pressed for a third time.")
|
||||
log_say("Continuing with policy actions.", play_sounds=True)
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
self.listener = keyboard.Listener(on_press=on_press)
|
||||
self.listener.start()
|
||||
except ImportError:
|
||||
logging.warning("Could not import pynput. Keyboard interface will not be available.")
|
||||
self.listener = None
|
||||
|
||||
def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]:
|
||||
is_intervention = False
|
||||
terminated_by_keyboard = False
|
||||
|
||||
# Extract policy_action if needed
|
||||
if isinstance(self.env.action_space, gym.spaces.Tuple):
|
||||
policy_action = action[0]
|
||||
|
||||
# Check the event flags without holding the lock for too long.
|
||||
with self.event_lock:
|
||||
if self.events["exit_early"]:
|
||||
terminated_by_keyboard = True
|
||||
pause_policy = self.events["pause_policy"]
|
||||
|
||||
if pause_policy:
|
||||
# Now, wait for human_intervention_step without holding the lock
|
||||
while True:
|
||||
with self.event_lock:
|
||||
if self.events["human_intervention_step"]:
|
||||
is_intervention = True
|
||||
break
|
||||
time.sleep(0.1) # Check more frequently if desired
|
||||
|
||||
# Execute the step in the underlying environment
|
||||
obs, reward, terminated, truncated, info = self.env.step((policy_action, is_intervention))
|
||||
|
||||
# Override reward and termination if episode success event triggered
|
||||
with self.event_lock:
|
||||
if self.events["episode_success"]:
|
||||
reward = 1
|
||||
terminated_by_keyboard = True
|
||||
|
||||
return obs, reward, terminated or terminated_by_keyboard, truncated, info
|
||||
|
||||
def reset(self, **kwargs) -> Tuple[Any, Dict]:
|
||||
"""
|
||||
Reset the environment and clear any pending events
|
||||
"""
|
||||
with self.event_lock:
|
||||
self.events = {k: False for k in self.events}
|
||||
return self.env.reset(**kwargs)
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Properly clean up the keyboard listener when the environment is closed
|
||||
"""
|
||||
if self.listener is not None:
|
||||
self.listener.stop()
|
||||
super().close()
|
||||
|
||||
|
||||
class ResetWrapper(gym.Wrapper):
|
||||
def __init__(
|
||||
self, env: HILSerlRobotEnv, reset_fn: Optional[Callable[[], None]] = None, reset_time_s: float = 5
|
||||
):
|
||||
super().__init__(env)
|
||||
self.reset_fn = reset_fn
|
||||
self.reset_time_s = reset_time_s
|
||||
|
||||
self.robot = self.unwrapped.robot
|
||||
self.init_pos = self.unwrapped.initial_follower_position
|
||||
|
||||
def reset(self, *, seed=None, options=None):
|
||||
if self.reset_fn is not None:
|
||||
self.reset_fn(self.env)
|
||||
else:
|
||||
log_say(f"Manually reset the environment for {self.reset_time_s} seconds.", play_sounds=True)
|
||||
start_time = time.perf_counter()
|
||||
while time.perf_counter() - start_time < self.reset_time_s:
|
||||
self.robot.teleop_step()
|
||||
|
||||
log_say("Manual reseting of the environment done.", play_sounds=True)
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
|
||||
class BatchCompitableWrapper(gym.ObservationWrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
|
||||
def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
for key in observation:
|
||||
if "image" in key and observation[key].dim() == 3:
|
||||
observation[key] = observation[key].unsqueeze(0)
|
||||
if "state" in key and observation[key].dim() == 1:
|
||||
observation[key] = observation[key].unsqueeze(0)
|
||||
return observation
|
||||
|
||||
|
||||
# TODO: REMOVE TH
|
||||
|
||||
|
||||
def make_robot_env(
|
||||
robot,
|
||||
reward_classifier,
|
||||
cfg,
|
||||
n_envs: int = 1,
|
||||
) -> gym.vector.VectorEnv:
|
||||
"""
|
||||
Factory function to create a vectorized robot environment.
|
||||
|
||||
Args:
|
||||
robot: Robot instance to control
|
||||
reward_classifier: Classifier model for computing rewards
|
||||
cfg: Configuration object containing environment parameters
|
||||
n_envs: Number of environments to create in parallel. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
A vectorized gym environment with all the necessary wrappers applied.
|
||||
"""
|
||||
if "maniskill" in cfg.name:
|
||||
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
|
||||
env = make_maniskill(
|
||||
task=cfg.task,
|
||||
obs_mode=cfg.obs,
|
||||
control_mode=cfg.control_mode,
|
||||
render_mode=cfg.render_mode,
|
||||
sensor_configs={"width": cfg.render_size, "height": cfg.render_size},
|
||||
device=cfg.device,
|
||||
)
|
||||
return env
|
||||
# Create base environment
|
||||
env = HILSerlRobotEnv(
|
||||
robot=robot,
|
||||
display_cameras=cfg.wrapper.display_cameras,
|
||||
delta=cfg.wrapper.delta_action,
|
||||
use_delta_action_space=cfg.wrapper.use_relative_joint_positions,
|
||||
)
|
||||
|
||||
# Add observation and image processing
|
||||
env = ConvertToLeRobotObservation(env=env, device=cfg.device)
|
||||
if cfg.wrapper.crop_params_dict is not None:
|
||||
env = ImageCropResizeWrapper(
|
||||
env=env, crop_params_dict=cfg.wrapper.crop_params_dict, resize_size=cfg.wrapper.resize_size
|
||||
)
|
||||
|
||||
# Add reward computation and control wrappers
|
||||
env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
||||
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
||||
env = KeyboardInterfaceWrapper(env=env)
|
||||
env = ResetWrapper(env=env, reset_fn=None, reset_time_s=cfg.wrapper.reset_time_s)
|
||||
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
|
||||
env = BatchCompitableWrapper(env=env)
|
||||
|
||||
return env
|
||||
|
||||
# batched version of the env that returns an observation of shape (b, c)
|
||||
|
||||
|
||||
def get_classifier(pretrained_path, config_path, device="mps"):
|
||||
if pretrained_path is None or config_path is None:
|
||||
return None
|
||||
|
||||
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
|
||||
cfg = init_hydra_config(config_path)
|
||||
|
||||
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
||||
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
|
||||
model = Classifier(classifier_config)
|
||||
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
|
||||
model = model.to(device)
|
||||
return model
|
||||
|
||||
|
||||
def replay_episode(env, repo_id, root=None, episode=0):
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
local_files_only = root is not None
|
||||
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
|
||||
for idx in range(dataset.num_frames):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action = actions[idx]["action"][:4]
|
||||
print(action)
|
||||
env.step((action / env.unwrapped.delta, False))
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
busy_wait(1 / 10 - dt_s)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--fps", type=int, default=30, help="control frequency")
|
||||
parser.add_argument(
|
||||
"--robot-path",
|
||||
type=str,
|
||||
default="lerobot/configs/robot/koch.yaml",
|
||||
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-overrides",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--pretrained-policy-name-or-path",
|
||||
help=(
|
||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
|
||||
"(useful for debugging). This argument is mutually exclusive with `--config`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
help=(
|
||||
"Path to a yaml config you want to use for initializing a policy from scratch (useful for "
|
||||
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--display-cameras", help=("Whether to display the camera feed while the rollout is happening")
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-classifier-pretrained-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the pretrained classifier weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-classifier-config-file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to a yaml config file that is necessary to build the reward classifier model.",
|
||||
)
|
||||
parser.add_argument("--env-path", type=str, default=None, help="Path to the env yaml file")
|
||||
parser.add_argument("--env-overrides", type=str, default=None, help="Overrides for the env yaml file")
|
||||
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds")
|
||||
parser.add_argument("--reset-follower-pos", type=int, default=1, help="Reset follower between episodes")
|
||||
parser.add_argument("--replay-repo-id", type=str, default=None, help="Repo ID of the episode to replay")
|
||||
parser.add_argument("--replay-root", type=str, default=None, help="Root of the dataset to replay")
|
||||
parser.add_argument("--replay-episode", type=int, default=0, help="Episode to replay")
|
||||
args = parser.parse_args()
|
||||
|
||||
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
|
||||
robot = make_robot(robot_cfg)
|
||||
|
||||
reward_classifier = get_classifier(
|
||||
args.reward_classifier_pretrained_path, args.reward_classifier_config_file
|
||||
)
|
||||
user_relative_joint_positions = True
|
||||
|
||||
cfg = init_hydra_config(args.env_path, args.env_overrides)
|
||||
env = make_robot_env(
|
||||
robot,
|
||||
reward_classifier,
|
||||
cfg.env, # .wrapper,
|
||||
)
|
||||
|
||||
env.reset()
|
||||
|
||||
if args.replay_repo_id is not None:
|
||||
replay_episode(env, args.replay_repo_id, root=args.replay_root, episode=args.replay_episode)
|
||||
exit()
|
||||
|
||||
# Retrieve the robot's action space for joint commands.
|
||||
action_space_robot = env.action_space.spaces[0]
|
||||
|
||||
# Initialize the smoothed action as a random sample.
|
||||
smoothed_action = action_space_robot.sample()
|
||||
|
||||
# Smoothing coefficient (alpha) defines how much of the new random sample to mix in.
|
||||
# A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth.
|
||||
alpha = 0.4
|
||||
|
||||
while True:
|
||||
start_loop_s = time.perf_counter()
|
||||
# Sample a new random action from the robot's action space.
|
||||
new_random_action = action_space_robot.sample()
|
||||
# Update the smoothed action using an exponential moving average.
|
||||
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
|
||||
|
||||
# Execute the step: wrap the NumPy action in a torch tensor.
|
||||
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
|
||||
if terminated or truncated:
|
||||
env.reset()
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_s
|
||||
busy_wait(1 / args.fps - dt_s)
|
||||
@@ -1,58 +0,0 @@
|
||||
// !/usr/bin/env python
|
||||
|
||||
// Copyright 2024 The HuggingFace Inc. team.
|
||||
// All rights reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
syntax = "proto3";
|
||||
|
||||
package hil_serl;
|
||||
|
||||
// LearnerService: the Actor calls this to push transitions.
|
||||
// The Learner implements this service.
|
||||
service LearnerService {
|
||||
// Actor -> Learner to store transitions
|
||||
rpc SendTransition(Transition) returns (Empty);
|
||||
rpc SendInteractionMessage(InteractionMessage) returns (Empty);
|
||||
}
|
||||
|
||||
// ActorService: the Learner calls this to push parameters.
|
||||
// The Actor implements this service.
|
||||
service ActorService {
|
||||
// Learner -> Actor to send new parameters
|
||||
rpc StreamTransition(Empty) returns (stream ActorInformation) {};
|
||||
rpc SendParameters(Parameters) returns (Empty);
|
||||
}
|
||||
|
||||
|
||||
message ActorInformation {
|
||||
oneof data {
|
||||
Transition transition = 1;
|
||||
InteractionMessage interaction_message = 2;
|
||||
}
|
||||
}
|
||||
|
||||
// Messages
|
||||
message Transition {
|
||||
bytes transition_bytes = 1;
|
||||
}
|
||||
|
||||
message Parameters {
|
||||
bytes parameter_bytes = 1;
|
||||
}
|
||||
|
||||
message InteractionMessage {
|
||||
bytes interaction_message_bytes = 1;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
@@ -1,48 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: hilserl.proto
|
||||
# Protobuf Python Version: 5.29.0
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import runtime_version as _runtime_version
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||
_runtime_version.Domain.PUBLIC,
|
||||
5,
|
||||
29,
|
||||
0,
|
||||
'',
|
||||
'hilserl.proto'
|
||||
)
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"\x83\x01\n\x10\x41\x63torInformation\x12*\n\ntransition\x18\x01 \x01(\x0b\x32\x14.hil_serl.TransitionH\x00\x12;\n\x13interaction_message\x18\x02 \x01(\x0b\x32\x1c.hil_serl.InteractionMessageH\x00\x42\x06\n\x04\x64\x61ta\"&\n\nTransition\x12\x18\n\x10transition_bytes\x18\x01 \x01(\x0c\"%\n\nParameters\x12\x17\n\x0fparameter_bytes\x18\x01 \x01(\x0c\"7\n\x12InteractionMessage\x12!\n\x19interaction_message_bytes\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty2\x92\x01\n\x0eLearnerService\x12\x37\n\x0eSendTransition\x12\x14.hil_serl.Transition\x1a\x0f.hil_serl.Empty\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty2\x8c\x01\n\x0c\x41\x63torService\x12\x43\n\x10StreamTransition\x12\x0f.hil_serl.Empty\x1a\x1a.hil_serl.ActorInformation\"\x00\x30\x01\x12\x37\n\x0eSendParameters\x12\x14.hil_serl.Parameters\x1a\x0f.hil_serl.Emptyb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'hilserl_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_ACTORINFORMATION']._serialized_start=28
|
||||
_globals['_ACTORINFORMATION']._serialized_end=159
|
||||
_globals['_TRANSITION']._serialized_start=161
|
||||
_globals['_TRANSITION']._serialized_end=199
|
||||
_globals['_PARAMETERS']._serialized_start=201
|
||||
_globals['_PARAMETERS']._serialized_end=238
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_start=240
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_end=295
|
||||
_globals['_EMPTY']._serialized_start=297
|
||||
_globals['_EMPTY']._serialized_end=304
|
||||
_globals['_LEARNERSERVICE']._serialized_start=307
|
||||
_globals['_LEARNERSERVICE']._serialized_end=453
|
||||
_globals['_ACTORSERVICE']._serialized_start=456
|
||||
_globals['_ACTORSERVICE']._serialized_end=596
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
@@ -1,269 +0,0 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
import hilserl_pb2 as hilserl__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.70.0'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
_version_not_supported = False
|
||||
|
||||
try:
|
||||
from grpc._utilities import first_version_is_lower
|
||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||
except ImportError:
|
||||
_version_not_supported = True
|
||||
|
||||
if _version_not_supported:
|
||||
raise RuntimeError(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ f' but the generated code in hilserl_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
)
|
||||
|
||||
|
||||
class LearnerServiceStub(object):
|
||||
"""LearnerService: the Actor calls this to push transitions.
|
||||
The Learner implements this service.
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.SendTransition = channel.unary_unary(
|
||||
'/hil_serl.LearnerService/SendTransition',
|
||||
request_serializer=hilserl__pb2.Transition.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.SendInteractionMessage = channel.unary_unary(
|
||||
'/hil_serl.LearnerService/SendInteractionMessage',
|
||||
request_serializer=hilserl__pb2.InteractionMessage.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class LearnerServiceServicer(object):
|
||||
"""LearnerService: the Actor calls this to push transitions.
|
||||
The Learner implements this service.
|
||||
"""
|
||||
|
||||
def SendTransition(self, request, context):
|
||||
"""Actor -> Learner to store transitions
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendInteractionMessage(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_LearnerServiceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'SendTransition': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendTransition,
|
||||
request_deserializer=hilserl__pb2.Transition.FromString,
|
||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'SendInteractionMessage': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendInteractionMessage,
|
||||
request_deserializer=hilserl__pb2.InteractionMessage.FromString,
|
||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'hil_serl.LearnerService', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('hil_serl.LearnerService', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class LearnerService(object):
|
||||
"""LearnerService: the Actor calls this to push transitions.
|
||||
The Learner implements this service.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def SendTransition(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/hil_serl.LearnerService/SendTransition',
|
||||
hilserl__pb2.Transition.SerializeToString,
|
||||
hilserl__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendInteractionMessage(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/hil_serl.LearnerService/SendInteractionMessage',
|
||||
hilserl__pb2.InteractionMessage.SerializeToString,
|
||||
hilserl__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class ActorServiceStub(object):
|
||||
"""ActorService: the Learner calls this to push parameters.
|
||||
The Actor implements this service.
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.StreamTransition = channel.unary_stream(
|
||||
'/hil_serl.ActorService/StreamTransition',
|
||||
request_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.ActorInformation.FromString,
|
||||
_registered_method=True)
|
||||
self.SendParameters = channel.unary_unary(
|
||||
'/hil_serl.ActorService/SendParameters',
|
||||
request_serializer=hilserl__pb2.Parameters.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class ActorServiceServicer(object):
|
||||
"""ActorService: the Learner calls this to push parameters.
|
||||
The Actor implements this service.
|
||||
"""
|
||||
|
||||
def StreamTransition(self, request, context):
|
||||
"""Learner -> Actor to send new parameters
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendParameters(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_ActorServiceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'StreamTransition': grpc.unary_stream_rpc_method_handler(
|
||||
servicer.StreamTransition,
|
||||
request_deserializer=hilserl__pb2.Empty.FromString,
|
||||
response_serializer=hilserl__pb2.ActorInformation.SerializeToString,
|
||||
),
|
||||
'SendParameters': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendParameters,
|
||||
request_deserializer=hilserl__pb2.Parameters.FromString,
|
||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'hil_serl.ActorService', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('hil_serl.ActorService', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class ActorService(object):
|
||||
"""ActorService: the Learner calls this to push parameters.
|
||||
The Actor implements this service.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def StreamTransition(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_stream(
|
||||
request,
|
||||
target,
|
||||
'/hil_serl.ActorService/StreamTransition',
|
||||
hilserl__pb2.Empty.SerializeToString,
|
||||
hilserl__pb2.ActorInformation.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendParameters(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/hil_serl.ActorService/SendParameters',
|
||||
hilserl__pb2.Parameters.SerializeToString,
|
||||
hilserl__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
@@ -1,676 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import io
|
||||
import logging
|
||||
import pickle
|
||||
import queue
|
||||
import shutil
|
||||
import time
|
||||
from pprint import pformat
|
||||
from threading import Lock, Thread
|
||||
|
||||
import grpc
|
||||
|
||||
# Import generated stubs
|
||||
import hilserl_pb2 # type: ignore
|
||||
import hilserl_pb2_grpc # type: ignore
|
||||
import hydra
|
||||
import torch
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch import nn
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
|
||||
# TODO: Remove the import of maniskill
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_global_random_state,
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
set_global_random_state,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.scripts.server.buffer import (
|
||||
ReplayBuffer,
|
||||
concatenate_batch_transitions,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
transition_queue = queue.Queue()
|
||||
interaction_message_queue = queue.Queue()
|
||||
|
||||
|
||||
def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
|
||||
if not cfg.resume:
|
||||
if Logger.get_last_checkpoint_dir(out_dir).exists():
|
||||
raise RuntimeError(
|
||||
f"Output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. "
|
||||
"Use `resume=true` to resume training."
|
||||
)
|
||||
return cfg
|
||||
|
||||
# if resume == True
|
||||
checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir)
|
||||
if not checkpoint_dir.exists():
|
||||
raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True")
|
||||
|
||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
||||
logging.info(
|
||||
colored(
|
||||
"Resume=True detected, resuming previous run",
|
||||
color="yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
|
||||
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
|
||||
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
||||
|
||||
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
|
||||
del diff["values_changed"]["root['resume']"]
|
||||
|
||||
if len(diff) > 0:
|
||||
logging.warning(
|
||||
f"Differences between the checkpoint config and the provided config detected: \n{pformat(diff)}\n"
|
||||
"Checkpoint configuration takes precedence."
|
||||
)
|
||||
|
||||
checkpoint_cfg.resume = True
|
||||
return checkpoint_cfg
|
||||
|
||||
|
||||
def load_training_state(
|
||||
cfg: DictConfig,
|
||||
logger: Logger,
|
||||
optimizers: Optimizer | dict,
|
||||
):
|
||||
if not cfg.resume:
|
||||
return None, None
|
||||
|
||||
training_state = torch.load(logger.last_checkpoint_dir / logger.training_state_file_name)
|
||||
|
||||
if isinstance(training_state["optimizer"], dict):
|
||||
assert set(training_state["optimizer"].keys()) == set(optimizers.keys())
|
||||
for k, v in training_state["optimizer"].items():
|
||||
optimizers[k].load_state_dict(v)
|
||||
else:
|
||||
optimizers.load_state_dict(training_state["optimizer"])
|
||||
|
||||
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
|
||||
return training_state["step"], training_state["interaction_step"]
|
||||
|
||||
|
||||
def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
log_output_dir(out_dir)
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.training.online_steps=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
|
||||
def initialize_replay_buffer(cfg: DictConfig, logger: Logger, device: str) -> ReplayBuffer:
|
||||
if not cfg.resume:
|
||||
return ReplayBuffer(
|
||||
capacity=cfg.training.online_buffer_capacity,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
)
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=cfg.dataset_repo_id, local_files_only=True, root=logger.log_dir / "dataset"
|
||||
)
|
||||
return ReplayBuffer.from_lerobot_dataset(
|
||||
lerobot_dataset=dataset,
|
||||
capacity=cfg.training.online_buffer_capacity,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
)
|
||||
|
||||
|
||||
def start_learner_threads(
|
||||
cfg: DictConfig,
|
||||
device: str,
|
||||
replay_buffer: ReplayBuffer,
|
||||
offline_replay_buffer: ReplayBuffer,
|
||||
batch_size: int,
|
||||
optimizers: dict,
|
||||
policy: SACPolicy,
|
||||
policy_lock: Lock,
|
||||
logger: Logger,
|
||||
resume_optimization_step: int | None = None,
|
||||
resume_interaction_step: int | None = None,
|
||||
) -> None:
|
||||
actor_ip = cfg.actor_learner_config.actor_ip
|
||||
port = cfg.actor_learner_config.port
|
||||
|
||||
server_thread = Thread(
|
||||
target=stream_transitions_from_actor,
|
||||
args=(
|
||||
actor_ip,
|
||||
port,
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
transition_thread = Thread(
|
||||
target=add_actor_information_and_train,
|
||||
daemon=True,
|
||||
args=(
|
||||
cfg,
|
||||
device,
|
||||
replay_buffer,
|
||||
offline_replay_buffer,
|
||||
batch_size,
|
||||
optimizers,
|
||||
policy,
|
||||
policy_lock,
|
||||
logger,
|
||||
resume_optimization_step,
|
||||
resume_interaction_step,
|
||||
),
|
||||
)
|
||||
|
||||
param_push_thread = Thread(
|
||||
target=learner_push_parameters,
|
||||
args=(policy, policy_lock, actor_ip, port, 15),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
server_thread.start()
|
||||
transition_thread.start()
|
||||
param_push_thread.start()
|
||||
param_push_thread.join()
|
||||
transition_thread.join()
|
||||
server_thread.join()
|
||||
|
||||
|
||||
def stream_transitions_from_actor(host="127.0.0.1", port=50051):
|
||||
"""
|
||||
Runs a gRPC client that listens for transition and interaction messages from an Actor service.
|
||||
|
||||
This function establishes a gRPC connection with the given `host` and `port`, then continuously
|
||||
streams transition data from the `ActorServiceStub`. The received transition data is deserialized
|
||||
and stored in a queue (`transition_queue`). Similarly, interaction messages are also deserialized
|
||||
and stored in a separate queue (`interaction_message_queue`).
|
||||
|
||||
Args:
|
||||
host (str, optional): The IP address or hostname of the gRPC server. Defaults to `"127.0.0.1"`.
|
||||
port (int, optional): The port number on which the gRPC server is running. Defaults to `50051`.
|
||||
|
||||
"""
|
||||
# NOTE: This is waiting for the handshake to be done
|
||||
# In the future we will do it in a canonical way with a proper handshake
|
||||
time.sleep(10)
|
||||
channel = grpc.insecure_channel(
|
||||
f"{host}:{port}",
|
||||
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
|
||||
)
|
||||
stub = hilserl_pb2_grpc.ActorServiceStub(channel)
|
||||
for response in stub.StreamTransition(hilserl_pb2.Empty()):
|
||||
if response.HasField("transition"):
|
||||
buffer = io.BytesIO(response.transition.transition_bytes)
|
||||
transition = torch.load(buffer)
|
||||
transition_queue.put(transition)
|
||||
if response.HasField("interaction_message"):
|
||||
content = pickle.loads(response.interaction_message.interaction_message_bytes)
|
||||
interaction_message_queue.put(content)
|
||||
|
||||
|
||||
def learner_push_parameters(
|
||||
policy: nn.Module, policy_lock: Lock, actor_host="127.0.0.1", actor_port=50052, seconds_between_pushes=5
|
||||
):
|
||||
"""
|
||||
As a client, connect to the Actor's gRPC server (ActorService)
|
||||
and periodically push new parameters.
|
||||
"""
|
||||
time.sleep(10)
|
||||
channel = grpc.insecure_channel(
|
||||
f"{actor_host}:{actor_port}",
|
||||
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
|
||||
)
|
||||
actor_stub = hilserl_pb2_grpc.ActorServiceStub(channel)
|
||||
|
||||
while True:
|
||||
with policy_lock:
|
||||
params_dict = policy.actor.state_dict()
|
||||
if policy.config.vision_encoder_name is not None:
|
||||
if policy.config.freeze_vision_encoder:
|
||||
params_dict: dict[str, torch.Tensor] = {
|
||||
k: v for k, v in params_dict.items() if not k.startswith("encoder.")
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
|
||||
)
|
||||
|
||||
params_dict = move_state_dict_to_device(params_dict, device="cpu")
|
||||
# Serialize
|
||||
buf = io.BytesIO()
|
||||
torch.save(params_dict, buf)
|
||||
params_bytes = buf.getvalue()
|
||||
|
||||
# Push them to the Actor's "SendParameters" method
|
||||
logging.info("[LEARNER] Publishing parameters to the Actor")
|
||||
response = actor_stub.SendParameters(hilserl_pb2.Parameters(parameter_bytes=params_bytes)) # noqa: F841
|
||||
time.sleep(seconds_between_pushes)
|
||||
|
||||
|
||||
def check_nan_in_transition(observations: torch.Tensor, actions: torch.Tensor, next_state: torch.Tensor):
|
||||
for k in observations:
|
||||
if torch.isnan(observations[k]).any():
|
||||
logging.error(f"observations[{k}] contains NaN values")
|
||||
for k in next_state:
|
||||
if torch.isnan(next_state[k]).any():
|
||||
logging.error(f"next_state[{k}] contains NaN values")
|
||||
if torch.isnan(actions).any():
|
||||
logging.error("actions contains NaN values")
|
||||
|
||||
|
||||
def add_actor_information_and_train(
|
||||
cfg,
|
||||
device: str,
|
||||
replay_buffer: ReplayBuffer,
|
||||
offline_replay_buffer: ReplayBuffer,
|
||||
batch_size: int,
|
||||
optimizers: dict[str, torch.optim.Optimizer],
|
||||
policy: nn.Module,
|
||||
policy_lock: Lock,
|
||||
logger: Logger,
|
||||
resume_optimization_step: int | None = None,
|
||||
resume_interaction_step: int | None = None,
|
||||
):
|
||||
"""
|
||||
Handles data transfer from the actor to the learner, manages training updates,
|
||||
and logs training progress in an online reinforcement learning setup.
|
||||
|
||||
This function continuously:
|
||||
- Transfers transitions from the actor to the replay buffer.
|
||||
- Logs received interaction messages.
|
||||
- Ensures training begins only when the replay buffer has a sufficient number of transitions.
|
||||
- Samples batches from the replay buffer and performs multiple critic updates.
|
||||
- Periodically updates the actor, critic, and temperature optimizers.
|
||||
- Logs training statistics, including loss values and optimization frequency.
|
||||
|
||||
**NOTE:**
|
||||
- This function performs multiple responsibilities (data transfer, training, and logging).
|
||||
It should ideally be split into smaller functions in the future.
|
||||
- Due to Python's **Global Interpreter Lock (GIL)**, running separate threads for different tasks
|
||||
significantly reduces performance. Instead, this function executes all operations in a single thread.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object containing hyperparameters.
|
||||
device (str): The computing device (`"cpu"` or `"cuda"`).
|
||||
replay_buffer (ReplayBuffer): The primary replay buffer storing online transitions.
|
||||
offline_replay_buffer (ReplayBuffer): An additional buffer for offline transitions.
|
||||
batch_size (int): The number of transitions to sample per training step.
|
||||
optimizers (Dict[str, torch.optim.Optimizer]): A dictionary of optimizers (`"actor"`, `"critic"`, `"temperature"`).
|
||||
policy (nn.Module): The reinforcement learning policy with critic, actor, and temperature parameters.
|
||||
policy_lock (Lock): A threading lock to ensure safe policy updates.
|
||||
logger (Logger): Logger instance for tracking training progress.
|
||||
resume_optimization_step (int | None): In the case of resume training, start from the last optimization step reached.
|
||||
resume_interaction_step (int | None): In the case of resume training, shift the interaction step with the last saved step in order to not break logging.
|
||||
"""
|
||||
# NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
|
||||
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
|
||||
# are divided by 200. So we need to have a single thread that does all the work.
|
||||
time.time()
|
||||
logging.info("Starting learner thread")
|
||||
interaction_message, transition = None, None
|
||||
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
|
||||
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
|
||||
while True:
|
||||
while not transition_queue.empty():
|
||||
transition_list = transition_queue.get()
|
||||
for transition in transition_list:
|
||||
transition = move_transition_to_device(transition, device=device)
|
||||
replay_buffer.add(**transition)
|
||||
|
||||
if transition.get("complementary_info", {}).get("is_intervention"):
|
||||
offline_replay_buffer.add(**transition)
|
||||
|
||||
while not interaction_message_queue.empty():
|
||||
interaction_message = interaction_message_queue.get()
|
||||
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
|
||||
interaction_message["Interaction step"] += interaction_step_shift
|
||||
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
|
||||
# logging.info(f"Interaction message: {interaction_message}")
|
||||
|
||||
if len(replay_buffer) < cfg.training.online_step_before_learning:
|
||||
continue
|
||||
|
||||
# logging.info(f"Size of replay buffer: {len(replay_buffer)}")
|
||||
# logging.info(f"Size of offline replay buffer: {len(offline_replay_buffer)}")
|
||||
|
||||
time_for_one_optimization_step = time.time()
|
||||
for _ in range(cfg.policy.utd_ratio - 1):
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
|
||||
if cfg.dataset_repo_id is not None:
|
||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
batch = concatenate_batch_transitions(batch, batch_offline)
|
||||
|
||||
actions = batch["action"]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
# Precompute encoder features from the frozen vision encoder if enabled
|
||||
obs_features, next_obs_features = None, None
|
||||
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
|
||||
with torch.no_grad():
|
||||
obs_features = (
|
||||
policy.actor.encoder(observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
)
|
||||
next_obs_features = (
|
||||
policy.actor.encoder(next_observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
)
|
||||
|
||||
with policy_lock:
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
obs_features=obs_features, # pass precomputed features
|
||||
next_obs_features=next_obs_features, # for target computation
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
|
||||
if cfg.dataset_repo_id is not None:
|
||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
||||
actions = batch["action"]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
# Precompute encoder features from the frozen vision encoder if enabled
|
||||
obs_features, next_obs_features = None, None
|
||||
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
|
||||
with torch.no_grad():
|
||||
obs_features = (
|
||||
policy.actor.encoder(observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
)
|
||||
next_obs_features = (
|
||||
policy.actor.encoder(next_observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
)
|
||||
|
||||
with policy_lock:
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
obs_features=obs_features, # pass precomputed features
|
||||
next_obs_features=next_obs_features, # for target computation
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
training_infos = {}
|
||||
training_infos["loss_critic"] = loss_critic.item()
|
||||
|
||||
if optimization_step % cfg.training.policy_update_freq == 0:
|
||||
for _ in range(cfg.training.policy_update_freq):
|
||||
with policy_lock:
|
||||
loss_actor = policy.compute_loss_actor(
|
||||
observations=observations,
|
||||
obs_features=obs_features, # reuse precomputed features here
|
||||
)
|
||||
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
|
||||
loss_temperature = policy.compute_loss_temperature(
|
||||
observations=observations,
|
||||
obs_features=obs_features, # and for temperature loss as well
|
||||
)
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
optimizers["temperature"].step()
|
||||
|
||||
training_infos["loss_temperature"] = loss_temperature.item()
|
||||
|
||||
policy.update_target_networks()
|
||||
if optimization_step % cfg.training.log_freq == 0:
|
||||
training_infos["Optimization step"] = optimization_step
|
||||
logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step")
|
||||
# logging.info(f"Training infos: {training_infos}")
|
||||
|
||||
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
||||
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
|
||||
|
||||
logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
|
||||
|
||||
logger.log_dict(
|
||||
{
|
||||
"Optimization frequency loop [Hz]": frequency_for_one_optimization_step,
|
||||
"Optimization step": optimization_step,
|
||||
},
|
||||
mode="train",
|
||||
custom_step_key="Optimization step",
|
||||
)
|
||||
|
||||
optimization_step += 1
|
||||
if optimization_step % cfg.training.log_freq == 0:
|
||||
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
|
||||
|
||||
if cfg.training.save_checkpoint and (
|
||||
optimization_step % cfg.training.save_freq == 0 or optimization_step == cfg.training.online_steps
|
||||
):
|
||||
logging.info(f"Checkpoint policy after step {optimization_step}")
|
||||
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
||||
# needed (choose 6 as a minimum for consistency without being overkill).
|
||||
_num_digits = max(6, len(str(cfg.training.online_steps)))
|
||||
step_identifier = f"{optimization_step:0{_num_digits}d}"
|
||||
interaction_step = (
|
||||
interaction_message["Interaction step"] if interaction_message is not None else 0
|
||||
)
|
||||
logger.save_checkpoint(
|
||||
optimization_step,
|
||||
policy,
|
||||
optimizers,
|
||||
scheduler=None,
|
||||
identifier=step_identifier,
|
||||
interaction_step=interaction_step,
|
||||
)
|
||||
|
||||
# TODO : temporarly save replay buffer here, remove later when on the robot
|
||||
# We want to control this with the keyboard inputs
|
||||
dataset_dir = logger.log_dir / "dataset"
|
||||
if dataset_dir.exists() and dataset_dir.is_dir():
|
||||
shutil.rmtree(
|
||||
dataset_dir,
|
||||
)
|
||||
replay_buffer.to_lerobot_dataset(
|
||||
cfg.dataset_repo_id, fps=cfg.fps, root=logger.log_dir / "dataset"
|
||||
)
|
||||
|
||||
logging.info("Resume training")
|
||||
|
||||
|
||||
def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||
"""
|
||||
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
|
||||
|
||||
This function sets up Adam optimizers for:
|
||||
- The **actor network**, ensuring that only relevant parameters are optimized.
|
||||
- The **critic ensemble**, which evaluates the value function.
|
||||
- The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods.
|
||||
|
||||
It also initializes a learning rate scheduler, though currently, it is set to `None`.
|
||||
|
||||
**NOTE:**
|
||||
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
|
||||
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object containing hyperparameters.
|
||||
policy (nn.Module): The policy model containing the actor, critic, and temperature components.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]:
|
||||
A tuple containing:
|
||||
- `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers.
|
||||
- `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling.
|
||||
|
||||
"""
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
|
||||
params=policy.actor.parameters_to_optimize,
|
||||
lr=policy.config.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
"critic": optimizer_critic,
|
||||
"temperature": optimizer_temperature,
|
||||
}
|
||||
return optimizers, lr_scheduler
|
||||
|
||||
|
||||
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
if job_name is None:
|
||||
raise NotImplementedError()
|
||||
|
||||
init_logging()
|
||||
logging.info(pformat(OmegaConf.to_container(cfg)))
|
||||
|
||||
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
||||
cfg = handle_resume_logic(cfg, out_dir)
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("make_policy")
|
||||
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
# TODO: At some point we should just need make sac policy
|
||||
|
||||
policy_lock = Lock()
|
||||
policy: SACPolicy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online traning, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
)
|
||||
# compile policy
|
||||
# policy = torch.compile(policy)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers)
|
||||
|
||||
log_training_info(cfg, out_dir, policy)
|
||||
|
||||
replay_buffer = initialize_replay_buffer(cfg, logger, device)
|
||||
batch_size = cfg.training.batch_size
|
||||
offline_replay_buffer = None
|
||||
|
||||
if cfg.dataset_repo_id is not None:
|
||||
logging.info("make_dataset offline buffer")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
logging.info("Convertion to a offline replay buffer")
|
||||
active_action_dims = [i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask]
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
offline_dataset,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
action_mask=active_action_dims,
|
||||
action_delta=cfg.env.wrapper.delta_action,
|
||||
)
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
start_learner_threads(
|
||||
cfg,
|
||||
device,
|
||||
replay_buffer,
|
||||
offline_replay_buffer,
|
||||
batch_size,
|
||||
optimizers,
|
||||
policy,
|
||||
policy_lock,
|
||||
logger,
|
||||
resume_optimization_step,
|
||||
resume_interaction_step,
|
||||
)
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
||||
def train_cli(cfg: dict):
|
||||
train(
|
||||
cfg,
|
||||
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
||||
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_cli()
|
||||
@@ -1,176 +0,0 @@
|
||||
import einops
|
||||
import numpy as np
|
||||
import gymnasium as gym
|
||||
import torch
|
||||
|
||||
"""Make ManiSkill3 gym environment"""
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
|
||||
|
||||
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, torch.Tensor]:
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
Args:
|
||||
observation: Dictionary of observation batches from a Gym vector environment.
|
||||
Returns:
|
||||
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
||||
"""
|
||||
# map to expected inputs for the policy
|
||||
return_observations = {}
|
||||
# TODO: You have to merge all tensors from agent key and extra key
|
||||
# You don't keep sensor param key in the observation
|
||||
# And you keep sensor data rgb
|
||||
q_pos = observations["agent"]["qpos"]
|
||||
q_vel = observations["agent"]["qvel"]
|
||||
tcp_pos = observations["extra"]["tcp_pose"]
|
||||
img = observations["sensor_data"]["base_camera"]["rgb"]
|
||||
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
|
||||
|
||||
return_observations["observation.image"] = img
|
||||
return_observations["observation.state"] = state
|
||||
return return_observations
|
||||
|
||||
|
||||
class ManiSkillObservationWrapper(gym.ObservationWrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
|
||||
def observation(self, observation):
|
||||
return preprocess_maniskill_observation(observation)
|
||||
|
||||
|
||||
class ManiSkillToDeviceWrapper(gym.Wrapper):
|
||||
def __init__(self, env, device: torch.device = "cuda"):
|
||||
super().__init__(env)
|
||||
self.device = device
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
obs, info = self.env.reset(seed=seed, options=options)
|
||||
obs = {k: v.to(self.device) for k, v in obs.items()}
|
||||
return obs, info
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
obs = {k: v.to(self.device) for k, v in obs.items()}
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
|
||||
class ManiSkillCompat(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
reward = reward.item()
|
||||
terminated = terminated.item()
|
||||
truncated = truncated.item()
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
|
||||
class ManiSkillActionWrapper(gym.ActionWrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
self.action_space = gym.spaces.Tuple(spaces=(env.action_space, gym.spaces.Discrete(2)))
|
||||
|
||||
def action(self, action):
|
||||
action, telop = action
|
||||
return action
|
||||
|
||||
|
||||
class ManiSkillMultiplyActionWrapper(gym.Wrapper):
|
||||
def __init__(self, env, multiply_factor: float = 10):
|
||||
super().__init__(env)
|
||||
self.multiply_factor = multiply_factor
|
||||
action_space_agent: gym.spaces.Box = env.action_space[0]
|
||||
action_space_agent.low = action_space_agent.low * multiply_factor
|
||||
action_space_agent.high = action_space_agent.high * multiply_factor
|
||||
self.action_space = gym.spaces.Tuple(spaces=(action_space_agent, gym.spaces.Discrete(2)))
|
||||
|
||||
def step(self, action):
|
||||
if isinstance(action, tuple):
|
||||
action, telop = action
|
||||
else:
|
||||
telop = 0
|
||||
action = action / self.multiply_factor
|
||||
obs, reward, terminated, truncated, info = self.env.step((action, telop))
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
|
||||
def make_maniskill(
|
||||
task: str = "PushCube-v1",
|
||||
obs_mode: str = "rgb",
|
||||
control_mode: str = "pd_ee_delta_pose",
|
||||
render_mode: str = "rgb_array",
|
||||
sensor_configs: dict[str, int] | None = None,
|
||||
n_envs: int = 1,
|
||||
device: torch.device = "cuda",
|
||||
) -> gym.Env:
|
||||
"""
|
||||
Factory function to create a ManiSkill environment with standard wrappers.
|
||||
|
||||
Args:
|
||||
task: Name of the ManiSkill task
|
||||
obs_mode: Observation mode (rgb, rgbd, etc)
|
||||
control_mode: Control mode for the robot
|
||||
render_mode: Rendering mode
|
||||
sensor_configs: Camera sensor configurations
|
||||
n_envs: Number of parallel environments
|
||||
|
||||
Returns:
|
||||
A wrapped ManiSkill environment
|
||||
"""
|
||||
if sensor_configs is None:
|
||||
sensor_configs = {"width": 64, "height": 64}
|
||||
|
||||
env = gym.make(
|
||||
task,
|
||||
obs_mode=obs_mode,
|
||||
control_mode=control_mode,
|
||||
render_mode=render_mode,
|
||||
sensor_configs=sensor_configs,
|
||||
num_envs=n_envs,
|
||||
)
|
||||
env = ManiSkillCompat(env)
|
||||
env = ManiSkillObservationWrapper(env)
|
||||
env = ManiSkillActionWrapper(env)
|
||||
env = ManiSkillMultiplyActionWrapper(env)
|
||||
env = ManiSkillToDeviceWrapper(env, device=device)
|
||||
return env
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import hydra
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default="lerobot/configs/env/maniskill_example.yaml")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize config
|
||||
with hydra.initialize(version_base=None, config_path="../../configs"):
|
||||
cfg = hydra.compose(config_name="env/maniskill_example.yaml")
|
||||
|
||||
env = make_maniskill(
|
||||
task=cfg.env.task,
|
||||
obs_mode=cfg.env.obs,
|
||||
control_mode=cfg.env.control_mode,
|
||||
render_mode=cfg.env.render_mode,
|
||||
sensor_configs={"width": cfg.env.render_size, "height": cfg.env.render_size},
|
||||
)
|
||||
|
||||
print("env done")
|
||||
obs, info = env.reset()
|
||||
random_action = env.action_space.sample()
|
||||
obs, reward, terminated, truncated, info = env.step(random_action)
|
||||
@@ -1,3 +1,5 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -18,7 +20,6 @@ from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import wandb
|
||||
@@ -26,9 +27,8 @@ from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch import optim
|
||||
from torch.autograd import profiler
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler, random_split
|
||||
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
||||
@@ -43,7 +43,6 @@ from lerobot.common.utils.utils import (
|
||||
init_hydra_config,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.scripts.server.buffer import random_shift
|
||||
|
||||
|
||||
def get_model(cfg, logger): # noqa I001
|
||||
@@ -81,7 +80,6 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
|
||||
for batch_idx, batch in enumerate(pbar):
|
||||
start_time = time.perf_counter()
|
||||
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
|
||||
images = [random_shift(img, 4) for img in images]
|
||||
labels = batch[cfg.training.label_key].float().to(device)
|
||||
|
||||
# Forward pass with optional AMP
|
||||
@@ -118,7 +116,7 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
|
||||
pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{current_acc:.2f}%"})
|
||||
|
||||
|
||||
def validate(model, val_loader, criterion, device, logger, cfg):
|
||||
def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_log=8):
|
||||
# Validation loop with metric tracking and sample logging
|
||||
model.eval()
|
||||
correct = 0
|
||||
@@ -126,7 +124,6 @@ def validate(model, val_loader, criterion, device, logger, cfg):
|
||||
batch_start_time = time.perf_counter()
|
||||
samples = []
|
||||
running_loss = 0
|
||||
inference_times = []
|
||||
|
||||
with (
|
||||
torch.no_grad(),
|
||||
@@ -136,18 +133,7 @@ def validate(model, val_loader, criterion, device, logger, cfg):
|
||||
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
|
||||
labels = batch[cfg.training.label_key].float().to(device)
|
||||
|
||||
if cfg.training.profile_inference_time and logger._cfg.wandb.enable:
|
||||
with (
|
||||
profiler.profile(record_shapes=True) as prof,
|
||||
profiler.record_function("model_inference"),
|
||||
):
|
||||
outputs = model(images)
|
||||
inference_times.append(
|
||||
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
|
||||
)
|
||||
else:
|
||||
outputs = model(images)
|
||||
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs.logits, labels)
|
||||
|
||||
# Track metrics
|
||||
@@ -160,18 +146,15 @@ def validate(model, val_loader, criterion, device, logger, cfg):
|
||||
running_loss += loss.item()
|
||||
|
||||
# Log sample predictions for visualization
|
||||
if len(samples) < cfg.eval.num_samples_to_log:
|
||||
for i in range(min(cfg.eval.num_samples_to_log - len(samples), len(images))):
|
||||
if len(samples) < num_samples_to_log:
|
||||
for i in range(min(num_samples_to_log - len(samples), len(images))):
|
||||
if model.config.num_classes == 2:
|
||||
confidence = round(outputs.probabilities[i].item(), 3)
|
||||
else:
|
||||
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
|
||||
samples.append(
|
||||
{
|
||||
**{
|
||||
f"image_{img_key}": wandb.Image(images[img_idx][i].cpu())
|
||||
for img_idx, img_key in enumerate(cfg.training.image_keys)
|
||||
},
|
||||
"image": wandb.Image(images[i].cpu()),
|
||||
"true_label": labels[i].item(),
|
||||
"predicted": predictions[i].item(),
|
||||
"confidence": confidence,
|
||||
@@ -187,83 +170,16 @@ def validate(model, val_loader, criterion, device, logger, cfg):
|
||||
"accuracy": accuracy,
|
||||
"eval_s": time.perf_counter() - batch_start_time,
|
||||
"eval/prediction_samples": wandb.Table(
|
||||
data=[list(s.values()) for s in samples],
|
||||
columns=list(samples[0].keys()),
|
||||
data=[[s["image"], s["true_label"], s["predicted"], f"{s['confidence']}"] for s in samples],
|
||||
columns=["Image", "True Label", "Predicted", "Confidence"],
|
||||
)
|
||||
if logger._cfg.wandb.enable
|
||||
else None,
|
||||
}
|
||||
|
||||
if len(inference_times) > 0:
|
||||
eval_info["inference_time_avg"] = np.mean(inference_times)
|
||||
eval_info["inference_time_median"] = np.median(inference_times)
|
||||
eval_info["inference_time_std"] = np.std(inference_times)
|
||||
eval_info["inference_time_batch_size"] = val_loader.batch_size
|
||||
|
||||
print(
|
||||
f"Inference mean time: {eval_info['inference_time_avg']:.2f} us, median: {eval_info['inference_time_median']:.2f} us, std: {eval_info['inference_time_std']:.2f} us, with {len(inference_times)} iterations on {device.type} device, batch size: {eval_info['inference_time_batch_size']}"
|
||||
)
|
||||
|
||||
return accuracy, eval_info
|
||||
|
||||
|
||||
def benchmark_inference_time(model, dataset, logger, cfg, device, step):
|
||||
if not cfg.training.profile_inference_time:
|
||||
return
|
||||
|
||||
iters = cfg.training.profile_inference_time_iters
|
||||
inference_times = []
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
num_workers=cfg.training.num_workers,
|
||||
sampler=RandomSampler(dataset),
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for _ in tqdm(range(iters), desc="Benchmarking inference time"):
|
||||
x = next(iter(loader))
|
||||
x = [x[img_key].to(device) for img_key in cfg.training.image_keys]
|
||||
|
||||
# Warm up
|
||||
for _ in range(10):
|
||||
_ = model(x)
|
||||
|
||||
# sync the device
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
elif device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
|
||||
with profiler.profile(record_shapes=True) as prof, profiler.record_function("model_inference"):
|
||||
_ = model(x)
|
||||
|
||||
inference_times.append(
|
||||
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
|
||||
)
|
||||
|
||||
inference_times = np.array(inference_times)
|
||||
avg, median, std = inference_times.mean(), np.median(inference_times), inference_times.std()
|
||||
print(
|
||||
f"Inference time mean: {avg:.2f} us, median: {median:.2f} us, std: {std:.2f} us, with {iters} iterations on {device.type} device"
|
||||
)
|
||||
if logger._cfg.wandb.enable:
|
||||
logger.log_dict(
|
||||
{
|
||||
"inference_time_benchmark_avg": avg,
|
||||
"inference_time_benchmark_median": median,
|
||||
"inference_time_benchmark_std": std,
|
||||
},
|
||||
step + 1,
|
||||
mode="eval",
|
||||
)
|
||||
|
||||
return avg, median, std
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier")
|
||||
def train(cfg: DictConfig) -> None:
|
||||
# Main training pipeline with support for resuming training
|
||||
@@ -273,19 +189,17 @@ def train(cfg: DictConfig) -> None:
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
out_dir = hydra.core.hydra_config.HydraConfig.get().run.dir + "frozen_resnet10_2"
|
||||
out_dir = Path(cfg.output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger = Logger(cfg, out_dir, cfg.wandb.job_name if cfg.wandb.enable else None)
|
||||
|
||||
# Setup dataset and dataloaders
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset_repo_id, root=cfg.dataset_root, local_files_only=cfg.local_files_only
|
||||
)
|
||||
dataset = LeRobotDataset(cfg.dataset_repo_id)
|
||||
logging.info(f"Dataset size: {len(dataset)}")
|
||||
|
||||
n_total = len(dataset)
|
||||
n_train = int(cfg.train_split_proportion * len(dataset))
|
||||
train_dataset = torch.utils.data.Subset(dataset, range(0, n_train))
|
||||
val_dataset = torch.utils.data.Subset(dataset, range(n_train, n_total))
|
||||
train_size = int(cfg.train_split_proportion * len(dataset))
|
||||
val_size = len(dataset) - train_size
|
||||
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
||||
|
||||
sampler = create_balanced_sampler(train_dataset, cfg)
|
||||
train_loader = DataLoader(
|
||||
@@ -293,7 +207,7 @@ def train(cfg: DictConfig) -> None:
|
||||
batch_size=cfg.training.batch_size,
|
||||
num_workers=cfg.training.num_workers,
|
||||
sampler=sampler,
|
||||
pin_memory=device.type == "cuda",
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
@@ -301,7 +215,7 @@ def train(cfg: DictConfig) -> None:
|
||||
batch_size=cfg.eval.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=cfg.training.num_workers,
|
||||
pin_memory=device.type == "cuda",
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
# Resume training if requested
|
||||
@@ -399,8 +313,6 @@ def train(cfg: DictConfig) -> None:
|
||||
|
||||
step += len(train_loader)
|
||||
|
||||
benchmark_inference_time(model, dataset, logger, cfg, device, step)
|
||||
|
||||
logging.info("Training completed")
|
||||
|
||||
|
||||
|
||||
@@ -13,25 +13,26 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import functools
|
||||
import logging
|
||||
import random
|
||||
import functools
|
||||
from pprint import pformat
|
||||
from typing import Callable, Optional, Sequence, TypedDict
|
||||
import random
|
||||
from typing import Optional, Sequence, TypedDict, Callable
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# TODO: Remove the import of maniskill
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.envs.factory import make_env, make_maniskill_env
|
||||
from lerobot.common.envs.utils import preprocess_maniskill_observation, preprocess_observation
|
||||
from lerobot.common.envs.utils import preprocess_observation, preprocess_maniskill_observation
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
@@ -176,7 +177,6 @@ class ReplayBuffer:
|
||||
)
|
||||
self.position: int = (self.position + 1) % self.capacity
|
||||
|
||||
# TODO: ADD image_augmentation and use_drq arguments in this function in order to instantiate the class with them
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
cls,
|
||||
|
||||
Reference in New Issue
Block a user