Compare commits
28 Commits
user/miche
...
user/miche
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
36714a14a7 | ||
|
|
68b8e274dd | ||
|
|
1a7b4ec890 | ||
|
|
1c9eccd279 | ||
|
|
7551260104 | ||
|
|
95758cb867 | ||
|
|
2ecc34ceb9 | ||
|
|
8598e80718 | ||
|
|
6fa3e5f9ad | ||
|
|
b7bd13570f | ||
|
|
f899edb57f | ||
|
|
17ec837a7a | ||
|
|
9e3c8461ca | ||
|
|
1f23ef7889 | ||
|
|
41219fe81e | ||
|
|
5081c145dc | ||
|
|
25b88f3b86 | ||
|
|
d711e20b5f | ||
|
|
700f00c014 | ||
|
|
584cad808e | ||
|
|
d8a1758122 | ||
|
|
1df9ee4f2d | ||
|
|
5b4a7aa81d | ||
|
|
ef8d943e54 | ||
|
|
42a038173f | ||
|
|
546719137a | ||
|
|
3ffe0cf0f4 | ||
|
|
ff82367c62 |
@@ -17,6 +17,7 @@ repos:
|
||||
rev: v3.19.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
exclude: '^(.*_pb2_grpc\.py|.*_pb2\.py$)'
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.8.2
|
||||
hooks:
|
||||
|
||||
@@ -32,7 +32,11 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import PIL
|
||||
import torch
|
||||
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
|
||||
from skimage.metrics import (
|
||||
mean_squared_error,
|
||||
peak_signal_noise_ratio,
|
||||
structural_similarity,
|
||||
)
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
@@ -81,7 +85,9 @@ def get_directory_size(directory: Path) -> int:
|
||||
return total_size
|
||||
|
||||
|
||||
def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> torch.Tensor:
|
||||
def load_original_frames(
|
||||
imgs_dir: Path, timestamps: list[float], fps: int
|
||||
) -> torch.Tensor:
|
||||
frames = []
|
||||
for ts in timestamps:
|
||||
idx = int(ts * fps)
|
||||
@@ -94,7 +100,11 @@ def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> t
|
||||
|
||||
|
||||
def save_decoded_frames(
|
||||
imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int
|
||||
imgs_dir: Path,
|
||||
save_dir: Path,
|
||||
frames: torch.Tensor,
|
||||
timestamps: list[float],
|
||||
fps: int,
|
||||
) -> None:
|
||||
if save_dir.exists() and len(list(save_dir.glob("frame_*.png"))) == len(timestamps):
|
||||
return
|
||||
@@ -104,7 +114,10 @@ def save_decoded_frames(
|
||||
idx = int(ts * fps)
|
||||
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
|
||||
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame_{idx:06d}_decoded.png")
|
||||
shutil.copyfile(imgs_dir / f"frame_{idx:06d}.png", save_dir / f"frame_{idx:06d}_original.png")
|
||||
shutil.copyfile(
|
||||
imgs_dir / f"frame_{idx:06d}.png",
|
||||
save_dir / f"frame_{idx:06d}_original.png",
|
||||
)
|
||||
|
||||
|
||||
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
@@ -116,11 +129,17 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
|
||||
# We only save images from the first camera
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")]
|
||||
img_keys = [
|
||||
key for key in hf_dataset.features if key.startswith("observation.image")
|
||||
]
|
||||
imgs_dataset = hf_dataset.select_columns(img_keys[0])
|
||||
|
||||
for i, item in enumerate(
|
||||
tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False)
|
||||
tqdm(
|
||||
imgs_dataset,
|
||||
desc=f"saving {dataset.repo_id} first episode images",
|
||||
leave=False,
|
||||
)
|
||||
):
|
||||
img = item[img_keys[0]]
|
||||
img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100)
|
||||
@@ -129,7 +148,9 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
break
|
||||
|
||||
|
||||
def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> list[float]:
|
||||
def sample_timestamps(
|
||||
timestamps_mode: str, ep_num_images: int, fps: int
|
||||
) -> list[float]:
|
||||
# Start at 5 to allow for 2_frames_4_space and 6_frames
|
||||
idx = random.randint(5, ep_num_images - 1)
|
||||
match timestamps_mode:
|
||||
@@ -154,7 +175,9 @@ def decode_video_frames(
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
if backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
return decode_video_frames_torchvision(
|
||||
video_path, timestamps, tolerance_s, backend
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(backend)
|
||||
|
||||
@@ -181,7 +204,9 @@ def benchmark_decoding(
|
||||
}
|
||||
|
||||
with time_benchmark:
|
||||
frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend)
|
||||
frames = decode_video_frames(
|
||||
video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend
|
||||
)
|
||||
result["load_time_video_ms"] = time_benchmark.result_ms / num_frames
|
||||
|
||||
with time_benchmark:
|
||||
@@ -190,12 +215,18 @@ def benchmark_decoding(
|
||||
|
||||
frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
|
||||
for i in range(num_frames):
|
||||
result["mse_values"].append(mean_squared_error(original_frames_np[i], frames_np[i]))
|
||||
result["mse_values"].append(
|
||||
mean_squared_error(original_frames_np[i], frames_np[i])
|
||||
)
|
||||
result["psnr_values"].append(
|
||||
peak_signal_noise_ratio(original_frames_np[i], frames_np[i], data_range=1.0)
|
||||
peak_signal_noise_ratio(
|
||||
original_frames_np[i], frames_np[i], data_range=1.0
|
||||
)
|
||||
)
|
||||
result["ssim_values"].append(
|
||||
structural_similarity(original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0)
|
||||
structural_similarity(
|
||||
original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0
|
||||
)
|
||||
)
|
||||
|
||||
if save_frames and sample == 0:
|
||||
@@ -215,7 +246,9 @@ def benchmark_decoding(
|
||||
# As these samples are independent, we run them in parallel threads to speed up the benchmark.
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = [executor.submit(process_sample, i) for i in range(num_samples)]
|
||||
for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False):
|
||||
for future in tqdm(
|
||||
as_completed(futures), total=num_samples, desc="samples", leave=False
|
||||
):
|
||||
result = future.result()
|
||||
load_times_video_ms.append(result["load_time_video_ms"])
|
||||
load_times_images_ms.append(result["load_time_images_ms"])
|
||||
@@ -275,9 +308,13 @@ def benchmark_encoding_decoding(
|
||||
random.seed(seed)
|
||||
benchmark_table = []
|
||||
for timestamps_mode in tqdm(
|
||||
decoding_cfg["timestamps_modes"], desc="decodings (timestamps_modes)", leave=False
|
||||
decoding_cfg["timestamps_modes"],
|
||||
desc="decodings (timestamps_modes)",
|
||||
leave=False,
|
||||
):
|
||||
for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False):
|
||||
for backend in tqdm(
|
||||
decoding_cfg["backends"], desc="decodings (backends)", leave=False
|
||||
):
|
||||
benchmark_row = benchmark_decoding(
|
||||
imgs_dir,
|
||||
video_path,
|
||||
@@ -355,14 +392,23 @@ def main(
|
||||
imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_")
|
||||
# We only use the first episode
|
||||
save_first_episode(imgs_dir, dataset)
|
||||
for key, values in tqdm(encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False):
|
||||
for key, values in tqdm(
|
||||
encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False
|
||||
):
|
||||
for value in tqdm(values, desc=f"encodings ({key})", leave=False):
|
||||
encoding_cfg = BASE_ENCODING.copy()
|
||||
encoding_cfg["vcodec"] = video_codec
|
||||
encoding_cfg["pix_fmt"] = pixel_format
|
||||
encoding_cfg[key] = value
|
||||
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
|
||||
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
|
||||
args_path = Path(
|
||||
"_".join(str(value) for value in encoding_cfg.values())
|
||||
)
|
||||
video_path = (
|
||||
output_dir
|
||||
/ "videos"
|
||||
/ args_path
|
||||
/ f"{repo_id.replace('/', '_')}.mp4"
|
||||
)
|
||||
benchmark_table += benchmark_encoding_decoding(
|
||||
dataset,
|
||||
video_path,
|
||||
@@ -388,7 +434,9 @@ def main(
|
||||
# Concatenate all results
|
||||
df_list = [pd.read_csv(csv_path) for csv_path in file_paths]
|
||||
concatenated_df = pd.concat(df_list, ignore_index=True)
|
||||
concatenated_path = output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
|
||||
concatenated_path = (
|
||||
output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
|
||||
)
|
||||
concatenated_df.to_csv(concatenated_path, header=True, index=False)
|
||||
|
||||
|
||||
|
||||
11
docker/lerobot-gpu-mani-skill/Dockerfile
Normal file
11
docker/lerobot-gpu-mani-skill/Dockerfile
Normal file
@@ -0,0 +1,11 @@
|
||||
FROM huggingface/lerobot-gpu:latest
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libvulkan1 vulkan-tools \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN pip install --upgrade --no-cache-dir pip
|
||||
RUN pip install --no-cache-dir ".[mani-skill]"
|
||||
|
||||
# Set EGL as the rendering backend for MuJoCo
|
||||
ENV MUJOCO_GL="egl"
|
||||
@@ -16,9 +16,9 @@ On your computer:
|
||||
mkdir -p ~/miniconda3
|
||||
# Linux:
|
||||
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
|
||||
# Mac M-series:
|
||||
# Mac M-series:
|
||||
# curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
|
||||
# Mac Intel:
|
||||
# Mac Intel:
|
||||
# curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -o ~/miniconda3/miniconda.sh
|
||||
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
|
||||
rm ~/miniconda3/miniconda.sh
|
||||
@@ -98,7 +98,7 @@ sudo chmod 666 /dev/ttyACM1
|
||||
|
||||
#### d. Update YAML file
|
||||
|
||||
Now that you have the ports, modify the *port* sections in `so100.yaml`
|
||||
Now that you have the ports, modify the *port* sections in `so100.yaml`
|
||||
|
||||
### 2. Configure the motors
|
||||
|
||||
|
||||
@@ -81,3 +81,14 @@ You can also log sample predictions during evaluation. Each logged sample will i
|
||||
- The **classifier's "confidence" (logits/probability)**.
|
||||
|
||||
These logs can be useful for diagnosing and debugging performance issues.
|
||||
|
||||
|
||||
#### Generate protobuf files
|
||||
|
||||
```bash
|
||||
python -m grpc_tools.protoc \
|
||||
-I lerobot/scripts/server \
|
||||
--python_out=lerobot/scripts/server \
|
||||
--grpc_python_out=lerobot/scripts/server \
|
||||
lerobot/scripts/server/hilserl.proto
|
||||
```
|
||||
|
||||
@@ -18,7 +18,10 @@ import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
LeRobotDatasetMetadata,
|
||||
)
|
||||
|
||||
# We ported a number of existing datasets ourselves, use this to see the list:
|
||||
print("List of available datasets:")
|
||||
@@ -26,7 +29,10 @@ pprint(lerobot.available_datasets)
|
||||
|
||||
# You can also browse through the datasets created/ported by the community on the hub using the hub api:
|
||||
hub_api = HfApi()
|
||||
repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])]
|
||||
repo_ids = [
|
||||
info.id
|
||||
for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])
|
||||
]
|
||||
pprint(repo_ids)
|
||||
|
||||
# Or simply explore them in your web browser directly at:
|
||||
@@ -41,7 +47,9 @@ ds_meta = LeRobotDatasetMetadata(repo_id)
|
||||
# structure of the dataset without downloading the actual data yet (only metadata files — which are
|
||||
# lightweight).
|
||||
print(f"Total number of episodes: {ds_meta.total_episodes}")
|
||||
print(f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}")
|
||||
print(
|
||||
f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}"
|
||||
)
|
||||
print(f"Frames per second used during data collection: {ds_meta.fps}")
|
||||
print(f"Robot type: {ds_meta.robot_type}")
|
||||
print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n")
|
||||
|
||||
@@ -32,7 +32,9 @@ if torch.cuda.is_available():
|
||||
print("GPU is available. Device set to:", device)
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
print(f"GPU is not available. Device set to: {device}. Inference will be slower than on GPU.")
|
||||
print(
|
||||
f"GPU is not available. Device set to: {device}. Inference will be slower than on GPU."
|
||||
)
|
||||
# Decrease the number of reverse-diffusion steps (trades off a bit of quality for 10x speed)
|
||||
policy.diffusion.num_inference_steps = 10
|
||||
|
||||
|
||||
@@ -31,7 +31,24 @@ delta_timestamps = {
|
||||
# Load the previous action (-0.1), the next action to be executed (0.0),
|
||||
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
||||
# used to supervise the policy.
|
||||
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||
"action": [
|
||||
-0.1,
|
||||
0.0,
|
||||
0.1,
|
||||
0.2,
|
||||
0.3,
|
||||
0.4,
|
||||
0.5,
|
||||
0.6,
|
||||
0.7,
|
||||
0.8,
|
||||
0.9,
|
||||
1.0,
|
||||
1.1,
|
||||
1.2,
|
||||
1.3,
|
||||
1.4,
|
||||
],
|
||||
}
|
||||
dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)
|
||||
|
||||
|
||||
@@ -34,10 +34,14 @@ transforms = v2.Compose(
|
||||
)
|
||||
|
||||
# Create another LeRobotDataset with the defined transformations
|
||||
transformed_dataset = LeRobotDataset(dataset_repo_id, episodes=[0], image_transforms=transforms)
|
||||
transformed_dataset = LeRobotDataset(
|
||||
dataset_repo_id, episodes=[0], image_transforms=transforms
|
||||
)
|
||||
|
||||
# Get a frame from the transformed dataset
|
||||
transformed_frame = transformed_dataset[first_idx][transformed_dataset.meta.camera_keys[0]]
|
||||
transformed_frame = transformed_dataset[first_idx][
|
||||
transformed_dataset.meta.camera_keys[0]
|
||||
]
|
||||
|
||||
# Create a directory to store output images
|
||||
output_dir = Path("outputs/image_transforms")
|
||||
|
||||
@@ -14,7 +14,10 @@ from pathlib import Path
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
LeRobotDatasetMetadata,
|
||||
)
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
device = torch.device("cuda")
|
||||
@@ -37,7 +40,24 @@ delta_timestamps = {
|
||||
# Load the previous action (-0.1), the next action to be executed (0.0),
|
||||
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
||||
# used to calculate the loss.
|
||||
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||
"action": [
|
||||
-0.1,
|
||||
0.0,
|
||||
0.1,
|
||||
0.2,
|
||||
0.3,
|
||||
0.4,
|
||||
0.5,
|
||||
0.6,
|
||||
0.7,
|
||||
0.8,
|
||||
0.9,
|
||||
1.0,
|
||||
1.1,
|
||||
1.2,
|
||||
1.3,
|
||||
1.4,
|
||||
],
|
||||
}
|
||||
|
||||
# Load the last 10% of episodes of the dataset as a validation set.
|
||||
@@ -53,8 +73,12 @@ print(f"Number of episodes in full dataset: {total_episodes}")
|
||||
print(f"Number of episodes in training dataset (90% subset): {len(train_episodes)}")
|
||||
print(f"Number of episodes in validation dataset (10% subset): {len(val_episodes)}")
|
||||
# - Load train an val datasets
|
||||
train_dataset = LeRobotDataset("lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps)
|
||||
val_dataset = LeRobotDataset("lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps)
|
||||
train_dataset = LeRobotDataset(
|
||||
"lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps
|
||||
)
|
||||
val_dataset = LeRobotDataset(
|
||||
"lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps
|
||||
)
|
||||
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
|
||||
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")
|
||||
|
||||
|
||||
@@ -69,7 +69,9 @@ def load_raw_dataset(zarr_path: Path):
|
||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
||||
print(
|
||||
"`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`"
|
||||
)
|
||||
raise e
|
||||
|
||||
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
|
||||
@@ -81,7 +83,9 @@ def calculate_coverage(zarr_data):
|
||||
import pymunk
|
||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||
except ModuleNotFoundError as e:
|
||||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
||||
print(
|
||||
"`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`"
|
||||
)
|
||||
raise e
|
||||
|
||||
block_pos = zarr_data["state"][:, 2:4]
|
||||
@@ -111,7 +115,9 @@ def calculate_coverage(zarr_data):
|
||||
]
|
||||
space.add(*walls)
|
||||
|
||||
block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||
block_body, block_shapes = PushTEnv.add_tee(
|
||||
space, block_pos[i].tolist(), block_angle[i].item()
|
||||
)
|
||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
|
||||
@@ -182,7 +182,11 @@ available_real_world_datasets = [
|
||||
]
|
||||
|
||||
available_datasets = sorted(
|
||||
set(itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets))
|
||||
set(
|
||||
itertools.chain(
|
||||
*available_datasets_per_env.values(), available_real_world_datasets
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# lists all available policies from `lerobot/common/policies`
|
||||
@@ -224,9 +228,13 @@ available_policies_per_env = {
|
||||
"dora_aloha_real": ["act_aloha_real"],
|
||||
}
|
||||
|
||||
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
|
||||
env_task_pairs = [
|
||||
(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks
|
||||
]
|
||||
env_dataset_pairs = [
|
||||
(env, dataset) for env, datasets in available_datasets_per_env.items() for dataset in datasets
|
||||
(env, dataset)
|
||||
for env, datasets in available_datasets_per_env.items()
|
||||
for dataset in datasets
|
||||
]
|
||||
env_dataset_policy_triplets = [
|
||||
(env, dataset, policy)
|
||||
|
||||
@@ -45,12 +45,20 @@ def get_stats_einops_patterns(dataset, num_workers=0):
|
||||
if key in dataset.meta.camera_keys:
|
||||
# sanity check that images are channel first
|
||||
_, c, h, w = batch[key].shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
|
||||
assert (
|
||||
c < h and c < w
|
||||
), f"expect channel first images, but instead {batch[key].shape}"
|
||||
|
||||
# sanity check that images are float32 in range [0,1]
|
||||
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
|
||||
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
|
||||
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
|
||||
assert (
|
||||
batch[key].dtype == torch.float32
|
||||
), f"expect torch.float32, but instead {batch[key].dtype=}"
|
||||
assert (
|
||||
batch[key].max() <= 1
|
||||
), f"expect pixels lower than 1, but instead {batch[key].max()=}"
|
||||
assert (
|
||||
batch[key].min() >= 0
|
||||
), f"expect pixels greater than 1, but instead {batch[key].min()=}"
|
||||
|
||||
stats_patterns[key] = "b c h w -> c 1 1"
|
||||
elif batch[key].ndim == 2:
|
||||
@@ -98,7 +106,11 @@ def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
|
||||
running_item_count = 0 # for online mean computation
|
||||
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
|
||||
tqdm.tqdm(
|
||||
dataloader,
|
||||
total=ceil(max_num_samples / batch_size),
|
||||
desc="Compute mean, min, max",
|
||||
)
|
||||
):
|
||||
this_batch_size = len(batch["index"])
|
||||
running_item_count += this_batch_size
|
||||
@@ -113,9 +125,16 @@ def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
|
||||
# and x is the current batch mean. Some rearrangement is then required to avoid risking
|
||||
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
|
||||
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
|
||||
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
|
||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||
mean[key] = (
|
||||
mean[key]
|
||||
+ this_batch_size * (batch_mean - mean[key]) / running_item_count
|
||||
)
|
||||
max[key] = torch.maximum(
|
||||
max[key], einops.reduce(batch[key], pattern, "max")
|
||||
)
|
||||
min[key] = torch.minimum(
|
||||
min[key], einops.reduce(batch[key], pattern, "min")
|
||||
)
|
||||
|
||||
if i == ceil(max_num_samples / batch_size) - 1:
|
||||
break
|
||||
@@ -124,7 +143,9 @@ def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
|
||||
running_item_count = 0 # for online std computation
|
||||
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
|
||||
tqdm.tqdm(
|
||||
dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std"
|
||||
)
|
||||
):
|
||||
this_batch_size = len(batch["index"])
|
||||
running_item_count += this_batch_size
|
||||
@@ -138,7 +159,9 @@ def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
|
||||
# Numerically stable update step for mean computation (where the mean is over squared
|
||||
# residuals).See notes in the mean computation loop above.
|
||||
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
|
||||
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
|
||||
std[key] = (
|
||||
std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
|
||||
)
|
||||
|
||||
if i == ceil(max_num_samples / batch_size) - 1:
|
||||
break
|
||||
@@ -177,13 +200,19 @@ def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
|
||||
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
|
||||
stats[data_key][stat_key] = einops.reduce(
|
||||
torch.stack(
|
||||
[ds.meta.stats[data_key][stat_key] for ds in ls_datasets if data_key in ds.meta.stats],
|
||||
[
|
||||
ds.meta.stats[data_key][stat_key]
|
||||
for ds in ls_datasets
|
||||
if data_key in ds.meta.stats
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
"n ... -> ...",
|
||||
stat_key,
|
||||
)
|
||||
total_samples = sum(d.num_frames for d in ls_datasets if data_key in d.meta.stats)
|
||||
total_samples = sum(
|
||||
d.num_frames for d in ls_datasets if data_key in d.meta.stats
|
||||
)
|
||||
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
|
||||
# dataset, then divide by total_samples to get the overall "mean".
|
||||
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
|
||||
|
||||
@@ -89,7 +89,9 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||
"image_std": None,
|
||||
}
|
||||
)
|
||||
cfg_tf = OmegaConf.merge(OmegaConf.create(default_tf), cfg.training.image_transforms)
|
||||
cfg_tf = OmegaConf.merge(
|
||||
OmegaConf.create(default_tf), cfg.training.image_transforms
|
||||
)
|
||||
|
||||
image_transforms = get_image_transforms(
|
||||
brightness_weight=cfg_tf.brightness.weight,
|
||||
@@ -104,7 +106,9 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||
sharpness_min_max=cfg_tf.sharpness.min_max,
|
||||
max_num_transforms=cfg_tf.max_num_transforms,
|
||||
random_order=cfg_tf.random_order,
|
||||
image_size=(cfg_tf.image_size.height, cfg_tf.image_size.width) if cfg_tf.image_size else None,
|
||||
image_size=(cfg_tf.image_size.height, cfg_tf.image_size.width)
|
||||
if cfg_tf.image_size
|
||||
else None,
|
||||
interpolation=cfg_tf.interpolation,
|
||||
image_mean=cfg_tf.image_mean,
|
||||
image_std=cfg_tf.image_std,
|
||||
@@ -131,6 +135,8 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||
for stats_type, listconfig in stats_dict.items():
|
||||
# example of stats_type: min, max, mean, std
|
||||
stats = OmegaConf.to_container(listconfig, resolve=True)
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(
|
||||
stats, dtype=torch.float32
|
||||
)
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -109,7 +109,9 @@ class AsyncImageWriter:
|
||||
self._stopped = False
|
||||
|
||||
if num_threads <= 0 and num_processes <= 0:
|
||||
raise ValueError("Number of threads and processes must be greater than zero.")
|
||||
raise ValueError(
|
||||
"Number of threads and processes must be greater than zero."
|
||||
)
|
||||
|
||||
if self.num_processes == 0:
|
||||
# Use threading
|
||||
@@ -123,12 +125,16 @@ class AsyncImageWriter:
|
||||
# Use multiprocessing
|
||||
self.queue = multiprocessing.JoinableQueue()
|
||||
for _ in range(self.num_processes):
|
||||
p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads))
|
||||
p = multiprocessing.Process(
|
||||
target=worker_process, args=(self.queue, self.num_threads)
|
||||
)
|
||||
p.daemon = True
|
||||
p.start()
|
||||
self.processes.append(p)
|
||||
|
||||
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
|
||||
def save_image(
|
||||
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path
|
||||
):
|
||||
if isinstance(image, torch.Tensor):
|
||||
# Convert tensor to numpy array to minimize main process time
|
||||
image = image.cpu().numpy()
|
||||
|
||||
@@ -68,7 +68,9 @@ from lerobot.common.robot_devices.robots.utils import Robot
|
||||
|
||||
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
|
||||
CODEBASE_VERSION = "v2.0"
|
||||
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
|
||||
LEROBOT_HOME = Path(
|
||||
os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")
|
||||
).expanduser()
|
||||
|
||||
|
||||
class LeRobotDatasetMetadata:
|
||||
@@ -108,7 +110,11 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
@cached_property
|
||||
def _hub_version(self) -> str | None:
|
||||
return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
|
||||
return (
|
||||
None
|
||||
if self.local_files_only
|
||||
else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
|
||||
)
|
||||
|
||||
@property
|
||||
def _version(self) -> str:
|
||||
@@ -122,7 +128,9 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
||||
ep_chunk = self.get_episode_chunk(ep_index)
|
||||
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
||||
fpath = self.video_path.format(
|
||||
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index
|
||||
)
|
||||
return Path(fpath)
|
||||
|
||||
def get_episode_chunk(self, ep_index: int) -> int:
|
||||
@@ -166,7 +174,11 @@ class LeRobotDatasetMetadata:
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||
return [
|
||||
key
|
||||
for key, ft in self.features.items()
|
||||
if ft["dtype"] in ["video", "image"]
|
||||
]
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
@@ -215,7 +227,9 @@ class LeRobotDatasetMetadata:
|
||||
task_index = self.task_to_task_index.get(task, None)
|
||||
return task_index if task_index is not None else self.total_tasks
|
||||
|
||||
def save_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None:
|
||||
def save_episode(
|
||||
self, episode_index: int, episode_length: int, task: str, task_index: int
|
||||
) -> None:
|
||||
self.info["total_episodes"] += 1
|
||||
self.info["total_frames"] += episode_length
|
||||
|
||||
@@ -257,7 +271,9 @@ class LeRobotDatasetMetadata:
|
||||
"""
|
||||
for key in self.video_keys:
|
||||
if not self.features[key].get("info", None):
|
||||
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
||||
video_path = self.root / self.get_video_file_path(
|
||||
ep_index=0, vid_key=key
|
||||
)
|
||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||
|
||||
write_json(self.info, self.root / INFO_PATH)
|
||||
@@ -308,7 +324,9 @@ class LeRobotDatasetMetadata:
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
|
||||
obj.tasks, obj.stats, obj.episodes = {}, {}, []
|
||||
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
|
||||
obj.info = create_empty_dataset_info(
|
||||
CODEBASE_VERSION, fps, robot_type, features, use_videos
|
||||
)
|
||||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
@@ -444,7 +462,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Load metadata
|
||||
self.meta = LeRobotDatasetMetadata(self.repo_id, self.root, self.local_files_only)
|
||||
self.meta = LeRobotDatasetMetadata(
|
||||
self.repo_id, self.root, self.local_files_only
|
||||
)
|
||||
|
||||
# Check version
|
||||
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
|
||||
@@ -452,10 +472,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Load actual data
|
||||
self.download_episodes(download_videos)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
self.episode_data_index = get_episode_data_index(
|
||||
self.meta.episodes, self.episodes
|
||||
)
|
||||
|
||||
# Check timestamps
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
check_timestamps_sync(
|
||||
self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s
|
||||
)
|
||||
|
||||
# Setup delta_indices
|
||||
if self.delta_timestamps is not None:
|
||||
@@ -501,7 +525,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
||||
)
|
||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset")
|
||||
create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset")
|
||||
create_branch(
|
||||
repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset"
|
||||
)
|
||||
|
||||
def pull_from_repo(
|
||||
self,
|
||||
@@ -529,7 +555,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
files = None
|
||||
ignore_patterns = None if download_videos else "videos/"
|
||||
if self.episodes is not None:
|
||||
files = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
||||
files = [
|
||||
str(self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes
|
||||
]
|
||||
if len(self.meta.video_keys) > 0 and download_videos:
|
||||
video_files = [
|
||||
str(self.meta.get_video_file_path(ep_idx, vid_key))
|
||||
@@ -547,7 +575,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
path = str(self.root / "data")
|
||||
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
|
||||
else:
|
||||
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
||||
files = [
|
||||
str(self.root / self.meta.get_data_file_path(ep_idx))
|
||||
for ep_idx in self.episodes
|
||||
]
|
||||
hf_dataset = load_dataset("parquet", data_files=files, split="train")
|
||||
|
||||
# TODO(aliberts): hf_dataset.set_format("torch")
|
||||
@@ -563,12 +594,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
"""Number of frames in selected episodes."""
|
||||
return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames
|
||||
return (
|
||||
len(self.hf_dataset)
|
||||
if self.hf_dataset is not None
|
||||
else self.meta.total_frames
|
||||
)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
"""Number of episodes selected."""
|
||||
return len(self.episodes) if self.episodes is not None else self.meta.total_episodes
|
||||
return (
|
||||
len(self.episodes)
|
||||
if self.episodes is not None
|
||||
else self.meta.total_episodes
|
||||
)
|
||||
|
||||
@property
|
||||
def features(self) -> dict[str, dict]:
|
||||
@@ -582,16 +621,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
return get_hf_features_from_features(self.features)
|
||||
|
||||
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
||||
def _get_query_indices(
|
||||
self, idx: int, ep_idx: int
|
||||
) -> tuple[dict[str, list[int | bool]]]:
|
||||
ep_start = self.episode_data_index["from"][ep_idx]
|
||||
ep_end = self.episode_data_index["to"][ep_idx]
|
||||
query_indices = {
|
||||
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
|
||||
key: [
|
||||
max(ep_start.item(), min(ep_end.item() - 1, idx + delta))
|
||||
for delta in delta_idx
|
||||
]
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
padding = { # Pad values outside of current episode range
|
||||
f"{key}_is_pad": torch.BoolTensor(
|
||||
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
|
||||
[
|
||||
(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item())
|
||||
for delta in delta_idx
|
||||
]
|
||||
)
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
@@ -619,7 +666,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if key not in self.meta.video_keys
|
||||
}
|
||||
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
|
||||
def _query_videos(
|
||||
self, query_timestamps: dict[str, list[float]], ep_idx: int
|
||||
) -> dict:
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
|
||||
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
||||
@@ -649,7 +698,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
query_indices = None
|
||||
if self.delta_indices is not None:
|
||||
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
|
||||
current_ep_idx = (
|
||||
self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
|
||||
)
|
||||
query_indices, padding = self._get_query_indices(idx, current_ep_idx)
|
||||
query_result = self._query_hf_dataset(query_indices)
|
||||
item = {**item, **padding}
|
||||
@@ -681,19 +732,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
|
||||
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
|
||||
current_ep_idx = (
|
||||
self.meta.total_episodes if episode_index is None else episode_index
|
||||
)
|
||||
return {
|
||||
"size": 0,
|
||||
**{key: current_ep_idx if key == "episode_index" else [] for key in self.features},
|
||||
**{
|
||||
key: current_ep_idx if key == "episode_index" else []
|
||||
for key in self.features
|
||||
},
|
||||
}
|
||||
|
||||
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
def _get_image_file_path(
|
||||
self, episode_index: int, image_key: str, frame_index: int
|
||||
) -> Path:
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
return self.root / fpath
|
||||
|
||||
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
|
||||
def _save_image(
|
||||
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path
|
||||
) -> None:
|
||||
if self.image_writer is None:
|
||||
if isinstance(image, torch.Tensor):
|
||||
image = image.cpu().numpy()
|
||||
@@ -714,7 +774,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
|
||||
frame_index = self.episode_buffer["size"]
|
||||
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
|
||||
timestamp = (
|
||||
frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
|
||||
)
|
||||
self.episode_buffer["frame_index"].append(frame_index)
|
||||
self.episode_buffer["timestamp"].append(timestamp)
|
||||
|
||||
@@ -723,11 +785,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
raise ValueError(key)
|
||||
|
||||
if self.features[key]["dtype"] not in ["image", "video"]:
|
||||
item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key]
|
||||
item = (
|
||||
frame[key].numpy()
|
||||
if isinstance(frame[key], torch.Tensor)
|
||||
else frame[key]
|
||||
)
|
||||
self.episode_buffer[key].append(item)
|
||||
elif self.features[key]["dtype"] in ["image", "video"]:
|
||||
img_path = self._get_image_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
|
||||
episode_index=self.episode_buffer["episode_index"],
|
||||
image_key=key,
|
||||
frame_index=frame_index,
|
||||
)
|
||||
if frame_index == 0:
|
||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
@@ -736,7 +804,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.episode_buffer["size"] += 1
|
||||
|
||||
def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict | None = None) -> None:
|
||||
def save_episode(
|
||||
self, task: str, encode_videos: bool = True, episode_data: dict | None = None
|
||||
) -> None:
|
||||
"""
|
||||
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
|
||||
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
|
||||
@@ -803,7 +873,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
|
||||
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
|
||||
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
|
||||
ep_dataset = datasets.Dataset.from_dict(
|
||||
episode_dict, features=self.hf_features, split="train"
|
||||
)
|
||||
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
|
||||
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
write_parquet(ep_dataset, ep_data_path)
|
||||
@@ -875,10 +947,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
return video_paths
|
||||
|
||||
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
|
||||
def consolidate(
|
||||
self, run_compute_stats: bool = True, keep_image_files: bool = False
|
||||
) -> None:
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
self.episode_data_index = get_episode_data_index(
|
||||
self.meta.episodes, self.episodes
|
||||
)
|
||||
check_timestamps_sync(
|
||||
self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s
|
||||
)
|
||||
|
||||
if len(self.meta.video_keys) > 0:
|
||||
self.encode_videos()
|
||||
@@ -983,7 +1061,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
self.root = Path(root) if root else LEROBOT_HOME
|
||||
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
|
||||
self.tolerances_s = (
|
||||
tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
|
||||
)
|
||||
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
||||
# are handled by this class.
|
||||
self._datasets = [
|
||||
@@ -1060,7 +1140,13 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
def features(self) -> datasets.Features:
|
||||
features = {}
|
||||
for dataset in self._datasets:
|
||||
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
|
||||
features.update(
|
||||
{
|
||||
k: v
|
||||
for k, v in dataset.hf_features.items()
|
||||
if k not in self.disabled_features
|
||||
}
|
||||
)
|
||||
return features
|
||||
|
||||
@property
|
||||
@@ -1121,7 +1207,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
continue
|
||||
break
|
||||
else:
|
||||
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
|
||||
raise AssertionError(
|
||||
"We expect the loop to break out as long as the index is within bounds."
|
||||
)
|
||||
item = self._datasets[dataset_idx][idx - start_idx]
|
||||
item["dataset_index"] = torch.tensor(dataset_idx)
|
||||
for data_key in self.disabled_features:
|
||||
|
||||
@@ -131,7 +131,9 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
else:
|
||||
self._delta_timestamps = None
|
||||
|
||||
def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]:
|
||||
def _make_data_spec(
|
||||
self, data_spec: dict[str, Any], buffer_capacity: int
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Makes the data spec for np.memmap."""
|
||||
if any(k.startswith("_") for k in data_spec):
|
||||
raise ValueError(
|
||||
@@ -154,14 +156,32 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()},
|
||||
# Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied
|
||||
# with real data rather than the dummy initialization.
|
||||
OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.OCCUPANCY_MASK_KEY: {
|
||||
"dtype": np.dtype("?"),
|
||||
"shape": (buffer_capacity,),
|
||||
},
|
||||
OnlineBuffer.INDEX_KEY: {
|
||||
"dtype": np.dtype("int64"),
|
||||
"shape": (buffer_capacity,),
|
||||
},
|
||||
OnlineBuffer.FRAME_INDEX_KEY: {
|
||||
"dtype": np.dtype("int64"),
|
||||
"shape": (buffer_capacity,),
|
||||
},
|
||||
OnlineBuffer.EPISODE_INDEX_KEY: {
|
||||
"dtype": np.dtype("int64"),
|
||||
"shape": (buffer_capacity,),
|
||||
},
|
||||
OnlineBuffer.TIMESTAMP_KEY: {
|
||||
"dtype": np.dtype("float64"),
|
||||
"shape": (buffer_capacity,),
|
||||
},
|
||||
}
|
||||
for k, v in data_spec.items():
|
||||
complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])}
|
||||
complete_data_spec[k] = {
|
||||
"dtype": v["dtype"],
|
||||
"shape": (buffer_capacity, *v["shape"]),
|
||||
}
|
||||
return complete_data_spec
|
||||
|
||||
def add_data(self, data: dict[str, np.ndarray]):
|
||||
@@ -188,7 +208,9 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
|
||||
# Shift the incoming indices if necessary.
|
||||
if self.num_frames > 0:
|
||||
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
|
||||
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][
|
||||
next_index - 1
|
||||
]
|
||||
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
|
||||
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
|
||||
data[OnlineBuffer.INDEX_KEY] += last_data_index + 1
|
||||
@@ -223,7 +245,11 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(
|
||||
np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
|
||||
np.unique(
|
||||
self._data[OnlineBuffer.EPISODE_INDEX_KEY][
|
||||
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -261,7 +287,9 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY],
|
||||
)
|
||||
)[0]
|
||||
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices]
|
||||
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][
|
||||
episode_data_indices
|
||||
]
|
||||
|
||||
for data_key in self.delta_timestamps:
|
||||
# Note: The logic in this loop is copied from `load_previous_and_future_frames`.
|
||||
@@ -278,7 +306,8 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
|
||||
# Check violated query timestamps are all outside the episode range.
|
||||
assert (
|
||||
(query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad])
|
||||
(query_ts[is_pad] < episode_timestamps[0])
|
||||
| (episode_timestamps[-1] < query_ts[is_pad])
|
||||
).all(), (
|
||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}"
|
||||
") inside the episode range."
|
||||
@@ -293,7 +322,9 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
|
||||
def get_data_by_key(self, key: str) -> torch.Tensor:
|
||||
"""Returns all data for a given data key as a Tensor."""
|
||||
return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
|
||||
return torch.from_numpy(
|
||||
self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]
|
||||
)
|
||||
|
||||
|
||||
def compute_sampler_weights(
|
||||
@@ -324,13 +355,19 @@ def compute_sampler_weights(
|
||||
- Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not
|
||||
included here to avoid adding complexity.
|
||||
"""
|
||||
if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0):
|
||||
raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.")
|
||||
if len(offline_dataset) == 0 and (
|
||||
online_dataset is None or len(online_dataset) == 0
|
||||
):
|
||||
raise ValueError(
|
||||
"At least one of `offline_dataset` or `online_dataset` should be contain data."
|
||||
)
|
||||
if (online_dataset is None) ^ (online_sampling_ratio is None):
|
||||
raise ValueError(
|
||||
"`online_dataset` and `online_sampling_ratio` must be provided together or not at all."
|
||||
)
|
||||
offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio
|
||||
offline_sampling_ratio = (
|
||||
0 if online_sampling_ratio is None else 1 - online_sampling_ratio
|
||||
)
|
||||
|
||||
weights = []
|
||||
|
||||
|
||||
@@ -37,10 +37,16 @@ def check_chunks_compatible(chunks: tuple, shape: tuple):
|
||||
assert c > 0
|
||||
|
||||
|
||||
def rechunk_recompress_array(group, name, chunks=None, chunk_length=None, compressor=None, tmp_key="_temp"):
|
||||
def rechunk_recompress_array(
|
||||
group, name, chunks=None, chunk_length=None, compressor=None, tmp_key="_temp"
|
||||
):
|
||||
old_arr = group[name]
|
||||
if chunks is None:
|
||||
chunks = (chunk_length,) + old_arr.chunks[1:] if chunk_length is not None else old_arr.chunks
|
||||
chunks = (
|
||||
(chunk_length,) + old_arr.chunks[1:]
|
||||
if chunk_length is not None
|
||||
else old_arr.chunks
|
||||
)
|
||||
check_chunks_compatible(chunks, old_arr.shape)
|
||||
|
||||
if compressor is None:
|
||||
@@ -82,13 +88,18 @@ def get_optimal_chunks(shape, dtype, target_chunk_bytes=2e6, max_chunk_length=No
|
||||
for i in range(len(shape) - 1):
|
||||
this_chunk_bytes = itemsize * np.prod(rshape[:i])
|
||||
next_chunk_bytes = itemsize * np.prod(rshape[: i + 1])
|
||||
if this_chunk_bytes <= target_chunk_bytes and next_chunk_bytes > target_chunk_bytes:
|
||||
if (
|
||||
this_chunk_bytes <= target_chunk_bytes
|
||||
and next_chunk_bytes > target_chunk_bytes
|
||||
):
|
||||
split_idx = i
|
||||
|
||||
rchunks = rshape[:split_idx]
|
||||
item_chunk_bytes = itemsize * np.prod(rshape[:split_idx])
|
||||
this_max_chunk_length = rshape[split_idx]
|
||||
next_chunk_length = min(this_max_chunk_length, math.ceil(target_chunk_bytes / item_chunk_bytes))
|
||||
next_chunk_length = min(
|
||||
this_max_chunk_length, math.ceil(target_chunk_bytes / item_chunk_bytes)
|
||||
)
|
||||
rchunks.append(next_chunk_length)
|
||||
len_diff = len(shape) - len(rchunks)
|
||||
rchunks.extend([1] * len_diff)
|
||||
@@ -124,7 +135,13 @@ class ReplayBuffer:
|
||||
root.require_group("data", overwrite=False)
|
||||
meta = root.require_group("meta", overwrite=False)
|
||||
if "episode_ends" not in meta:
|
||||
meta.zeros("episode_ends", shape=(0,), dtype=np.int64, compressor=None, overwrite=False)
|
||||
meta.zeros(
|
||||
"episode_ends",
|
||||
shape=(0,),
|
||||
dtype=np.int64,
|
||||
compressor=None,
|
||||
overwrite=False,
|
||||
)
|
||||
return cls(root=root)
|
||||
|
||||
@classmethod
|
||||
@@ -193,7 +210,11 @@ class ReplayBuffer:
|
||||
root = zarr.group(store=store)
|
||||
# copy without recompression
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||
source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists
|
||||
source=src_store,
|
||||
dest=store,
|
||||
source_path="/meta",
|
||||
dest_path="/meta",
|
||||
if_exists=if_exists,
|
||||
)
|
||||
data_group = root.create_group("data", overwrite=True)
|
||||
if keys is None:
|
||||
@@ -201,7 +222,9 @@ class ReplayBuffer:
|
||||
for key in keys:
|
||||
value = src_root["data"][key]
|
||||
cks = cls._resolve_array_chunks(chunks=chunks, key=key, array=value)
|
||||
cpr = cls._resolve_array_compressor(compressors=compressors, key=key, array=value)
|
||||
cpr = cls._resolve_array_compressor(
|
||||
compressors=compressors, key=key, array=value
|
||||
)
|
||||
if cks == value.chunks and cpr == value.compressor:
|
||||
# copy without recompression
|
||||
this_path = "/data/" + key
|
||||
@@ -286,13 +309,17 @@ class ReplayBuffer:
|
||||
meta_group = root.create_group("meta", overwrite=True)
|
||||
# save meta, no chunking
|
||||
for key, value in self.root["meta"].items():
|
||||
_ = meta_group.array(name=key, data=value, shape=value.shape, chunks=value.shape)
|
||||
_ = meta_group.array(
|
||||
name=key, data=value, shape=value.shape, chunks=value.shape
|
||||
)
|
||||
|
||||
# save data, chunk
|
||||
data_group = root.create_group("data", overwrite=True)
|
||||
for key, value in self.root["data"].items():
|
||||
cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
|
||||
cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
|
||||
cpr = self._resolve_array_compressor(
|
||||
compressors=compressors, key=key, array=value
|
||||
)
|
||||
if isinstance(value, zarr.Array):
|
||||
if cks == value.chunks and cpr == value.compressor:
|
||||
# copy without recompression
|
||||
@@ -339,13 +366,19 @@ class ReplayBuffer:
|
||||
@staticmethod
|
||||
def resolve_compressor(compressor="default"):
|
||||
if compressor == "default":
|
||||
compressor = numcodecs.Blosc(cname="lz4", clevel=5, shuffle=numcodecs.Blosc.NOSHUFFLE)
|
||||
compressor = numcodecs.Blosc(
|
||||
cname="lz4", clevel=5, shuffle=numcodecs.Blosc.NOSHUFFLE
|
||||
)
|
||||
elif compressor == "disk":
|
||||
compressor = numcodecs.Blosc("zstd", clevel=5, shuffle=numcodecs.Blosc.BITSHUFFLE)
|
||||
compressor = numcodecs.Blosc(
|
||||
"zstd", clevel=5, shuffle=numcodecs.Blosc.BITSHUFFLE
|
||||
)
|
||||
return compressor
|
||||
|
||||
@classmethod
|
||||
def _resolve_array_compressor(cls, compressors: dict | str | numcodecs.abc.Codec, key, array):
|
||||
def _resolve_array_compressor(
|
||||
cls, compressors: dict | str | numcodecs.abc.Codec, key, array
|
||||
):
|
||||
# allows compressor to be explicitly set to None
|
||||
cpr = "nil"
|
||||
if isinstance(compressors, dict):
|
||||
@@ -404,7 +437,11 @@ class ReplayBuffer:
|
||||
if self.backend == "zarr":
|
||||
for key, value in np_data.items():
|
||||
_ = meta_group.array(
|
||||
name=key, data=value, shape=value.shape, chunks=value.shape, overwrite=True
|
||||
name=key,
|
||||
data=value,
|
||||
shape=value.shape,
|
||||
chunks=value.shape,
|
||||
overwrite=True,
|
||||
)
|
||||
else:
|
||||
meta_group.update(np_data)
|
||||
@@ -514,10 +551,18 @@ class ReplayBuffer:
|
||||
# create array
|
||||
if key not in self.data:
|
||||
if is_zarr:
|
||||
cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
|
||||
cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
|
||||
cks = self._resolve_array_chunks(
|
||||
chunks=chunks, key=key, array=value
|
||||
)
|
||||
cpr = self._resolve_array_compressor(
|
||||
compressors=compressors, key=key, array=value
|
||||
)
|
||||
arr = self.data.zeros(
|
||||
name=key, shape=new_shape, chunks=cks, dtype=value.dtype, compressor=cpr
|
||||
name=key,
|
||||
shape=new_shape,
|
||||
chunks=cks,
|
||||
dtype=value.dtype,
|
||||
compressor=cpr,
|
||||
)
|
||||
else:
|
||||
# copy data to prevent modify
|
||||
@@ -544,7 +589,9 @@ class ReplayBuffer:
|
||||
|
||||
# rechunk
|
||||
if is_zarr and episode_ends.chunks[0] < episode_ends.shape[0]:
|
||||
rechunk_recompress_array(self.meta, "episode_ends", chunk_length=int(episode_ends.shape[0] * 1.5))
|
||||
rechunk_recompress_array(
|
||||
self.meta, "episode_ends", chunk_length=int(episode_ends.shape[0] * 1.5)
|
||||
)
|
||||
|
||||
def drop_episode(self):
|
||||
is_zarr = self.backend == "zarr"
|
||||
|
||||
@@ -38,7 +38,9 @@ import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import AVAILABLE_RAW_REPO_IDS
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import (
|
||||
AVAILABLE_RAW_REPO_IDS,
|
||||
)
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
||||
from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub
|
||||
|
||||
@@ -73,7 +75,9 @@ def encode_datasets(
|
||||
check_repo_id(raw_repo_id)
|
||||
dataset_repo_id_push = get_push_repo_id_from_raw(raw_repo_id, push_repo)
|
||||
dataset_raw_dir = raw_dir / raw_repo_id
|
||||
dataset_dir = local_dir / dataset_repo_id_push if local_dir is not None else None
|
||||
dataset_dir = (
|
||||
local_dir / dataset_repo_id_push if local_dir is not None else None
|
||||
)
|
||||
encoding = {
|
||||
"vcodec": vcodec,
|
||||
"pix_fmt": pix_fmt,
|
||||
|
||||
@@ -133,7 +133,9 @@ class Jpeg2k(Codec):
|
||||
)
|
||||
|
||||
def decode(self, buf, out=None):
|
||||
return imagecodecs.jpeg2k_decode(buf, verbose=self.verbose, numthreads=self.numthreads, out=out)
|
||||
return imagecodecs.jpeg2k_decode(
|
||||
buf, verbose=self.verbose, numthreads=self.numthreads, out=out
|
||||
)
|
||||
|
||||
|
||||
class JpegXl(Codec):
|
||||
|
||||
@@ -44,7 +44,9 @@ from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
def get_cameras(hdf5_data):
|
||||
# ignore depth channel, not currently handled
|
||||
# TODO(rcadene): add depth
|
||||
rgb_cameras = [key for key in hdf5_data["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
|
||||
rgb_cameras = [
|
||||
key for key in hdf5_data["/observations/images"].keys() if "depth" not in key
|
||||
] # noqa: SIM118
|
||||
return rgb_cameras
|
||||
|
||||
|
||||
@@ -73,7 +75,9 @@ def check_format(raw_dir) -> bool:
|
||||
else:
|
||||
assert data[f"/observations/images/{camera}"].ndim == 4
|
||||
b, h, w, c = data[f"/observations/images/{camera}"].shape
|
||||
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
||||
assert (
|
||||
c < h and c < w
|
||||
), f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
||||
|
||||
|
||||
def load_from_raw(
|
||||
@@ -134,14 +138,17 @@ def load_from_raw(
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
||||
encode_video_frames(
|
||||
tmp_imgs_dir, video_path, fps, **(encoding or {})
|
||||
)
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps}
|
||||
for i in range(num_frames)
|
||||
]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
@@ -181,15 +188,18 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
features[key] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
length=data_dict["observation.state"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
)
|
||||
if "observation.velocity" in data_dict:
|
||||
features["observation.velocity"] = Sequence(
|
||||
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
length=data_dict["observation.velocity"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
)
|
||||
if "observation.effort" in data_dict:
|
||||
features["observation.effort"] = Sequence(
|
||||
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
length=data_dict["observation.effort"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
|
||||
@@ -26,7 +26,9 @@ import torch
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
@@ -42,11 +44,19 @@ def check_format(raw_dir) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||
def load_from_raw(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int,
|
||||
video: bool,
|
||||
episodes: list[int] | None = None,
|
||||
):
|
||||
# Load data stream that will be used as reference for the timestamps synchronization
|
||||
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
|
||||
if len(reference_files) == 0:
|
||||
raise ValueError(f"Missing reference files for camera, starting with in '{raw_dir}'")
|
||||
raise ValueError(
|
||||
f"Missing reference files for camera, starting with in '{raw_dir}'"
|
||||
)
|
||||
# select first camera in alphanumeric order
|
||||
reference_key = sorted(reference_files)[0].stem
|
||||
reference_df = pd.read_parquet(raw_dir / f"{reference_key}.parquet")
|
||||
@@ -107,7 +117,9 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
||||
|
||||
df["timestamp"] = df["timestamp_utc"].map(lambda x: x.timestamp())
|
||||
# each episode starts with timestamp 0 to match the ones from the video
|
||||
df["timestamp"] = df.groupby("episode_index")["timestamp"].transform(lambda x: x - x.iloc[0])
|
||||
df["timestamp"] = df.groupby("episode_index")["timestamp"].transform(
|
||||
lambda x: x - x.iloc[0]
|
||||
)
|
||||
|
||||
del df["timestamp_utc"]
|
||||
|
||||
@@ -120,7 +132,9 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
||||
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
|
||||
expected_ep_ids = list(range(df["episode_index"].max() + 1))
|
||||
if ep_ids != expected_ep_ids:
|
||||
raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}")
|
||||
raise ValueError(
|
||||
f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}"
|
||||
)
|
||||
|
||||
# Create symlink to raw videos directory (that needs to be absolute not relative)
|
||||
videos_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
@@ -152,7 +166,9 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
||||
data_dict[key] = torch.from_numpy(df[key].values)
|
||||
# is vector
|
||||
elif df[key].iloc[0].shape[0] > 1:
|
||||
data_dict[key] = torch.stack([torch.from_numpy(x.copy()) for x in df[key].values])
|
||||
data_dict[key] = torch.stack(
|
||||
[torch.from_numpy(x.copy()) for x in df[key].values]
|
||||
)
|
||||
else:
|
||||
raise ValueError(key)
|
||||
|
||||
@@ -170,15 +186,18 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
features[key] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
length=data_dict["observation.state"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
)
|
||||
if "observation.velocity" in data_dict:
|
||||
features["observation.velocity"] = Sequence(
|
||||
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
length=data_dict["observation.velocity"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
)
|
||||
if "observation.effort" in data_dict:
|
||||
features["observation.effort"] = Sequence(
|
||||
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
length=data_dict["observation.effort"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
|
||||
@@ -143,7 +143,11 @@ def load_from_raw(
|
||||
else:
|
||||
state_keys.append(key)
|
||||
|
||||
lang_key = "language_instruction" if "language_instruction" in dataset.element_spec else None
|
||||
lang_key = (
|
||||
"language_instruction"
|
||||
if "language_instruction" in dataset.element_spec
|
||||
else None
|
||||
)
|
||||
|
||||
print(" - image_keys: ", image_keys)
|
||||
print(" - lang_key: ", lang_key)
|
||||
@@ -202,7 +206,9 @@ def load_from_raw(
|
||||
|
||||
# If lang_key is present, convert the entire tensor at once
|
||||
if lang_key is not None:
|
||||
ep_dict["language_instruction"] = [x.numpy().decode("utf-8") for x in episode[lang_key]]
|
||||
ep_dict["language_instruction"] = [
|
||||
x.numpy().decode("utf-8") for x in episode[lang_key]
|
||||
]
|
||||
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
|
||||
@@ -234,7 +240,8 @@ def load_from_raw(
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps}
|
||||
for i in range(num_frames)
|
||||
]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
@@ -259,7 +266,9 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
for key in data_dict:
|
||||
# check if vector state obs
|
||||
if key.startswith("observation.") and "observation.images." not in key:
|
||||
features[key] = Sequence(length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None))
|
||||
features[key] = Sequence(
|
||||
length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
# check if image obs
|
||||
elif "observation.images." in key:
|
||||
if video:
|
||||
|
||||
@@ -56,7 +56,9 @@ def check_format(raw_dir):
|
||||
|
||||
required_datasets.remove("meta/episode_ends")
|
||||
|
||||
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||
assert all(
|
||||
nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets
|
||||
)
|
||||
|
||||
|
||||
def load_from_raw(
|
||||
@@ -76,7 +78,9 @@ def load_from_raw(
|
||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
||||
print(
|
||||
"`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`"
|
||||
)
|
||||
raise e
|
||||
# as define in gmy-pusht env: https://github.com/huggingface/gym-pusht/blob/e0684ff988d223808c0a9dcfaba9dc4991791370/gym_pusht/envs/pusht.py#L174
|
||||
success_threshold = 0.95 # 95% coverage,
|
||||
@@ -150,7 +154,9 @@ def load_from_raw(
|
||||
]
|
||||
space.add(*walls)
|
||||
|
||||
block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||
block_body, block_shapes = PushTEnv.add_tee(
|
||||
space, block_pos[i].tolist(), block_angle[i].item()
|
||||
)
|
||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
@@ -159,7 +165,9 @@ def load_from_raw(
|
||||
reward[i] = np.clip(coverage / success_threshold, 0, 1)
|
||||
success[i] = coverage > success_threshold
|
||||
if keypoints_instead_of_image:
|
||||
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
|
||||
keypoints[i] = torch.from_numpy(
|
||||
PushTEnv.get_keypoints(block_shapes).flatten()
|
||||
)
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done[-1] = True
|
||||
@@ -184,7 +192,8 @@ def load_from_raw(
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps}
|
||||
for i in range(num_frames)
|
||||
]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
@@ -193,7 +202,9 @@ def load_from_raw(
|
||||
if keypoints_instead_of_image:
|
||||
ep_dict["observation.environment_state"] = keypoints
|
||||
ep_dict["action"] = actions[from_idx:to_idx]
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||
ep_dict["episode_index"] = torch.tensor(
|
||||
[ep_idx] * num_frames, dtype=torch.int64
|
||||
)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
# ep_dict["next.observation.image"] = image[1:],
|
||||
@@ -220,7 +231,8 @@ def to_hf_dataset(data_dict, video, keypoints_instead_of_image: bool = False):
|
||||
features["observation.image"] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
length=data_dict["observation.state"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
)
|
||||
if keypoints_instead_of_image:
|
||||
features["observation.environment_state"] = Sequence(
|
||||
@@ -261,7 +273,9 @@ def from_raw_to_lerobot_format(
|
||||
if fps is None:
|
||||
fps = 10
|
||||
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, keypoints_instead_of_image, encoding)
|
||||
data_dict = load_from_raw(
|
||||
raw_dir, videos_dir, fps, video, episodes, keypoints_instead_of_image, encoding
|
||||
)
|
||||
hf_dataset = to_hf_dataset(data_dict, video, keypoints_instead_of_image)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
|
||||
@@ -26,7 +26,9 @@ from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
|
||||
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import (
|
||||
register_codecs,
|
||||
)
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
concatenate_episodes,
|
||||
@@ -61,7 +63,9 @@ def check_format(raw_dir) -> bool:
|
||||
nb_frames = zarr_data["data/camera0_rgb"].shape[0]
|
||||
|
||||
required_datasets.remove("meta/episode_ends")
|
||||
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||
assert all(
|
||||
nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets
|
||||
)
|
||||
|
||||
|
||||
def load_from_raw(
|
||||
@@ -79,7 +83,9 @@ def load_from_raw(
|
||||
end_pose = torch.from_numpy(zarr_data["data/robot0_demo_end_pose"][:])
|
||||
start_pos = torch.from_numpy(zarr_data["data/robot0_demo_start_pose"][:])
|
||||
eff_pos = torch.from_numpy(zarr_data["data/robot0_eef_pos"][:])
|
||||
eff_rot_axis_angle = torch.from_numpy(zarr_data["data/robot0_eef_rot_axis_angle"][:])
|
||||
eff_rot_axis_angle = torch.from_numpy(
|
||||
zarr_data["data/robot0_eef_rot_axis_angle"][:]
|
||||
)
|
||||
gripper_width = torch.from_numpy(zarr_data["data/robot0_gripper_width"][:])
|
||||
|
||||
states_pos = torch.cat([eff_pos, eff_rot_axis_angle], dim=1)
|
||||
@@ -129,24 +135,31 @@ def load_from_raw(
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
||||
encode_video_frames(
|
||||
tmp_imgs_dir, video_path, fps, **(encoding or {})
|
||||
)
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps}
|
||||
for i in range(num_frames)
|
||||
]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
ep_dict["observation.state"] = state
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||
ep_dict["episode_index"] = torch.tensor(
|
||||
[ep_idx] * num_frames, dtype=torch.int64
|
||||
)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
ep_dict["episode_data_index_from"] = torch.tensor([from_idx] * num_frames)
|
||||
ep_dict["episode_data_index_to"] = torch.tensor([from_idx + num_frames] * num_frames)
|
||||
ep_dict["episode_data_index_to"] = torch.tensor(
|
||||
[from_idx + num_frames] * num_frames
|
||||
)
|
||||
ep_dict["end_pose"] = end_pose[from_idx:to_idx]
|
||||
ep_dict["start_pos"] = start_pos[from_idx:to_idx]
|
||||
ep_dict["gripper_width"] = gripper_width[from_idx:to_idx]
|
||||
@@ -172,7 +185,8 @@ def to_hf_dataset(data_dict, video):
|
||||
features["observation.image"] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
length=data_dict["observation.state"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
)
|
||||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
@@ -192,7 +206,8 @@ def to_hf_dataset(data_dict, video):
|
||||
length=data_dict["start_pos"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["gripper_width"] = Sequence(
|
||||
length=data_dict["gripper_width"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
length=data_dict["gripper_width"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
)
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||
|
||||
@@ -45,7 +45,9 @@ def concatenate_episodes(ep_dicts):
|
||||
return data_dict
|
||||
|
||||
|
||||
def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4):
|
||||
def save_images_concurrently(
|
||||
imgs_array: numpy.array, out_dir: Path, max_workers: int = 4
|
||||
):
|
||||
out_dir = Path(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -55,7 +57,10 @@ def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers
|
||||
|
||||
num_images = len(imgs_array)
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
||||
[
|
||||
executor.submit(save_image, imgs_array[i], i, out_dir)
|
||||
for i in range(num_images)
|
||||
]
|
||||
|
||||
|
||||
def get_default_encoding() -> dict:
|
||||
@@ -64,7 +69,8 @@ def get_default_encoding() -> dict:
|
||||
return {
|
||||
k: v.default
|
||||
for k, v in signature.parameters.items()
|
||||
if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"]
|
||||
if v.default is not inspect.Parameter.empty
|
||||
and k in ["vcodec", "pix_fmt", "g", "crf"]
|
||||
}
|
||||
|
||||
|
||||
@@ -77,7 +83,9 @@ def check_repo_id(repo_id: str) -> None:
|
||||
|
||||
|
||||
# TODO(aliberts): remove
|
||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
|
||||
def calculate_episode_data_index(
|
||||
hf_dataset: datasets.Dataset,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
|
||||
|
||||
|
||||
@@ -40,7 +40,10 @@ from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
||||
def check_format(raw_dir):
|
||||
keys = {"actions", "rewards", "dones"}
|
||||
nested_keys = {"observations": {"rgb", "state"}, "next_observations": {"rgb", "state"}}
|
||||
nested_keys = {
|
||||
"observations": {"rgb", "state"},
|
||||
"next_observations": {"rgb", "state"},
|
||||
}
|
||||
|
||||
xarm_files = list(raw_dir.glob("*.pkl"))
|
||||
assert len(xarm_files) > 0
|
||||
@@ -53,11 +56,17 @@ def check_format(raw_dir):
|
||||
|
||||
# Check for consistent lengths in nested keys
|
||||
expected_len = len(dataset_dict["actions"])
|
||||
assert all(len(dataset_dict[key]) == expected_len for key in keys if key in dataset_dict)
|
||||
assert all(
|
||||
len(dataset_dict[key]) == expected_len for key in keys if key in dataset_dict
|
||||
)
|
||||
|
||||
for key, subkeys in nested_keys.items():
|
||||
nested_dict = dataset_dict.get(key, {})
|
||||
assert all(len(nested_dict[subkey]) == expected_len for subkey in subkeys if subkey in nested_dict)
|
||||
assert all(
|
||||
len(nested_dict[subkey]) == expected_len
|
||||
for subkey in subkeys
|
||||
if subkey in nested_dict
|
||||
)
|
||||
|
||||
|
||||
def load_from_raw(
|
||||
@@ -122,13 +131,18 @@ def load_from_raw(
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)]
|
||||
ep_dict[img_key] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps}
|
||||
for i in range(num_frames)
|
||||
]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
ep_dict["observation.state"] = state
|
||||
ep_dict["action"] = action
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||
ep_dict["episode_index"] = torch.tensor(
|
||||
[ep_idx] * num_frames, dtype=torch.int64
|
||||
)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
# ep_dict["next.observation.image"] = next_image
|
||||
@@ -153,7 +167,8 @@ def to_hf_dataset(data_dict, video):
|
||||
features["observation.image"] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
length=data_dict["observation.state"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
|
||||
@@ -43,7 +43,10 @@ class EpisodeAwareSampler:
|
||||
):
|
||||
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
||||
indices.extend(
|
||||
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
|
||||
range(
|
||||
start_index.item() + drop_n_first_frames,
|
||||
end_index.item() - drop_n_last_frames,
|
||||
)
|
||||
)
|
||||
|
||||
self.indices = indices
|
||||
|
||||
@@ -57,7 +57,9 @@ class RandomSubsetApply(Transform):
|
||||
elif not isinstance(n_subset, int):
|
||||
raise TypeError("n_subset should be an int or None")
|
||||
elif not (1 <= n_subset <= len(transforms)):
|
||||
raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]")
|
||||
raise ValueError(
|
||||
f"n_subset should be in the interval [1, {len(transforms)}]"
|
||||
)
|
||||
|
||||
self.transforms = transforms
|
||||
total = sum(p)
|
||||
@@ -116,16 +118,22 @@ class SharpnessJitter(Transform):
|
||||
def _check_input(self, sharpness):
|
||||
if isinstance(sharpness, (int, float)):
|
||||
if sharpness < 0:
|
||||
raise ValueError("If sharpness is a single number, it must be non negative.")
|
||||
raise ValueError(
|
||||
"If sharpness is a single number, it must be non negative."
|
||||
)
|
||||
sharpness = [1.0 - sharpness, 1.0 + sharpness]
|
||||
sharpness[0] = max(sharpness[0], 0.0)
|
||||
elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2:
|
||||
sharpness = [float(v) for v in sharpness]
|
||||
else:
|
||||
raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.")
|
||||
raise TypeError(
|
||||
f"{sharpness=} should be a single number or a sequence with length 2."
|
||||
)
|
||||
|
||||
if not 0.0 <= sharpness[0] <= sharpness[1]:
|
||||
raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.")
|
||||
raise ValueError(
|
||||
f"sharpnesss values should be between (0., inf), but got {sharpness}."
|
||||
)
|
||||
|
||||
return float(sharpness[0]), float(sharpness[1])
|
||||
|
||||
@@ -134,7 +142,9 @@ class SharpnessJitter(Transform):
|
||||
|
||||
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
||||
sharpness_factor = self._generate_value(self.sharpness[0], self.sharpness[1])
|
||||
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
||||
return self._call_kernel(
|
||||
F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor
|
||||
)
|
||||
|
||||
|
||||
def get_image_transforms(
|
||||
@@ -185,7 +195,11 @@ def get_image_transforms(
|
||||
raise ValueError("The interpolation passed is not supported")
|
||||
# Weight for resizing is always 1
|
||||
weights.append(1.0)
|
||||
transforms.append(v2.Resize(size=(image_size[0], image_size[1]), interpolation=interpolation_mode))
|
||||
transforms.append(
|
||||
v2.Resize(
|
||||
size=(image_size[0], image_size[1]), interpolation=interpolation_mode
|
||||
)
|
||||
)
|
||||
if brightness_min_max is not None and brightness_weight > 0.0:
|
||||
weights.append(brightness_weight)
|
||||
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
|
||||
@@ -219,4 +233,6 @@ def get_image_transforms(
|
||||
return v2.Identity()
|
||||
else:
|
||||
# TODO(rcadene, aliberts): add v2.ToDtype float16?
|
||||
return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order)
|
||||
return RandomSubsetApply(
|
||||
transforms, p=weights, n_subset=n_subset, random_order=random_order
|
||||
)
|
||||
|
||||
@@ -43,9 +43,15 @@ EPISODES_PATH = "meta/episodes.jsonl"
|
||||
STATS_PATH = "meta/stats.json"
|
||||
TASKS_PATH = "meta/tasks.jsonl"
|
||||
|
||||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
DEFAULT_VIDEO_PATH = (
|
||||
"videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
)
|
||||
DEFAULT_PARQUET_PATH = (
|
||||
"data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
)
|
||||
DEFAULT_IMAGE_PATH = (
|
||||
"images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
)
|
||||
|
||||
DATASET_CARD_TEMPLATE = """
|
||||
---
|
||||
@@ -99,7 +105,9 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
|
||||
|
||||
|
||||
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||
serialized_dict = {key: value.tolist() for key, value in flatten_dict(stats).items()}
|
||||
serialized_dict = {
|
||||
key: value.tolist() for key, value in flatten_dict(stats).items()
|
||||
}
|
||||
return unflatten_dict(serialized_dict)
|
||||
|
||||
|
||||
@@ -157,14 +165,19 @@ def load_stats(local_dir: Path) -> dict:
|
||||
|
||||
def load_tasks(local_dir: Path) -> dict:
|
||||
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
||||
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
return {
|
||||
item["task_index"]: item["task"]
|
||||
for item in sorted(tasks, key=lambda x: x["task_index"])
|
||||
}
|
||||
|
||||
|
||||
def load_episodes(local_dir: Path) -> dict:
|
||||
return load_jsonlines(local_dir / EPISODES_PATH)
|
||||
|
||||
|
||||
def load_image_as_numpy(fpath: str | Path, dtype="float32", channel_first: bool = True) -> np.ndarray:
|
||||
def load_image_as_numpy(
|
||||
fpath: str | Path, dtype="float32", channel_first: bool = True
|
||||
) -> np.ndarray:
|
||||
img = PILImage.open(fpath).convert("RGB")
|
||||
img_array = np.array(img, dtype=dtype)
|
||||
if channel_first: # (H, W, C) -> (C, H, W)
|
||||
@@ -222,7 +235,10 @@ class BackwardCompatibilityError(Exception):
|
||||
|
||||
|
||||
def check_version_compatibility(
|
||||
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
|
||||
repo_id: str,
|
||||
version_to_check: str,
|
||||
current_version: str,
|
||||
enforce_breaking_major: bool = True,
|
||||
) -> None:
|
||||
current_major, _ = _get_major_minor(current_version)
|
||||
major_to_check, _ = _get_major_minor(version_to_check)
|
||||
@@ -317,7 +333,9 @@ def create_empty_dataset_info(
|
||||
def get_episode_data_index(
|
||||
episode_dicts: list[dict], episodes: list[int] | None = None
|
||||
) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)}
|
||||
episode_lengths = {
|
||||
ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)
|
||||
}
|
||||
if episodes is not None:
|
||||
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
|
||||
|
||||
@@ -338,7 +356,9 @@ def calculate_total_episode(
|
||||
return total_episodes
|
||||
|
||||
|
||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]:
|
||||
def calculate_episode_data_index(
|
||||
hf_dataset: datasets.Dataset,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = []
|
||||
table = hf_dataset.data.table
|
||||
total_episodes = calculate_total_episode(hf_dataset)
|
||||
@@ -380,7 +400,9 @@ def check_timestamps_sync(
|
||||
# Track original indices before masking
|
||||
original_indices = torch.arange(len(diffs))
|
||||
filtered_indices = original_indices[mask]
|
||||
outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance) # .squeeze()
|
||||
outside_tolerance_filtered_indices = torch.nonzero(
|
||||
~filtered_within_tolerance
|
||||
) # .squeeze()
|
||||
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
|
||||
episode_indices = torch.stack(hf_dataset["episode_index"])
|
||||
|
||||
@@ -405,7 +427,10 @@ def check_timestamps_sync(
|
||||
|
||||
|
||||
def check_delta_timestamps(
|
||||
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
|
||||
delta_timestamps: dict[str, list[float]],
|
||||
fps: int,
|
||||
tolerance_s: float,
|
||||
raise_value_error: bool = True,
|
||||
) -> bool:
|
||||
"""This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance.
|
||||
This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be
|
||||
@@ -413,10 +438,14 @@ def check_delta_timestamps(
|
||||
"""
|
||||
outside_tolerance = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts]
|
||||
within_tolerance = [
|
||||
abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts
|
||||
]
|
||||
if not all(within_tolerance):
|
||||
outside_tolerance[key] = [
|
||||
ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within
|
||||
ts
|
||||
for ts, is_within in zip(delta_ts, within_tolerance, strict=True)
|
||||
if not is_within
|
||||
]
|
||||
|
||||
if len(outside_tolerance) > 0:
|
||||
@@ -434,7 +463,9 @@ def check_delta_timestamps(
|
||||
return True
|
||||
|
||||
|
||||
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
|
||||
def get_delta_indices(
|
||||
delta_timestamps: dict[str, list[float]], fps: int
|
||||
) -> dict[str, list[int]]:
|
||||
delta_indices = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
delta_indices[key] = (torch.tensor(delta_ts) * fps).long().tolist()
|
||||
@@ -499,7 +530,9 @@ def create_lerobot_dataset_card(
|
||||
],
|
||||
)
|
||||
|
||||
card_template = (importlib.resources.files("lerobot.common.datasets") / "card_template.md").read_text()
|
||||
card_template = (
|
||||
importlib.resources.files("lerobot.common.datasets") / "card_template.md"
|
||||
).read_text()
|
||||
|
||||
return DatasetCard.from_template(
|
||||
card_data=card_data,
|
||||
|
||||
@@ -26,7 +26,10 @@ from pathlib import Path
|
||||
from textwrap import dedent
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset, parse_robot_config
|
||||
from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import (
|
||||
convert_dataset,
|
||||
parse_robot_config,
|
||||
)
|
||||
|
||||
LOCAL_DIR = Path("data/")
|
||||
|
||||
@@ -117,7 +120,10 @@ DATASETS = {
|
||||
"single_task": "Place the battery into the slot of the remote controller.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_candy": {"single_task": "Pick up the candy and unwrap it.", **ALOHA_STATIC_INFO},
|
||||
"aloha_static_candy": {
|
||||
"single_task": "Pick up the candy and unwrap it.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_coffee": {
|
||||
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
@@ -166,13 +172,22 @@ DATASETS = {
|
||||
"single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO},
|
||||
"aloha_sim_insertion_scripted": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
|
||||
"aloha_static_ziploc_slide": {
|
||||
"single_task": "Slide open the ziploc bag.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_insertion_scripted": {
|
||||
"single_task": "Insert the peg into the socket.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_insertion_scripted_image": {
|
||||
"single_task": "Insert the peg into the socket.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_insertion_human": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
|
||||
"aloha_sim_insertion_human": {
|
||||
"single_task": "Insert the peg into the socket.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_insertion_human_image": {
|
||||
"single_task": "Insert the peg into the socket.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
@@ -193,10 +208,19 @@ DATASETS = {
|
||||
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"pusht": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
|
||||
"pusht_image": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
|
||||
"pusht": {
|
||||
"single_task": "Push the T-shaped block onto the T-shaped target.",
|
||||
**PUSHT_INFO,
|
||||
},
|
||||
"pusht_image": {
|
||||
"single_task": "Push the T-shaped block onto the T-shaped target.",
|
||||
**PUSHT_INFO,
|
||||
},
|
||||
"unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO},
|
||||
"unitreeh1_rearrange_objects": {"single_task": "Put the object into the bin.", **UNITREEH_INFO},
|
||||
"unitreeh1_rearrange_objects": {
|
||||
"single_task": "Put the object into the bin.",
|
||||
**UNITREEH_INFO,
|
||||
},
|
||||
"unitreeh1_two_robot_greeting": {
|
||||
"single_task": "Greet the other robot with a high five.",
|
||||
**UNITREEH_INFO,
|
||||
@@ -206,13 +230,31 @@ DATASETS = {
|
||||
**UNITREEH_INFO,
|
||||
},
|
||||
"xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_lift_medium_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_lift_medium_replay": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_lift_medium_replay_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_lift_medium_image": {
|
||||
"single_task": "Pick up the cube and lift it.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"xarm_lift_medium_replay": {
|
||||
"single_task": "Pick up the cube and lift it.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"xarm_lift_medium_replay_image": {
|
||||
"single_task": "Pick up the cube and lift it.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"xarm_push_medium_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"xarm_push_medium_replay": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"xarm_push_medium_replay_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"xarm_push_medium_image": {
|
||||
"single_task": "Push the cube onto the target.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"xarm_push_medium_replay": {
|
||||
"single_task": "Push the cube onto the target.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"xarm_push_medium_replay_image": {
|
||||
"single_task": "Push the cube onto the target.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"umi_cup_in_the_wild": {
|
||||
"single_task": "Put the cup on the plate.",
|
||||
"license": "apache-2.0",
|
||||
|
||||
@@ -152,7 +152,9 @@ V1_INFO_PATH = "meta_data/info.json"
|
||||
V1_STATS_PATH = "meta_data/stats.safetensors"
|
||||
|
||||
|
||||
def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]:
|
||||
def parse_robot_config(
|
||||
config_path: Path, config_overrides: list[str] | None = None
|
||||
) -> tuple[str, dict]:
|
||||
robot_cfg = init_hydra_config(config_path, config_overrides)
|
||||
if robot_cfg["robot_type"] in ["aloha", "koch"]:
|
||||
state_names = [
|
||||
@@ -203,7 +205,9 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
|
||||
torch.testing.assert_close(stats_json[key], stats[key])
|
||||
|
||||
|
||||
def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = None) -> dict[str, list]:
|
||||
def get_features_from_hf_dataset(
|
||||
dataset: Dataset, robot_config: dict | None = None
|
||||
) -> dict[str, list]:
|
||||
features = {}
|
||||
for key, ft in dataset.features.items():
|
||||
if isinstance(ft, datasets.Value):
|
||||
@@ -215,7 +219,9 @@ def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = N
|
||||
dtype = ft.feature.dtype
|
||||
shape = (ft.length,)
|
||||
motor_names = (
|
||||
robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
|
||||
robot_config["names"][key]
|
||||
if robot_config
|
||||
else [f"motor_{i}" for i in range(ft.length)]
|
||||
)
|
||||
assert len(motor_names) == shape[0]
|
||||
names = {"motors": motor_names}
|
||||
@@ -239,11 +245,15 @@ def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = N
|
||||
return features
|
||||
|
||||
|
||||
def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
|
||||
def add_task_index_by_episodes(
|
||||
dataset: Dataset, tasks_by_episodes: dict
|
||||
) -> tuple[Dataset, list[str]]:
|
||||
df = dataset.to_pandas()
|
||||
tasks = list(set(tasks_by_episodes.values()))
|
||||
tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)}
|
||||
episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
episodes_to_task_index = {
|
||||
ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()
|
||||
}
|
||||
df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int)
|
||||
|
||||
features = dataset.features
|
||||
@@ -260,10 +270,19 @@ def add_task_index_from_tasks_col(
|
||||
# HACK: This is to clean some of the instructions in our version of Open X datasets
|
||||
prefix_to_clean = "tf.Tensor(b'"
|
||||
suffix_to_clean = "', shape=(), dtype=string)"
|
||||
df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean)
|
||||
df[tasks_col] = (
|
||||
df[tasks_col]
|
||||
.str.removeprefix(prefix_to_clean)
|
||||
.str.removesuffix(suffix_to_clean)
|
||||
)
|
||||
|
||||
# Create task_index col
|
||||
tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict()
|
||||
tasks_by_episode = (
|
||||
df.groupby("episode_index")[tasks_col]
|
||||
.unique()
|
||||
.apply(lambda x: x.tolist())
|
||||
.to_dict()
|
||||
)
|
||||
tasks = df[tasks_col].unique().tolist()
|
||||
tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)}
|
||||
df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int)
|
||||
@@ -288,7 +307,9 @@ def split_parquet_by_episodes(
|
||||
for ep_chunk in range(total_chunks):
|
||||
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
|
||||
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
||||
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
|
||||
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(
|
||||
episode_chunk=ep_chunk
|
||||
)
|
||||
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
||||
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
@@ -320,7 +341,9 @@ def move_videos(
|
||||
videos_moved = False
|
||||
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
|
||||
if len(video_files) == 0:
|
||||
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")]
|
||||
video_files = [
|
||||
str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")
|
||||
]
|
||||
videos_moved = True # Videos have already been moved
|
||||
|
||||
assert len(video_files) == total_episodes * len(video_keys)
|
||||
@@ -351,7 +374,9 @@ def move_videos(
|
||||
target_path = DEFAULT_VIDEO_PATH.format(
|
||||
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
|
||||
)
|
||||
video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
|
||||
video_file = V1_VIDEO_FILE.format(
|
||||
video_key=vid_key, episode_index=ep_idx
|
||||
)
|
||||
if len(video_dirs) == 1:
|
||||
video_path = video_dirs[0] / video_file
|
||||
else:
|
||||
@@ -368,7 +393,9 @@ def move_videos(
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None:
|
||||
def fix_lfs_video_files_tracking(
|
||||
work_dir: Path, lfs_untracked_videos: list[str]
|
||||
) -> None:
|
||||
"""
|
||||
HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
|
||||
there's no other option than to download the actual files and reupload them with lfs tracking.
|
||||
@@ -376,7 +403,12 @@ def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]
|
||||
for i in range(0, len(lfs_untracked_videos), 100):
|
||||
files = lfs_untracked_videos[i : i + 100]
|
||||
try:
|
||||
subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True)
|
||||
subprocess.run(
|
||||
["git", "rm", "--cached", *files],
|
||||
cwd=work_dir,
|
||||
capture_output=True,
|
||||
check=True,
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print("git rm --cached ERROR:")
|
||||
print(e.stderr)
|
||||
@@ -387,10 +419,14 @@ def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None:
|
||||
def fix_gitattributes(
|
||||
work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path
|
||||
) -> None:
|
||||
shutil.copyfile(clean_gittatributes, current_gittatributes)
|
||||
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
|
||||
subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True)
|
||||
subprocess.run(
|
||||
["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True
|
||||
)
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
@@ -399,7 +435,17 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
|
||||
repo_url = f"https://huggingface.co/datasets/{repo_id}"
|
||||
env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
|
||||
subprocess.run(
|
||||
["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)],
|
||||
[
|
||||
"git",
|
||||
"clone",
|
||||
"--branch",
|
||||
branch,
|
||||
"--single-branch",
|
||||
"--depth",
|
||||
"1",
|
||||
repo_url,
|
||||
str(work_dir),
|
||||
],
|
||||
check=True,
|
||||
env=env,
|
||||
)
|
||||
@@ -407,13 +453,19 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
|
||||
|
||||
def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]:
|
||||
lfs_tracked_files = subprocess.run(
|
||||
["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
|
||||
["git", "lfs", "ls-files", "-n"],
|
||||
cwd=work_dir,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
|
||||
return [f for f in video_files if f not in lfs_tracked_files]
|
||||
|
||||
|
||||
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
|
||||
def get_videos_info(
|
||||
repo_id: str, local_dir: Path, video_keys: list[str], branch: str
|
||||
) -> dict:
|
||||
# Assumes first episode
|
||||
video_files = [
|
||||
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
|
||||
@@ -421,7 +473,11 @@ def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch
|
||||
]
|
||||
hub_api = HfApi()
|
||||
hub_api.snapshot_download(
|
||||
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
local_dir=local_dir,
|
||||
revision=branch,
|
||||
allow_patterns=video_files,
|
||||
)
|
||||
videos_info_dict = {}
|
||||
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
|
||||
@@ -448,7 +504,11 @@ def convert_dataset(
|
||||
|
||||
hub_api = HfApi()
|
||||
hub_api.snapshot_download(
|
||||
repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/"
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
revision=v1,
|
||||
local_dir=v1x_dir,
|
||||
ignore_patterns="videos*/",
|
||||
)
|
||||
branch = "main"
|
||||
if test_branch:
|
||||
@@ -480,19 +540,31 @@ def convert_dataset(
|
||||
if single_task:
|
||||
tasks_by_episodes = {ep_idx: single_task for ep_idx in episode_indices}
|
||||
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
||||
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
tasks_by_episodes = {
|
||||
ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()
|
||||
}
|
||||
elif tasks_path:
|
||||
tasks_by_episodes = load_json(tasks_path)
|
||||
tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
|
||||
tasks_by_episodes = {
|
||||
int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()
|
||||
}
|
||||
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
||||
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
tasks_by_episodes = {
|
||||
ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()
|
||||
}
|
||||
elif tasks_col:
|
||||
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col)
|
||||
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(
|
||||
dataset, tasks_col
|
||||
)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
|
||||
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
|
||||
assert set(tasks) == {
|
||||
task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks
|
||||
}
|
||||
tasks = [
|
||||
{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)
|
||||
]
|
||||
write_jsonlines(tasks, v20_dir / TASKS_PATH)
|
||||
features["task_index"] = {
|
||||
"dtype": "int64",
|
||||
@@ -506,14 +578,25 @@ def convert_dataset(
|
||||
dataset = dataset.remove_columns(video_keys)
|
||||
clean_gitattr = Path(
|
||||
hub_api.hf_hub_download(
|
||||
repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
|
||||
repo_id=GITATTRIBUTES_REF,
|
||||
repo_type="dataset",
|
||||
local_dir=local_dir,
|
||||
filename=".gitattributes",
|
||||
)
|
||||
).absolute()
|
||||
with tempfile.TemporaryDirectory() as tmp_video_dir:
|
||||
move_videos(
|
||||
repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch
|
||||
repo_id,
|
||||
video_keys,
|
||||
total_episodes,
|
||||
total_chunks,
|
||||
Path(tmp_video_dir),
|
||||
clean_gitattr,
|
||||
branch,
|
||||
)
|
||||
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
|
||||
videos_info = get_videos_info(
|
||||
repo_id, v1x_dir, video_keys=video_keys, branch=branch
|
||||
)
|
||||
for key in video_keys:
|
||||
features[key]["shape"] = (
|
||||
videos_info[key].pop("video.height"),
|
||||
@@ -521,15 +604,22 @@ def convert_dataset(
|
||||
videos_info[key].pop("video.channels"),
|
||||
)
|
||||
features[key]["video_info"] = videos_info[key]
|
||||
assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
|
||||
assert math.isclose(
|
||||
videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3
|
||||
)
|
||||
if "encoding" in metadata_v1:
|
||||
assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
|
||||
assert (
|
||||
videos_info[key]["video.pix_fmt"]
|
||||
== metadata_v1["encoding"]["pix_fmt"]
|
||||
)
|
||||
else:
|
||||
assert metadata_v1.get("video", 0) == 0
|
||||
videos_info = None
|
||||
|
||||
# Split data into 1 parquet file by episode
|
||||
episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
|
||||
episode_lengths = split_parquet_by_episodes(
|
||||
dataset, total_episodes, total_chunks, v20_dir
|
||||
)
|
||||
|
||||
if robot_config is not None:
|
||||
robot_type = robot_config["robot_type"]
|
||||
@@ -540,7 +630,11 @@ def convert_dataset(
|
||||
|
||||
# Episodes
|
||||
episodes = [
|
||||
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
|
||||
{
|
||||
"episode_index": ep_idx,
|
||||
"tasks": tasks_by_episodes[ep_idx],
|
||||
"length": episode_lengths[ep_idx],
|
||||
}
|
||||
for ep_idx in episode_indices
|
||||
]
|
||||
write_jsonlines(episodes, v20_dir / EPISODES_PATH)
|
||||
@@ -563,16 +657,27 @@ def convert_dataset(
|
||||
}
|
||||
write_json(metadata_v2_0, v20_dir / INFO_PATH)
|
||||
convert_stats_to_json(v1x_dir, v20_dir)
|
||||
card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs)
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs
|
||||
)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
|
||||
hub_api.delete_folder(
|
||||
repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch
|
||||
)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
|
||||
hub_api.delete_folder(
|
||||
repo_id=repo_id,
|
||||
path_in_repo="meta_data",
|
||||
repo_type="dataset",
|
||||
revision=branch,
|
||||
)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
|
||||
hub_api.delete_folder(
|
||||
repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch
|
||||
)
|
||||
|
||||
hub_api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
@@ -655,7 +760,11 @@ def main():
|
||||
if not args.local_dir:
|
||||
args.local_dir = Path("/tmp/lerobot_dataset_v2")
|
||||
|
||||
robot_config = parse_robot_config(args.robot_config, args.robot_overrides) if args.robot_config else None
|
||||
robot_config = (
|
||||
parse_robot_config(args.robot_config, args.robot_overrides)
|
||||
if args.robot_config
|
||||
else None
|
||||
)
|
||||
del args.robot_config, args.robot_overrides
|
||||
|
||||
convert_dataset(**vars(args), robot_config=robot_config)
|
||||
|
||||
@@ -227,7 +227,9 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
result = subprocess.run(
|
||||
ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
@@ -241,7 +243,9 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
"has_audio": True,
|
||||
"audio.channels": audio_stream_info.get("channels", None),
|
||||
"audio.codec": audio_stream_info.get("codec_name", None),
|
||||
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
|
||||
"audio.bit_rate": int(audio_stream_info["bit_rate"])
|
||||
if audio_stream_info.get("bit_rate")
|
||||
else None,
|
||||
"audio.sample_rate": int(audio_stream_info["sample_rate"])
|
||||
if audio_stream_info.get("sample_rate")
|
||||
else None,
|
||||
@@ -263,7 +267,9 @@ def get_video_info(video_path: Path | str) -> dict:
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
result = subprocess.run(
|
||||
ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
|
||||
@@ -35,7 +35,9 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
||||
return
|
||||
|
||||
if "maniskill" in cfg.env.name:
|
||||
env = make_maniskill_env(cfg, n_envs if n_envs is not None else cfg.eval.batch_size)
|
||||
env = make_maniskill_env(
|
||||
cfg, n_envs if n_envs is not None else cfg.eval.batch_size
|
||||
)
|
||||
return env
|
||||
|
||||
package_name = f"gym_{cfg.env.name}"
|
||||
@@ -55,7 +57,11 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
||||
gym_kwgs["max_episode_steps"] = cfg.env.episode_length
|
||||
|
||||
# batched version of the env that returns an observation of shape (b, c)
|
||||
env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv
|
||||
env_cls = (
|
||||
gym.vector.AsyncVectorEnv
|
||||
if cfg.eval.use_async_envs
|
||||
else gym.vector.SyncVectorEnv
|
||||
)
|
||||
env = env_cls(
|
||||
[
|
||||
lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs)
|
||||
@@ -66,7 +72,9 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
||||
return env
|
||||
|
||||
|
||||
def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
|
||||
def make_maniskill_env(
|
||||
cfg: DictConfig, n_envs: int | None = None
|
||||
) -> gym.vector.VectorEnv | None:
|
||||
"""Make ManiSkill3 gym environment"""
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
|
||||
@@ -83,7 +91,9 @@ def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector
|
||||
# state should have the size of 25
|
||||
# env = ConvertToLeRobotEnv(env, n_envs)
|
||||
# env = PixelWrapper(cfg, env, n_envs)
|
||||
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
|
||||
env._max_episode_steps = env.max_episode_steps = (
|
||||
50 # gym_utils.find_max_episode_steps_value(env)
|
||||
)
|
||||
env.unwrapped.metadata["render_fps"] = 20
|
||||
|
||||
return env
|
||||
@@ -110,7 +120,11 @@ class PixelWrapper(gym.Wrapper):
|
||||
def _get_obs(self, obs):
|
||||
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
|
||||
self._frames.append(frame)
|
||||
return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)}
|
||||
return {
|
||||
"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(
|
||||
self.env.device
|
||||
)
|
||||
}
|
||||
|
||||
def reset(self, seed):
|
||||
obs, info = self.env.reset() # (seed=seed)
|
||||
@@ -123,6 +137,7 @@ class PixelWrapper(gym.Wrapper):
|
||||
return self._get_obs(obs), reward, terminated, truncated, info
|
||||
|
||||
|
||||
# TODO: Remove this
|
||||
class ConvertToLeRobotEnv(gym.Wrapper):
|
||||
def __init__(self, env, num_envs):
|
||||
super().__init__(env)
|
||||
@@ -144,7 +159,9 @@ class ConvertToLeRobotEnv(gym.Wrapper):
|
||||
|
||||
images = torch.concat(images, axis=-1)
|
||||
# flatten the rest of the data which should just be state data
|
||||
observation = common.flatten_state_dict(observation, use_torch=True, device=self.base_env.device)
|
||||
observation = common.flatten_state_dict(
|
||||
observation, use_torch=True, device=self.base_env.device
|
||||
)
|
||||
ret = dict()
|
||||
ret["state"] = observation
|
||||
ret["pixels"] = images
|
||||
|
||||
@@ -39,7 +39,9 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
img = img.unsqueeze(0)
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
assert (
|
||||
c < h and c < w
|
||||
), f"expect channel last images, but instead got {img.shape=}"
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
@@ -65,7 +67,9 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
return return_observations
|
||||
|
||||
|
||||
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
||||
def preprocess_maniskill_observation(
|
||||
observations: dict[str, np.ndarray],
|
||||
) -> dict[str, Tensor]:
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
Args:
|
||||
observation: Dictionary of observation batches from a Gym vector environment.
|
||||
|
||||
@@ -84,7 +84,9 @@ class Logger:
|
||||
pretrained_model_dir_name = "pretrained_model"
|
||||
training_state_file_name = "training_state.pth"
|
||||
|
||||
def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None):
|
||||
def __init__(
|
||||
self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
log_dir: The directory to save all logs and training outputs to.
|
||||
@@ -104,7 +106,9 @@ class Logger:
|
||||
enable_wandb = cfg.get("wandb", {}).get("enable", False)
|
||||
run_offline = not enable_wandb or not project
|
||||
if run_offline:
|
||||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
logging.info(
|
||||
colored("Logs will be saved locally.", "yellow", attrs=["bold"])
|
||||
)
|
||||
self._wandb = None
|
||||
else:
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
@@ -130,7 +134,9 @@ class Logger:
|
||||
# Handle custom step key for rl asynchronous training.
|
||||
self._wandb_custom_step_key: set[str] | None = None
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
logging.info(
|
||||
f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}"
|
||||
)
|
||||
self._wandb = wandb
|
||||
|
||||
@classmethod
|
||||
@@ -151,7 +157,9 @@ class Logger:
|
||||
"""
|
||||
return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name
|
||||
|
||||
def save_model(self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None):
|
||||
def save_model(
|
||||
self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None
|
||||
):
|
||||
"""Save the weights of the Policy model using PyTorchModelHubMixin.
|
||||
|
||||
The weights are saved in a folder called "pretrained_model" under the checkpoint directory.
|
||||
@@ -221,22 +229,30 @@ class Logger:
|
||||
else f"{self._group.replace(':', '_').replace('/', '_')}-{self._cfg.seed}-{identifier}"
|
||||
)
|
||||
self.save_model(
|
||||
checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name
|
||||
checkpoint_dir / self.pretrained_model_dir_name,
|
||||
policy,
|
||||
wandb_artifact_name=wandb_artifact_name,
|
||||
)
|
||||
self.save_training_state(
|
||||
checkpoint_dir, train_step, optimizer, scheduler, interaction_step
|
||||
)
|
||||
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler, interaction_step)
|
||||
os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir)
|
||||
|
||||
def load_last_training_state(self, optimizer: Optimizer | dict, scheduler: LRScheduler | None) -> int:
|
||||
def load_last_training_state(
|
||||
self, optimizer: Optimizer | dict, scheduler: LRScheduler | None
|
||||
) -> int:
|
||||
"""
|
||||
Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and
|
||||
random state, and return the global training step.
|
||||
"""
|
||||
training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name)
|
||||
training_state = torch.load(
|
||||
self.last_checkpoint_dir / self.training_state_file_name
|
||||
)
|
||||
# For the case where the optimizer is a dictionary of optimizers (e.g., sac)
|
||||
if type(training_state["optimizer"]) is dict:
|
||||
assert set(training_state["optimizer"].keys()) == set(optimizer.keys()), (
|
||||
"Optimizer dictionaries do not have the same keys during resume!"
|
||||
)
|
||||
assert set(training_state["optimizer"].keys()) == set(
|
||||
optimizer.keys()
|
||||
), "Optimizer dictionaries do not have the same keys during resume!"
|
||||
for k, v in training_state["optimizer"].items():
|
||||
optimizer[k].load_state_dict(v)
|
||||
else:
|
||||
@@ -248,10 +264,18 @@ class Logger:
|
||||
"The checkpoint contains a scheduler state_dict, but no LRScheduler was provided."
|
||||
)
|
||||
# Small hack to get the expected keys: use `get_global_random_state`.
|
||||
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
|
||||
set_global_random_state(
|
||||
{k: training_state[k] for k in get_global_random_state()}
|
||||
)
|
||||
return training_state["step"]
|
||||
|
||||
def log_dict(self, d, step: int | None = None, mode="train", custom_step_key: str | None = None):
|
||||
def log_dict(
|
||||
self,
|
||||
d,
|
||||
step: int | None = None,
|
||||
mode="train",
|
||||
custom_step_key: str | None = None,
|
||||
):
|
||||
"""Log a dictionary of metrics to WandB."""
|
||||
assert mode in {"train", "eval"}
|
||||
# TODO(alexander-soare): Add local text log.
|
||||
@@ -280,12 +304,20 @@ class Logger:
|
||||
continue
|
||||
|
||||
# Do not log the custom step key itself.
|
||||
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
|
||||
if (
|
||||
self._wandb_custom_step_key is not None
|
||||
and k in self._wandb_custom_step_key
|
||||
):
|
||||
continue
|
||||
|
||||
if custom_step_key is not None:
|
||||
value_custom_step = d[custom_step_key]
|
||||
self._wandb.log({f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step})
|
||||
self._wandb.log(
|
||||
{
|
||||
f"{mode}/{k}": v,
|
||||
f"{mode}/{custom_step_key}": value_custom_step,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
|
||||
|
||||
@@ -168,4 +168,6 @@ class ACTConfig:
|
||||
not any(k.startswith("observation.image") for k in self.input_shapes)
|
||||
and "observation.environment_state" not in self.input_shapes
|
||||
):
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
raise ValueError(
|
||||
"You must provide at least one image or the environment state among the inputs."
|
||||
)
|
||||
|
||||
@@ -81,10 +81,14 @@ class ACTPolicy(
|
||||
|
||||
self.model = ACT(config)
|
||||
|
||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
self.expected_image_keys = [
|
||||
k for k in config.input_shapes if k.startswith("observation.image")
|
||||
]
|
||||
|
||||
if config.temporal_ensemble_coeff is not None:
|
||||
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
|
||||
self.temporal_ensembler = ACTTemporalEnsembler(
|
||||
config.temporal_ensemble_coeff, config.chunk_size
|
||||
)
|
||||
|
||||
self.reset()
|
||||
|
||||
@@ -107,8 +111,12 @@ class ACTPolicy(
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[k] for k in self.expected_image_keys], dim=-4
|
||||
)
|
||||
|
||||
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
||||
# we are ensembling over.
|
||||
@@ -135,13 +143,18 @@ class ACTPolicy(
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[k] for k in self.expected_image_keys], dim=-4
|
||||
)
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
l1_loss = (
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none")
|
||||
* ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss.item()}
|
||||
@@ -151,7 +164,12 @@ class ACTPolicy(
|
||||
# KL-divergence per batch element, then take the mean over the batch.
|
||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||
mean_kld = (
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
(
|
||||
-0.5
|
||||
* (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())
|
||||
)
|
||||
.sum(-1)
|
||||
.mean()
|
||||
)
|
||||
loss_dict["kld_loss"] = mean_kld.item()
|
||||
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
||||
@@ -205,7 +223,9 @@ class ACTTemporalEnsembler:
|
||||
```
|
||||
"""
|
||||
self.chunk_size = chunk_size
|
||||
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
|
||||
self.ensemble_weights = torch.exp(
|
||||
-temporal_ensemble_coeff * torch.arange(chunk_size)
|
||||
)
|
||||
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
|
||||
self.reset()
|
||||
|
||||
@@ -221,7 +241,9 @@ class ACTTemporalEnsembler:
|
||||
time steps, and pop/return the next batch of actions in the sequence.
|
||||
"""
|
||||
self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
|
||||
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device)
|
||||
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(
|
||||
device=actions.device
|
||||
)
|
||||
if self.ensembled_actions is None:
|
||||
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
|
||||
# time step of the episode.
|
||||
@@ -229,19 +251,34 @@ class ACTTemporalEnsembler:
|
||||
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
|
||||
# operations later.
|
||||
self.ensembled_actions_count = torch.ones(
|
||||
(self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
|
||||
(self.chunk_size, 1),
|
||||
dtype=torch.long,
|
||||
device=self.ensembled_actions.device,
|
||||
)
|
||||
else:
|
||||
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
|
||||
# the online update for those entries.
|
||||
self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]
|
||||
self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
|
||||
self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]
|
||||
self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size)
|
||||
self.ensembled_actions *= self.ensemble_weights_cumsum[
|
||||
self.ensembled_actions_count - 1
|
||||
]
|
||||
self.ensembled_actions += (
|
||||
actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
|
||||
)
|
||||
self.ensembled_actions /= self.ensemble_weights_cumsum[
|
||||
self.ensembled_actions_count
|
||||
]
|
||||
self.ensembled_actions_count = torch.clamp(
|
||||
self.ensembled_actions_count + 1, max=self.chunk_size
|
||||
)
|
||||
# The last action, which has no prior online average, needs to get concatenated onto the end.
|
||||
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
|
||||
self.ensembled_actions = torch.cat(
|
||||
[self.ensembled_actions, actions[:, -1:]], dim=1
|
||||
)
|
||||
self.ensembled_actions_count = torch.cat(
|
||||
[self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])]
|
||||
[
|
||||
self.ensembled_actions_count,
|
||||
torch.ones_like(self.ensembled_actions_count[-1:]),
|
||||
]
|
||||
)
|
||||
# "Consume" the first action.
|
||||
action, self.ensembled_actions, self.ensembled_actions_count = (
|
||||
@@ -293,7 +330,9 @@ class ACT(nn.Module):
|
||||
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
|
||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||
self.use_robot_state = "observation.state" in config.input_shapes
|
||||
self.use_images = any(k.startswith("observation.image") for k in config.input_shapes)
|
||||
self.use_images = any(
|
||||
k.startswith("observation.image") for k in config.input_shapes
|
||||
)
|
||||
self.use_env_state = "observation.environment_state" in config.input_shapes
|
||||
if self.config.use_vae:
|
||||
self.vae_encoder = ACTEncoder(config, is_vae_encoder=True)
|
||||
@@ -308,7 +347,9 @@ class ACT(nn.Module):
|
||||
config.output_shapes["action"][0], config.dim_model
|
||||
)
|
||||
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
||||
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
|
||||
self.vae_encoder_latent_output_proj = nn.Linear(
|
||||
config.dim_model, config.latent_dim * 2
|
||||
)
|
||||
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
|
||||
# dimension.
|
||||
num_input_token_encoder = 1 + config.chunk_size
|
||||
@@ -316,20 +357,28 @@ class ACT(nn.Module):
|
||||
num_input_token_encoder += 1
|
||||
self.register_buffer(
|
||||
"vae_encoder_pos_enc",
|
||||
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
|
||||
create_sinusoidal_pos_embedding(
|
||||
num_input_token_encoder, config.dim_model
|
||||
).unsqueeze(0),
|
||||
)
|
||||
|
||||
# Backbone for image feature extraction.
|
||||
if self.use_images:
|
||||
backbone_model = getattr(torchvision.models, config.vision_backbone)(
|
||||
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
|
||||
replace_stride_with_dilation=[
|
||||
False,
|
||||
False,
|
||||
config.replace_final_stride_with_dilation,
|
||||
],
|
||||
weights=config.pretrained_backbone_weights,
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
|
||||
# feature map).
|
||||
# Note: The forward method of this returns a dict: {"feature_map": output}.
|
||||
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
|
||||
self.backbone = IntermediateLayerGetter(
|
||||
backbone_model, return_layers={"layer4": "feature_map"}
|
||||
)
|
||||
|
||||
# Transformer (acts as VAE decoder when training with the variational objective).
|
||||
self.encoder = ACTEncoder(config)
|
||||
@@ -343,7 +392,8 @@ class ACT(nn.Module):
|
||||
)
|
||||
if self.use_env_state:
|
||||
self.encoder_env_state_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.environment_state"][0], config.dim_model
|
||||
config.input_shapes["observation.environment_state"][0],
|
||||
config.dim_model,
|
||||
)
|
||||
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
|
||||
if self.use_images:
|
||||
@@ -358,14 +408,18 @@ class ACT(nn.Module):
|
||||
n_1d_tokens += 1
|
||||
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
|
||||
if self.use_images:
|
||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(
|
||||
config.dim_model // 2
|
||||
)
|
||||
|
||||
# Transformer decoder.
|
||||
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
|
||||
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
|
||||
|
||||
# Final action regression head on the output of the transformer's decoder.
|
||||
self.action_head = nn.Linear(config.dim_model, config.output_shapes["action"][0])
|
||||
self.action_head = nn.Linear(
|
||||
config.dim_model, config.output_shapes["action"][0]
|
||||
)
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
@@ -375,7 +429,9 @@ class ACT(nn.Module):
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
|
||||
def forward(
|
||||
self, batch: dict[str, Tensor]
|
||||
) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
|
||||
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
|
||||
|
||||
`batch` should have the following structure:
|
||||
@@ -412,12 +468,20 @@ class ACT(nn.Module):
|
||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||
) # (B, 1, D)
|
||||
if self.use_robot_state:
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(
|
||||
batch["observation.state"]
|
||||
)
|
||||
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(
|
||||
batch["action"]
|
||||
) # (B, S, D)
|
||||
|
||||
if self.use_robot_state:
|
||||
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
||||
vae_encoder_input = [
|
||||
cls_embed,
|
||||
robot_state_embed,
|
||||
action_embed,
|
||||
] # (B, S+2, D)
|
||||
else:
|
||||
vae_encoder_input = [cls_embed, action_embed]
|
||||
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
|
||||
@@ -455,20 +519,26 @@ class ACT(nn.Module):
|
||||
# When not using the VAE encoder, we set the latent to be all zeros.
|
||||
mu = log_sigma_x2 = None
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
|
||||
batch["observation.state"].device
|
||||
)
|
||||
latent_sample = torch.zeros(
|
||||
[batch_size, self.config.latent_dim], dtype=torch.float32
|
||||
).to(batch["observation.state"].device)
|
||||
|
||||
# Prepare transformer encoder inputs.
|
||||
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
|
||||
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
|
||||
encoder_in_pos_embed = list(
|
||||
self.encoder_1d_feature_pos_embed.weight.unsqueeze(1)
|
||||
)
|
||||
# Robot state token.
|
||||
if self.use_robot_state:
|
||||
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
|
||||
encoder_in_tokens.append(
|
||||
self.encoder_robot_state_input_proj(batch["observation.state"])
|
||||
)
|
||||
# Environment state token.
|
||||
if self.use_env_state:
|
||||
encoder_in_tokens.append(
|
||||
self.encoder_env_state_input_proj(batch["observation.environment_state"])
|
||||
self.encoder_env_state_input_proj(
|
||||
batch["observation.environment_state"]
|
||||
)
|
||||
)
|
||||
|
||||
# Camera observation features and positional embeddings.
|
||||
@@ -477,19 +547,29 @@ class ACT(nn.Module):
|
||||
all_cam_pos_embeds = []
|
||||
|
||||
for cam_index in range(batch["observation.images"].shape[-4]):
|
||||
cam_features = self.backbone(batch["observation.images"][:, cam_index])["feature_map"]
|
||||
cam_features = self.backbone(batch["observation.images"][:, cam_index])[
|
||||
"feature_map"
|
||||
]
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
|
||||
# buffer
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(
|
||||
dtype=cam_features.dtype
|
||||
)
|
||||
cam_features = self.encoder_img_feat_input_proj(
|
||||
cam_features
|
||||
) # (B, C, h, w)
|
||||
all_cam_features.append(cam_features)
|
||||
all_cam_pos_embeds.append(cam_pos_embed)
|
||||
# Concatenate camera observation feature maps and positional embeddings along the width dimension,
|
||||
# and move to (sequence, batch, dim).
|
||||
all_cam_features = torch.cat(all_cam_features, axis=-1)
|
||||
encoder_in_tokens.extend(einops.rearrange(all_cam_features, "b c h w -> (h w) b c"))
|
||||
encoder_in_tokens.extend(
|
||||
einops.rearrange(all_cam_features, "b c h w -> (h w) b c")
|
||||
)
|
||||
all_cam_pos_embeds = torch.cat(all_cam_pos_embeds, axis=-1)
|
||||
encoder_in_pos_embed.extend(einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c"))
|
||||
encoder_in_pos_embed.extend(
|
||||
einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c")
|
||||
)
|
||||
|
||||
# Stack all tokens along the sequence dimension.
|
||||
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
|
||||
@@ -524,12 +604,21 @@ class ACTEncoder(nn.Module):
|
||||
def __init__(self, config: ACTConfig, is_vae_encoder: bool = False):
|
||||
super().__init__()
|
||||
self.is_vae_encoder = is_vae_encoder
|
||||
num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers
|
||||
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)])
|
||||
num_layers = (
|
||||
config.n_vae_encoder_layers
|
||||
if self.is_vae_encoder
|
||||
else config.n_encoder_layers
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[ACTEncoderLayer(config) for _ in range(num_layers)]
|
||||
)
|
||||
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
|
||||
|
||||
def forward(
|
||||
self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
|
||||
self,
|
||||
x: Tensor,
|
||||
pos_embed: Tensor | None = None,
|
||||
key_padding_mask: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
for layer in self.layers:
|
||||
x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
|
||||
@@ -540,7 +629,9 @@ class ACTEncoder(nn.Module):
|
||||
class ACTEncoderLayer(nn.Module):
|
||||
def __init__(self, config: ACTConfig):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
|
||||
self.self_attn = nn.MultiheadAttention(
|
||||
config.dim_model, config.n_heads, dropout=config.dropout
|
||||
)
|
||||
|
||||
# Feed forward layers.
|
||||
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
|
||||
@@ -555,7 +646,9 @@ class ACTEncoderLayer(nn.Module):
|
||||
self.activation = get_activation_fn(config.feedforward_activation)
|
||||
self.pre_norm = config.pre_norm
|
||||
|
||||
def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor:
|
||||
def forward(
|
||||
self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
|
||||
) -> Tensor:
|
||||
skip = x
|
||||
if self.pre_norm:
|
||||
x = self.norm1(x)
|
||||
@@ -580,7 +673,9 @@ class ACTDecoder(nn.Module):
|
||||
def __init__(self, config: ACTConfig):
|
||||
"""Convenience module for running multiple decoder layers followed by normalization."""
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)])
|
||||
self.layers = nn.ModuleList(
|
||||
[ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]
|
||||
)
|
||||
self.norm = nn.LayerNorm(config.dim_model)
|
||||
|
||||
def forward(
|
||||
@@ -592,7 +687,10 @@ class ACTDecoder(nn.Module):
|
||||
) -> Tensor:
|
||||
for layer in self.layers:
|
||||
x = layer(
|
||||
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
|
||||
x,
|
||||
encoder_out,
|
||||
decoder_pos_embed=decoder_pos_embed,
|
||||
encoder_pos_embed=encoder_pos_embed,
|
||||
)
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
@@ -602,8 +700,12 @@ class ACTDecoder(nn.Module):
|
||||
class ACTDecoderLayer(nn.Module):
|
||||
def __init__(self, config: ACTConfig):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
|
||||
self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
|
||||
self.self_attn = nn.MultiheadAttention(
|
||||
config.dim_model, config.n_heads, dropout=config.dropout
|
||||
)
|
||||
self.multihead_attn = nn.MultiheadAttention(
|
||||
config.dim_model, config.n_heads, dropout=config.dropout
|
||||
)
|
||||
|
||||
# Feed forward layers.
|
||||
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
|
||||
@@ -644,7 +746,9 @@ class ACTDecoderLayer(nn.Module):
|
||||
if self.pre_norm:
|
||||
x = self.norm1(x)
|
||||
q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
|
||||
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
|
||||
x = self.self_attn(q, k, value=x)[
|
||||
0
|
||||
] # select just the output, not the attention weights
|
||||
x = skip + self.dropout1(x)
|
||||
if self.pre_norm:
|
||||
skip = x
|
||||
@@ -681,9 +785,14 @@ def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tenso
|
||||
"""
|
||||
|
||||
def get_position_angle_vec(position):
|
||||
return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)]
|
||||
return [
|
||||
position / np.power(10000, 2 * (hid_j // 2) / dimension)
|
||||
for hid_j in range(dimension)
|
||||
]
|
||||
|
||||
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)])
|
||||
sinusoid_table = np.array(
|
||||
[get_position_angle_vec(pos_i) for pos_i in range(num_positions)]
|
||||
)
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
return torch.from_numpy(sinusoid_table).float()
|
||||
@@ -728,7 +837,9 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
|
||||
x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
|
||||
|
||||
inverse_frequency = self._temperature ** (
|
||||
2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
|
||||
2
|
||||
* (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2)
|
||||
/ self.dimension
|
||||
)
|
||||
|
||||
x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
|
||||
@@ -736,9 +847,15 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
|
||||
|
||||
# Note: this stack then flatten operation results in interleaved sine and cosine terms.
|
||||
# pos_embed_x and pos_embed_y are (1, H, W, C // 2).
|
||||
pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3)
|
||||
pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3)
|
||||
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W)
|
||||
pos_embed_x = torch.stack(
|
||||
(x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1
|
||||
).flatten(3)
|
||||
pos_embed_y = torch.stack(
|
||||
(y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1
|
||||
).flatten(3)
|
||||
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(
|
||||
0, 3, 1, 2
|
||||
) # (1, C, H, W)
|
||||
|
||||
return pos_embed
|
||||
|
||||
|
||||
@@ -121,7 +121,9 @@ class DiffusionConfig:
|
||||
"observation.state": "min_max",
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
|
||||
output_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {"action": "min_max"}
|
||||
)
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
@@ -163,8 +165,13 @@ class DiffusionConfig:
|
||||
|
||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||
|
||||
if len(image_keys) == 0 and "observation.environment_state" not in self.input_shapes:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
if (
|
||||
len(image_keys) == 0
|
||||
and "observation.environment_state" not in self.input_shapes
|
||||
):
|
||||
raise ValueError(
|
||||
"You must provide at least one image or the environment state among the inputs."
|
||||
)
|
||||
|
||||
if len(image_keys) > 0:
|
||||
if self.crop_shape is not None:
|
||||
|
||||
@@ -88,7 +88,9 @@ class DiffusionPolicy(
|
||||
|
||||
self.diffusion = DiffusionModel(config)
|
||||
|
||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
self.expected_image_keys = [
|
||||
k for k in config.input_shapes if k.startswith("observation.image")
|
||||
]
|
||||
self.use_env_state = "observation.environment_state" in config.input_shapes
|
||||
|
||||
self.reset()
|
||||
@@ -102,7 +104,9 @@ class DiffusionPolicy(
|
||||
if len(self.expected_image_keys) > 0:
|
||||
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
|
||||
if self.use_env_state:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
||||
self._queues["observation.environment_state"] = deque(
|
||||
maxlen=self.config.n_obs_steps
|
||||
)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
@@ -128,14 +132,22 @@ class DiffusionPolicy(
|
||||
"""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[k] for k in self.expected_image_keys], dim=-4
|
||||
)
|
||||
# Note: It's important that this happens after stacking the images into a single key.
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
# stack n latest observations from the queue
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
batch = {
|
||||
k: torch.stack(list(self._queues[k]), dim=1)
|
||||
for k in batch
|
||||
if k in self._queues
|
||||
}
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
|
||||
# TODO(rcadene): make above methods return output dictionary?
|
||||
@@ -150,8 +162,12 @@ class DiffusionPolicy(
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[k] for k in self.expected_image_keys], dim=-4
|
||||
)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
return {"loss": loss}
|
||||
@@ -177,7 +193,9 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
# Build observation encoders (depending on which observations are provided).
|
||||
global_cond_dim = config.input_shapes["observation.state"][0]
|
||||
num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
|
||||
num_images = len(
|
||||
[k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
)
|
||||
self._use_images = False
|
||||
self._use_env_state = False
|
||||
if num_images > 0:
|
||||
@@ -193,7 +211,9 @@ class DiffusionModel(nn.Module):
|
||||
self._use_env_state = True
|
||||
global_cond_dim += config.input_shapes["observation.environment_state"][0]
|
||||
|
||||
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
|
||||
self.unet = DiffusionConditionalUnet1d(
|
||||
config, global_cond_dim=global_cond_dim * config.n_obs_steps
|
||||
)
|
||||
|
||||
self.noise_scheduler = _make_noise_scheduler(
|
||||
config.noise_scheduler_type,
|
||||
@@ -213,14 +233,21 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
# ========= inference ============
|
||||
def conditional_sample(
|
||||
self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None
|
||||
self,
|
||||
batch_size: int,
|
||||
global_cond: Tensor | None = None,
|
||||
generator: torch.Generator | None = None,
|
||||
) -> Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
dtype = get_dtype_from_parameters(self)
|
||||
|
||||
# Sample prior.
|
||||
sample = torch.randn(
|
||||
size=(batch_size, self.config.horizon, self.config.output_shapes["action"][0]),
|
||||
size=(
|
||||
batch_size,
|
||||
self.config.horizon,
|
||||
self.config.output_shapes["action"][0],
|
||||
),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
@@ -236,7 +263,9 @@ class DiffusionModel(nn.Module):
|
||||
global_cond=global_cond,
|
||||
)
|
||||
# Compute previous image: x_t -> x_t-1
|
||||
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
|
||||
sample = self.noise_scheduler.step(
|
||||
model_output, t, sample, generator=generator
|
||||
).prev_sample
|
||||
|
||||
return sample
|
||||
|
||||
@@ -248,27 +277,39 @@ class DiffusionModel(nn.Module):
|
||||
if self._use_images:
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
|
||||
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
|
||||
images_per_camera = einops.rearrange(
|
||||
batch["observation.images"], "b s n ... -> n (b s) ..."
|
||||
)
|
||||
img_features_list = torch.cat(
|
||||
[
|
||||
encoder(images)
|
||||
for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True)
|
||||
for encoder, images in zip(
|
||||
self.rgb_encoder, images_per_camera, strict=True
|
||||
)
|
||||
]
|
||||
)
|
||||
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
img_features = einops.rearrange(
|
||||
img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||
img_features_list,
|
||||
"(n b s) ... -> b s (n ...)",
|
||||
b=batch_size,
|
||||
s=n_obs_steps,
|
||||
)
|
||||
else:
|
||||
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
|
||||
img_features = self.rgb_encoder(
|
||||
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
||||
einops.rearrange(
|
||||
batch["observation.images"], "b s n ... -> (b s n) ..."
|
||||
)
|
||||
)
|
||||
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
img_features = einops.rearrange(
|
||||
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||
img_features,
|
||||
"(b s n) ... -> b s (n ...)",
|
||||
b=batch_size,
|
||||
s=n_obs_steps,
|
||||
)
|
||||
global_cond_feats.append(img_features)
|
||||
|
||||
@@ -354,7 +395,9 @@ class DiffusionModel(nn.Module):
|
||||
elif self.config.prediction_type == "sample":
|
||||
target = batch["action"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
|
||||
raise ValueError(
|
||||
f"Unsupported prediction type {self.config.prediction_type}"
|
||||
)
|
||||
|
||||
loss = F.mse_loss(pred, target, reduction="none")
|
||||
|
||||
@@ -414,7 +457,9 @@ class SpatialSoftmax(nn.Module):
|
||||
|
||||
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
|
||||
# and causes a small degradation in pc_success of pre-trained models.
|
||||
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
|
||||
pos_x, pos_y = np.meshgrid(
|
||||
np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)
|
||||
)
|
||||
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
|
||||
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
|
||||
# register as buffer so it's moved to the correct device.
|
||||
@@ -456,7 +501,9 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
# Always use center crop for eval
|
||||
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
|
||||
if config.crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(
|
||||
config.crop_shape
|
||||
)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
@@ -477,7 +524,9 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
self.backbone = _replace_submodules(
|
||||
root_module=self.backbone,
|
||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||
func=lambda x: nn.GroupNorm(
|
||||
num_groups=x.num_features // 16, num_channels=x.num_features
|
||||
),
|
||||
)
|
||||
|
||||
# Set up pooling and final layers.
|
||||
@@ -485,17 +534,25 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
# The dummy input should take the number of image channels from `config.input_shapes` and it should
|
||||
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
||||
# height and width from `config.input_shapes`.
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
image_keys = [
|
||||
k for k in config.input_shapes if k.startswith("observation.image")
|
||||
]
|
||||
# Note: we have a check in the config class to make sure all images have the same shape.
|
||||
image_key = image_keys[0]
|
||||
dummy_input_h_w = (
|
||||
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
|
||||
config.crop_shape
|
||||
if config.crop_shape is not None
|
||||
else config.input_shapes[image_key][1:]
|
||||
)
|
||||
dummy_input = torch.zeros(
|
||||
size=(1, config.input_shapes[image_key][0], *dummy_input_h_w)
|
||||
)
|
||||
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
|
||||
with torch.inference_mode():
|
||||
dummy_feature_map = self.backbone(dummy_input)
|
||||
feature_map_shape = tuple(dummy_feature_map.shape[1:])
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
self.pool = SpatialSoftmax(
|
||||
feature_map_shape, num_kp=config.spatial_softmax_num_keypoints
|
||||
)
|
||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
self.relu = nn.ReLU()
|
||||
@@ -522,7 +579,9 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
|
||||
|
||||
def _replace_submodules(
|
||||
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||
root_module: nn.Module,
|
||||
predicate: Callable[[nn.Module], bool],
|
||||
func: Callable[[nn.Module], nn.Module],
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Args:
|
||||
@@ -535,7 +594,11 @@ def _replace_submodules(
|
||||
if predicate(root_module):
|
||||
return func(root_module)
|
||||
|
||||
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
||||
replace_list = [
|
||||
k.split(".")
|
||||
for k, m in root_module.named_modules(remove_duplicate=True)
|
||||
if predicate(m)
|
||||
]
|
||||
for *parents, k in replace_list:
|
||||
parent_module = root_module
|
||||
if len(parents) > 0:
|
||||
@@ -550,7 +613,9 @@ def _replace_submodules(
|
||||
else:
|
||||
setattr(parent_module, k, tgt_module)
|
||||
# verify that all BN are replaced
|
||||
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
|
||||
assert not any(
|
||||
predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)
|
||||
)
|
||||
return root_module
|
||||
|
||||
|
||||
@@ -578,7 +643,9 @@ class DiffusionConv1dBlock(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
||||
nn.Conv1d(
|
||||
inp_channels, out_channels, kernel_size, padding=kernel_size // 2
|
||||
),
|
||||
nn.GroupNorm(n_groups, out_channels),
|
||||
nn.Mish(),
|
||||
)
|
||||
@@ -601,9 +668,13 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
# Encoder for the diffusion timestep.
|
||||
self.diffusion_step_encoder = nn.Sequential(
|
||||
DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim),
|
||||
nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4),
|
||||
nn.Linear(
|
||||
config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4
|
||||
),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim),
|
||||
nn.Linear(
|
||||
config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim
|
||||
),
|
||||
)
|
||||
|
||||
# The FiLM conditioning dimension.
|
||||
@@ -628,10 +699,16 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
self.down_modules.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs),
|
||||
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
dim_in, dim_out, **common_res_block_kwargs
|
||||
),
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
dim_out, dim_out, **common_res_block_kwargs
|
||||
),
|
||||
# Downsample as long as it is not the last block.
|
||||
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
|
||||
nn.Conv1d(dim_out, dim_out, 3, 2, 1)
|
||||
if not is_last
|
||||
else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
@@ -640,10 +717,14 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
self.mid_modules = nn.ModuleList(
|
||||
[
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
|
||||
config.down_dims[-1],
|
||||
config.down_dims[-1],
|
||||
**common_res_block_kwargs,
|
||||
),
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
|
||||
config.down_dims[-1],
|
||||
config.down_dims[-1],
|
||||
**common_res_block_kwargs,
|
||||
),
|
||||
]
|
||||
)
|
||||
@@ -656,16 +737,24 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
nn.ModuleList(
|
||||
[
|
||||
# dim_in * 2, because it takes the encoder's skip connection as well
|
||||
DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs),
|
||||
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
dim_in * 2, dim_out, **common_res_block_kwargs
|
||||
),
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
dim_out, dim_out, **common_res_block_kwargs
|
||||
),
|
||||
# Upsample as long as it is not the last block.
|
||||
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
|
||||
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1)
|
||||
if not is_last
|
||||
else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
|
||||
DiffusionConv1dBlock(
|
||||
config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size
|
||||
),
|
||||
nn.Conv1d(config.down_dims[0], config.output_shapes["action"][0], 1),
|
||||
)
|
||||
|
||||
@@ -733,17 +822,23 @@ class DiffusionConditionalResidualBlock1d(nn.Module):
|
||||
self.use_film_scale_modulation = use_film_scale_modulation
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||
self.conv1 = DiffusionConv1dBlock(
|
||||
in_channels, out_channels, kernel_size, n_groups=n_groups
|
||||
)
|
||||
|
||||
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
|
||||
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
|
||||
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
|
||||
|
||||
self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||
self.conv2 = DiffusionConv1dBlock(
|
||||
out_channels, out_channels, kernel_size, n_groups=n_groups
|
||||
)
|
||||
|
||||
# A final convolution for dimension matching the residual (if needed).
|
||||
self.residual_conv = (
|
||||
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
|
||||
nn.Conv1d(in_channels, out_channels, 1)
|
||||
if in_channels != out_channels
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
|
||||
|
||||
@@ -52,7 +52,9 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
|
||||
|
||||
return TDMPCPolicy, TDMPCConfig
|
||||
elif name == "diffusion":
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import (
|
||||
DiffusionConfig,
|
||||
)
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
return DiffusionPolicy, DiffusionConfig
|
||||
@@ -115,7 +117,9 @@ def make_policy(
|
||||
# huggingface_hub should make it possible to avoid the hack:
|
||||
# https://github.com/huggingface/huggingface_hub/pull/2274.
|
||||
policy = policy_cls(policy_cfg)
|
||||
policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())
|
||||
policy.load_state_dict(
|
||||
policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict()
|
||||
)
|
||||
|
||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||
|
||||
|
||||
@@ -7,7 +7,9 @@ from torch import Tensor, nn
|
||||
|
||||
from .configuration_classifier import ClassifierConfig
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -15,7 +17,10 @@ class ClassifierOutput:
|
||||
"""Wrapper for classifier outputs with additional metadata."""
|
||||
|
||||
def __init__(
|
||||
self, logits: Tensor, probabilities: Optional[Tensor] = None, hidden_states: Optional[Tensor] = None
|
||||
self,
|
||||
logits: Tensor,
|
||||
probabilities: Optional[Tensor] = None,
|
||||
hidden_states: Optional[Tensor] = None,
|
||||
):
|
||||
self.logits = logits
|
||||
self.probabilities = probabilities
|
||||
@@ -43,12 +48,14 @@ class Classifier(
|
||||
name = "classifier"
|
||||
|
||||
def __init__(self, config: ClassifierConfig):
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
from transformers import AutoModel
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
encoder = AutoModel.from_pretrained(
|
||||
self.config.model_name, trust_remote_code=True
|
||||
)
|
||||
# Extract vision model if we're given a multimodal model
|
||||
if hasattr(encoder, "vision_model"):
|
||||
logging.info("Multimodal model detected - using vision encoder only")
|
||||
@@ -74,7 +81,9 @@ class Classifier(
|
||||
self.feature_dim = self.encoder.fc.in_features
|
||||
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
|
||||
elif hasattr(self.encoder.config, "hidden_sizes"):
|
||||
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
|
||||
self.feature_dim = self.encoder.config.hidden_sizes[
|
||||
-1
|
||||
] # Last channel dimension
|
||||
else:
|
||||
raise ValueError("Unsupported CNN architecture")
|
||||
|
||||
@@ -94,14 +103,19 @@ class Classifier(
|
||||
if hasattr(self.encoder.config, "hidden_size"):
|
||||
input_dim = self.encoder.config.hidden_size
|
||||
else:
|
||||
raise ValueError("Unsupported transformer architecture since hidden_size is not found")
|
||||
raise ValueError(
|
||||
"Unsupported transformer architecture since hidden_size is not found"
|
||||
)
|
||||
|
||||
self.classifier_head = nn.Sequential(
|
||||
nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
|
||||
nn.Dropout(self.config.dropout_rate),
|
||||
nn.LayerNorm(self.config.hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes),
|
||||
nn.Linear(
|
||||
self.config.hidden_dim,
|
||||
1 if self.config.num_classes == 2 else self.config.num_classes,
|
||||
),
|
||||
)
|
||||
self.classifier_head = self.classifier_head.to(self.config.device)
|
||||
|
||||
@@ -127,7 +141,10 @@ class Classifier(
|
||||
return features
|
||||
else: # Transformer models
|
||||
outputs = self.encoder(processed)
|
||||
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
|
||||
if (
|
||||
hasattr(outputs, "pooler_output")
|
||||
and outputs.pooler_output is not None
|
||||
):
|
||||
return outputs.pooler_output
|
||||
return outputs.last_hidden_state[:, 0, :]
|
||||
|
||||
@@ -143,7 +160,9 @@ class Classifier(
|
||||
else:
|
||||
probabilities = torch.softmax(logits, dim=-1)
|
||||
|
||||
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
|
||||
return ClassifierOutput(
|
||||
logits=logits, probabilities=probabilities, hidden_states=encoder_outputs
|
||||
)
|
||||
|
||||
def predict_reward(self, x, threshold=0.6):
|
||||
if self.config.num_classes == 2:
|
||||
|
||||
@@ -196,7 +196,7 @@ class Unnormalize(nn.Module):
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
# @torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, mode in self.modes.items():
|
||||
|
||||
@@ -41,11 +41,16 @@ class SACConfig:
|
||||
)
|
||||
input_normalization_params: dict[str, dict[str, list[float]]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": {"mean": [[0.485, 0.456, 0.406]], "std": [[0.229, 0.224, 0.225]]},
|
||||
"observation.image": {
|
||||
"mean": [[0.485, 0.456, 0.406]],
|
||||
"std": [[0.229, 0.224, 0.225]],
|
||||
},
|
||||
"observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]},
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
|
||||
output_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {"action": "min_max"}
|
||||
)
|
||||
output_normalization_params: dict[str, dict[str, list[float]]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": {"min": [-1, -1], "max": [1, 1]},
|
||||
@@ -54,12 +59,13 @@ class SACConfig:
|
||||
# TODO: Move it outside of the config
|
||||
actor_learner_config: dict[str, str | int] = field(
|
||||
default_factory=lambda: {
|
||||
"actor_ip": "127.0.0.1",
|
||||
"port": 50051,
|
||||
"learner_ip": "127.0.0.1",
|
||||
"learner_host": "127.0.0.1",
|
||||
"learner_port": 50051,
|
||||
}
|
||||
)
|
||||
camera_number: int = 1
|
||||
|
||||
storage_device: str = "cpu"
|
||||
# Add type annotations for these fields:
|
||||
vision_encoder_name: str | None = field(default="helper2424/resnet10")
|
||||
freeze_vision_encoder: bool = True
|
||||
@@ -78,10 +84,12 @@ class SACConfig:
|
||||
latent_dim: int = 256
|
||||
target_entropy: float | None = None
|
||||
use_backup_entropy: bool = True
|
||||
grad_clip_norm: float = 40.0
|
||||
critic_network_kwargs: dict[str, Any] = field(
|
||||
default_factory=lambda: {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
"final_activation": None,
|
||||
}
|
||||
)
|
||||
actor_network_kwargs: dict[str, Any] = field(
|
||||
@@ -95,6 +103,6 @@ class SACConfig:
|
||||
"use_tanh_squash": True,
|
||||
"log_std_min": -5,
|
||||
"log_std_max": 2,
|
||||
"init_final": 0.005,
|
||||
"init_final": 0.05,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -17,7 +17,9 @@
|
||||
|
||||
# TODO: (1) better device management
|
||||
|
||||
from typing import Callable, Optional, Tuple
|
||||
import math
|
||||
from typing import Callable, Optional, Tuple, Union, Dict, List
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
@@ -57,7 +59,9 @@ class SACPolicy(
|
||||
config.input_normalization_params
|
||||
)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, input_normalization_params
|
||||
config.input_shapes,
|
||||
config.input_normalization_modes,
|
||||
input_normalization_params,
|
||||
)
|
||||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
@@ -83,52 +87,157 @@ class SACPolicy(
|
||||
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
||||
encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
|
||||
|
||||
# Create a list of critic heads
|
||||
critic_heads = [
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
|
||||
self.critic_ensemble = CriticEnsemble(
|
||||
encoder=encoder_critic,
|
||||
network_list=nn.ModuleList(
|
||||
[
|
||||
MLP(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
),
|
||||
ensemble=critic_heads,
|
||||
output_normalization=self.normalize_targets,
|
||||
)
|
||||
|
||||
# Create target critic heads as deepcopies of the original critic heads
|
||||
target_critic_heads = [
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
|
||||
self.critic_target = CriticEnsemble(
|
||||
encoder=encoder_critic,
|
||||
network_list=nn.ModuleList(
|
||||
[
|
||||
MLP(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
),
|
||||
ensemble=target_critic_heads,
|
||||
output_normalization=self.normalize_targets,
|
||||
)
|
||||
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
|
||||
network=MLP(
|
||||
input_dim=encoder_actor.output_dim, **config.actor_network_kwargs
|
||||
),
|
||||
action_dim=config.output_shapes["action"][0],
|
||||
encoder_is_shared=config.shared_encoder,
|
||||
**config.policy_kwargs,
|
||||
)
|
||||
if config.target_entropy is None:
|
||||
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
|
||||
config.target_entropy = (
|
||||
-np.prod(config.output_shapes["action"][0]) / 2
|
||||
) # (-dim(A)/2)
|
||||
|
||||
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
||||
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
|
||||
# it triggers "can't optimize a non-leaf Tensor"
|
||||
self.log_alpha = nn.Parameter(torch.tensor([0.0]))
|
||||
|
||||
temperature_init = config.temperature_init
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)]))
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
def _save_pretrained(self, save_directory):
|
||||
"""Custom save method to handle TensorDict properly"""
|
||||
import os
|
||||
import json
|
||||
from dataclasses import asdict
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
|
||||
from safetensors.torch import save_model
|
||||
|
||||
save_model(self, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE))
|
||||
|
||||
# Save config
|
||||
config_dict = asdict(self.config)
|
||||
with open(os.path.join(save_directory, CONFIG_NAME), "w") as f:
|
||||
json.dump(config_dict, f, indent=2)
|
||||
print(f"Saved config to {os.path.join(save_directory, CONFIG_NAME)}")
|
||||
|
||||
@classmethod
|
||||
def _from_pretrained(
|
||||
cls,
|
||||
*,
|
||||
model_id: str,
|
||||
revision: Optional[str],
|
||||
cache_dir: Optional[Union[str, Path]],
|
||||
force_download: bool,
|
||||
proxies: Optional[Dict],
|
||||
resume_download: Optional[bool],
|
||||
local_files_only: bool,
|
||||
token: Optional[Union[str, bool]],
|
||||
map_location: str = "cpu",
|
||||
strict: bool = False,
|
||||
**model_kwargs,
|
||||
) -> "SACPolicy":
|
||||
"""Custom load method to handle loading SAC policy from saved files"""
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
|
||||
from safetensors.torch import load_model
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
|
||||
# Check if model_id is a local path or a hub model ID
|
||||
if os.path.isdir(model_id):
|
||||
model_path = Path(model_id)
|
||||
safetensors_file = os.path.join(model_path, SAFETENSORS_SINGLE_FILE)
|
||||
config_file = os.path.join(model_path, CONFIG_NAME)
|
||||
else:
|
||||
# Download the safetensors file from the hub
|
||||
safetensors_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=SAFETENSORS_SINGLE_FILE,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
# Download the config file
|
||||
try:
|
||||
config_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=CONFIG_NAME,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
except Exception:
|
||||
config_file = None
|
||||
|
||||
# Load or create config
|
||||
if config_file and os.path.exists(config_file):
|
||||
# Load config from file
|
||||
with open(config_file) as f:
|
||||
config_dict = json.load(f)
|
||||
config = SACConfig(**config_dict)
|
||||
else:
|
||||
# Use the provided config or create a default one
|
||||
config = model_kwargs.get("config", SACConfig())
|
||||
|
||||
# Create a new instance with the loaded config
|
||||
model = cls(config=config)
|
||||
|
||||
# Load state dict from safetensors file
|
||||
if os.path.exists(safetensors_file):
|
||||
load_model(model, filename=safetensors_file, device=map_location)
|
||||
|
||||
return model
|
||||
|
||||
def reset(self):
|
||||
"""Reset the policy"""
|
||||
pass
|
||||
@@ -148,7 +257,11 @@ class SACPolicy(
|
||||
return actions
|
||||
|
||||
def critic_forward(
|
||||
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False
|
||||
self,
|
||||
observations: dict[str, Tensor],
|
||||
actions: Tensor,
|
||||
use_target: bool = False,
|
||||
observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
"""Forward pass through a critic network ensemble
|
||||
|
||||
@@ -161,28 +274,49 @@ class SACPolicy(
|
||||
Tensor of Q-values from all critics
|
||||
"""
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = critics(observations, actions)
|
||||
q_values = critics(observations, actions, observation_features)
|
||||
return q_values
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ...
|
||||
def update_target_networks(self):
|
||||
"""Update target networks with exponential moving average"""
|
||||
for target_param, param in zip(
|
||||
self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=False
|
||||
self.critic_target.parameters(),
|
||||
self.critic_ensemble.parameters(),
|
||||
strict=False,
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
def compute_loss_critic(self, observations, actions, rewards, next_observations, done) -> Tensor:
|
||||
temperature = self.log_alpha.exp().item()
|
||||
def compute_loss_critic(
|
||||
self,
|
||||
observations,
|
||||
actions,
|
||||
rewards,
|
||||
next_observations,
|
||||
done,
|
||||
observation_features: Tensor | None = None,
|
||||
next_observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
with torch.no_grad():
|
||||
next_action_preds, next_log_probs, _ = self.actor(next_observations)
|
||||
next_action_preds, next_log_probs, _ = self.actor(
|
||||
next_observations, next_observation_features
|
||||
)
|
||||
|
||||
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
|
||||
next_action_preds = self.unnormalize_outputs({"action": next_action_preds})[
|
||||
"action"
|
||||
]
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.critic_forward(
|
||||
observations=next_observations, actions=next_action_preds, use_target=True
|
||||
observations=next_observations,
|
||||
actions=next_action_preds,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
@@ -194,12 +328,17 @@ class SACPolicy(
|
||||
# critics subsample size
|
||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||
if self.config.use_backup_entropy:
|
||||
min_q = min_q - (temperature * next_log_probs)
|
||||
min_q = min_q - (self.temperature * next_log_probs)
|
||||
|
||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||
q_preds = self.critic_forward(
|
||||
observations,
|
||||
actions,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
@@ -214,23 +353,37 @@ class SACPolicy(
|
||||
).sum()
|
||||
return critics_loss
|
||||
|
||||
def compute_loss_temperature(self, observations) -> Tensor:
|
||||
def compute_loss_temperature(
|
||||
self, observations, observation_features: Tensor | None = None
|
||||
) -> Tensor:
|
||||
"""Compute the temperature loss"""
|
||||
# calculate temperature loss
|
||||
with torch.no_grad():
|
||||
_, log_probs, _ = self.actor(observations)
|
||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
|
||||
_, log_probs, _ = self.actor(observations, observation_features)
|
||||
temperature_loss = (
|
||||
-self.log_alpha.exp() * (log_probs + self.config.target_entropy)
|
||||
).mean()
|
||||
return temperature_loss
|
||||
|
||||
def compute_loss_actor(self, observations) -> Tensor:
|
||||
temperature = self.log_alpha.exp().item()
|
||||
def compute_loss_actor(
|
||||
self, observations, observation_features: Tensor | None = None
|
||||
) -> Tensor:
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
actions_pi, log_probs, _ = self.actor(observations)
|
||||
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
||||
|
||||
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
|
||||
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
|
||||
actions_pi = self.unnormalize_outputs({"action": actions_pi})["action"]
|
||||
|
||||
q_preds = self.critic_forward(
|
||||
observations,
|
||||
actions_pi,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
|
||||
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
|
||||
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
|
||||
return actor_loss
|
||||
|
||||
|
||||
@@ -242,6 +395,7 @@ class MLP(nn.Module):
|
||||
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||
activate_final: bool = False,
|
||||
dropout_rate: Optional[float] = None,
|
||||
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.activate_final = activate_final
|
||||
@@ -254,7 +408,11 @@ class MLP(nn.Module):
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(hidden_dims[0]))
|
||||
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
|
||||
layers.append(
|
||||
activations
|
||||
if isinstance(activations, nn.Module)
|
||||
else getattr(nn, activations)()
|
||||
)
|
||||
|
||||
# Rest of the layers
|
||||
for i in range(1, len(hidden_dims)):
|
||||
@@ -264,9 +422,24 @@ class MLP(nn.Module):
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(hidden_dims[i]))
|
||||
layers.append(
|
||||
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
|
||||
)
|
||||
|
||||
# If we're at the final layer and a final activation is specified, use it
|
||||
if (
|
||||
i + 1 == len(hidden_dims)
|
||||
and activate_final
|
||||
and final_activation is not None
|
||||
):
|
||||
layers.append(
|
||||
final_activation
|
||||
if isinstance(final_activation, nn.Module)
|
||||
else getattr(nn, final_activation)()
|
||||
)
|
||||
else:
|
||||
layers.append(
|
||||
activations
|
||||
if isinstance(activations, nn.Module)
|
||||
else getattr(nn, activations)()
|
||||
)
|
||||
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
@@ -274,6 +447,37 @@ class MLP(nn.Module):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class CriticHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
hidden_dims: list[int],
|
||||
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||
activate_final: bool = False,
|
||||
dropout_rate: Optional[float] = None,
|
||||
init_final: Optional[float] = None,
|
||||
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.net = MLP(
|
||||
input_dim=input_dim,
|
||||
hidden_dims=hidden_dims,
|
||||
activations=activations,
|
||||
activate_final=activate_final,
|
||||
dropout_rate=dropout_rate,
|
||||
final_activation=final_activation,
|
||||
)
|
||||
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.output_layer(self.net(x))
|
||||
|
||||
|
||||
class CriticEnsemble(nn.Module):
|
||||
"""
|
||||
┌──────────────────┬─────────────────────────────────────────────────────────┐
|
||||
@@ -316,50 +520,27 @@ class CriticEnsemble(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
network_list: nn.ModuleList,
|
||||
ensemble: List[CriticHead],
|
||||
output_normalization: nn.Module,
|
||||
init_final: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.network_list = network_list
|
||||
self.init_final = init_final
|
||||
self.output_normalization = output_normalization
|
||||
self.critics = nn.ModuleList(ensemble)
|
||||
|
||||
self.parameters_to_optimize = []
|
||||
# Handle the case where a part of the encoder if frozen
|
||||
if self.encoder is not None:
|
||||
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
|
||||
|
||||
self.parameters_to_optimize += list(self.network_list.parameters())
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network_list[0].net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
out_features = layer.out_features
|
||||
break
|
||||
|
||||
# Output layer
|
||||
self.output_layers = []
|
||||
if init_final is not None:
|
||||
for _ in network_list:
|
||||
output_layer = nn.Linear(out_features, 1)
|
||||
nn.init.uniform_(output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(output_layer.bias, -init_final, init_final)
|
||||
self.output_layers.append(output_layer)
|
||||
else:
|
||||
self.output_layers = []
|
||||
for _ in network_list:
|
||||
output_layer = nn.Linear(out_features, 1)
|
||||
orthogonal_init()(output_layer.weight)
|
||||
self.output_layers.append(output_layer)
|
||||
|
||||
self.output_layers = nn.ModuleList(self.output_layers)
|
||||
self.parameters_to_optimize += list(self.output_layers.parameters())
|
||||
self.parameters_to_optimize += list(self.critics.parameters())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: dict[str, torch.Tensor],
|
||||
actions: torch.Tensor,
|
||||
observation_features: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
# Move each tensor in observations to device
|
||||
@@ -370,15 +551,22 @@ class CriticEnsemble(nn.Module):
|
||||
actions = self.output_normalization(actions)["action"]
|
||||
actions = actions.to(device)
|
||||
|
||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||
obs_enc = (
|
||||
observation_features
|
||||
if observation_features is not None
|
||||
else (observations if self.encoder is None else self.encoder(observations))
|
||||
)
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
list_q_values = []
|
||||
for network, output_layer in zip(self.network_list, self.output_layers, strict=False):
|
||||
x = network(inputs)
|
||||
value = output_layer(x)
|
||||
list_q_values.append(value.squeeze(-1))
|
||||
return torch.stack(list_q_values)
|
||||
|
||||
# Loop through critics and collect outputs
|
||||
q_values = []
|
||||
for critic in self.critics:
|
||||
q_values.append(critic(inputs))
|
||||
|
||||
# Stack outputs to match expected shape [num_critics, batch_size]
|
||||
q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0)
|
||||
return q_values
|
||||
|
||||
|
||||
class Policy(nn.Module):
|
||||
@@ -435,9 +623,14 @@ class Policy(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
observation_features: torch.Tensor | None = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Encode observations if encoder exists
|
||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||
obs_enc = (
|
||||
observation_features
|
||||
if observation_features is not None
|
||||
else (observations if self.encoder is None else self.encoder(observations))
|
||||
)
|
||||
|
||||
# Get network outputs
|
||||
outputs = self.network(obs_enc)
|
||||
@@ -446,11 +639,15 @@ class Policy(nn.Module):
|
||||
# Compute standard deviations
|
||||
if self.fixed_std is None:
|
||||
log_std = self.std_layer(outputs)
|
||||
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
|
||||
assert not torch.isnan(
|
||||
log_std
|
||||
).any(), "[ERROR] log_std became NaN after std_layer!"
|
||||
|
||||
if self.use_tanh_squash:
|
||||
log_std = torch.tanh(log_std)
|
||||
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
|
||||
log_std = self.log_std_min + 0.5 * (
|
||||
self.log_std_max - self.log_std_min
|
||||
) * (log_std + 1.0)
|
||||
else:
|
||||
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
|
||||
else:
|
||||
@@ -463,7 +660,9 @@ class Policy(nn.Module):
|
||||
|
||||
if self.use_tanh_squash:
|
||||
actions = torch.tanh(x_t)
|
||||
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
|
||||
log_probs -= torch.log(
|
||||
(1 - actions.pow(2)) + 1e-6
|
||||
) # Adjust log-probs for Tanh
|
||||
else:
|
||||
actions = x_t # No Tanh; raw Gaussian sample
|
||||
|
||||
@@ -510,11 +709,15 @@ class SACObservationEncoder(nn.Module):
|
||||
freeze_image_encoder(self.image_enc_layers)
|
||||
else:
|
||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||
self.all_image_keys = [
|
||||
k for k in config.input_shapes if k.startswith("observation.image")
|
||||
]
|
||||
|
||||
if "observation.state" in config.input_shapes:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
in_features=config.input_shapes["observation.state"][0], out_features=config.latent_dim
|
||||
in_features=config.input_shapes["observation.state"][0],
|
||||
out_features=config.latent_dim,
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
@@ -535,7 +738,9 @@ class SACObservationEncoder(nn.Module):
|
||||
self.aggregation_size += config.latent_dim
|
||||
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
|
||||
|
||||
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
||||
self.aggregation_layer = nn.Linear(
|
||||
in_features=self.aggregation_size, out_features=config.latent_dim
|
||||
)
|
||||
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
|
||||
|
||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||
@@ -546,16 +751,21 @@ class SACObservationEncoder(nn.Module):
|
||||
"""
|
||||
feat = []
|
||||
obs_dict = self.input_normalization(obs_dict)
|
||||
# Concatenate all images along the channel dimension.
|
||||
image_keys = [k for k in obs_dict if k.startswith("observation.image")]
|
||||
for image_key in image_keys:
|
||||
enc_feat = self.image_enc_layers(obs_dict[image_key])
|
||||
# Batch all images along the batch dimension, then encode them.
|
||||
if len(self.all_image_keys) > 0:
|
||||
images_batched = torch.cat(
|
||||
[obs_dict[key] for key in self.all_image_keys], dim=0
|
||||
)
|
||||
images_batched = self.image_enc_layers(images_batched)
|
||||
embeddings_chunks = torch.chunk(
|
||||
images_batched, dim=0, chunks=len(self.all_image_keys)
|
||||
)
|
||||
feat.extend(embeddings_chunks)
|
||||
|
||||
# if not self.has_pretrained_vision_encoder:
|
||||
# enc_feat = flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])
|
||||
feat.append(enc_feat)
|
||||
if "observation.environment_state" in self.config.input_shapes:
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
feat.append(
|
||||
self.env_state_enc_layers(obs_dict["observation.environment_state"])
|
||||
)
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
|
||||
@@ -623,7 +833,9 @@ class PretrainedImageEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
||||
self.image_enc_layers, self.image_enc_out_shape = (
|
||||
self._load_pretrained_vision_encoder(config)
|
||||
)
|
||||
self.image_enc_proj = nn.Sequential(
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
@@ -634,15 +846,21 @@ class PretrainedImageEncoder(nn.Module):
|
||||
"""Set up CNN encoder"""
|
||||
from transformers import AutoModel
|
||||
|
||||
self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True)
|
||||
self.image_enc_layers = AutoModel.from_pretrained(
|
||||
config.vision_encoder_name, trust_remote_code=True
|
||||
)
|
||||
# self.image_enc_layers.pooler = Identity()
|
||||
|
||||
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
|
||||
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
|
||||
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[
|
||||
-1
|
||||
] # Last channel dimension
|
||||
elif hasattr(self.image_enc_layers, "fc"):
|
||||
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
|
||||
else:
|
||||
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
|
||||
raise ValueError(
|
||||
"Unsupported vision encoder architecture, make sure you are using a CNN"
|
||||
)
|
||||
return self.image_enc_layers, self.image_enc_out_shape
|
||||
|
||||
def forward(self, x):
|
||||
@@ -665,34 +883,12 @@ def orthogonal_init():
|
||||
|
||||
class Identity(nn.Module):
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
# TODO (azouitine): I think in our case this function is not usefull we should remove it
|
||||
# after some investigation
|
||||
# borrowed from tdmpc
|
||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||
|
||||
Args:
|
||||
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
|
||||
(B, *), where * is any number of dimensions.
|
||||
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
|
||||
can be more than 1 dimensions, generally different from *.
|
||||
Returns:
|
||||
A return value from the callable reshaped to (**, *).
|
||||
"""
|
||||
if image_tensor.ndim == 4:
|
||||
return fn(image_tensor)
|
||||
start_dims = image_tensor.shape[:-3]
|
||||
inp = torch.flatten(image_tensor, end_dim=-4)
|
||||
flat_out = fn(inp)
|
||||
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
||||
|
||||
|
||||
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
||||
converted_params = {}
|
||||
for outer_key, inner_dict in normalization_params.items():
|
||||
@@ -700,57 +896,86 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
||||
for key, value in inner_dict.items():
|
||||
converted_params[outer_key][key] = torch.tensor(value)
|
||||
if "image" in outer_key:
|
||||
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
|
||||
converted_params[outer_key][key] = converted_params[outer_key][
|
||||
key
|
||||
].view(3, 1, 1)
|
||||
|
||||
return converted_params
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test the SACObservationEncoder
|
||||
# Benchmark the CriticEnsemble performance
|
||||
import time
|
||||
|
||||
config = SACConfig()
|
||||
config.num_critics = 10
|
||||
encoder = SACObservationEncoder(config)
|
||||
actor_encoder = SACObservationEncoder(config)
|
||||
encoder = torch.compile(encoder)
|
||||
# Configuration
|
||||
num_critics = 10
|
||||
batch_size = 32
|
||||
action_dim = 7
|
||||
obs_dim = 64
|
||||
hidden_dims = [256, 256]
|
||||
num_iterations = 100
|
||||
|
||||
print("Creating test environment...")
|
||||
|
||||
# Create a simple dummy encoder
|
||||
class DummyEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.output_dim = obs_dim
|
||||
self.parameters_to_optimize = []
|
||||
|
||||
def forward(self, obs):
|
||||
# Just return a random tensor of the right shape
|
||||
# In practice, this would encode the observations
|
||||
return torch.randn(batch_size, obs_dim, device=device)
|
||||
|
||||
# Create critic heads
|
||||
print(f"Creating {num_critics} critic heads...")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
critic_heads = [
|
||||
CriticHead(
|
||||
input_dim=obs_dim + action_dim,
|
||||
hidden_dims=hidden_dims,
|
||||
).to(device)
|
||||
for _ in range(num_critics)
|
||||
]
|
||||
|
||||
# Create the critic ensemble
|
||||
print("Creating CriticEnsemble...")
|
||||
critic_ensemble = CriticEnsemble(
|
||||
encoder=encoder,
|
||||
network_list=nn.ModuleList(
|
||||
[
|
||||
MLP(
|
||||
input_dim=encoder.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
),
|
||||
)
|
||||
actor = Policy(
|
||||
encoder=actor_encoder,
|
||||
network=MLP(input_dim=actor_encoder.output_dim, **config.actor_network_kwargs),
|
||||
action_dim=config.output_shapes["action"][0],
|
||||
encoder_is_shared=config.shared_encoder,
|
||||
**config.policy_kwargs,
|
||||
)
|
||||
encoder = encoder.to("cuda:0")
|
||||
critic_ensemble = torch.compile(critic_ensemble)
|
||||
critic_ensemble = critic_ensemble.to("cuda:0")
|
||||
actor = torch.compile(actor)
|
||||
actor = actor.to("cuda:0")
|
||||
encoder=DummyEncoder().to(device),
|
||||
ensemble=critic_heads,
|
||||
output_normalization=nn.Identity(),
|
||||
).to(device)
|
||||
|
||||
# Create random input data
|
||||
print("Creating input data...")
|
||||
obs_dict = {
|
||||
"observation.image": torch.randn(1, 3, 84, 84),
|
||||
"observation.state": torch.randn(1, 4),
|
||||
"observation.state": torch.randn(batch_size, obs_dim, device=device),
|
||||
}
|
||||
actions = torch.randn(1, 2).to("cuda:0")
|
||||
obs_dict = {k: v.to("cuda:0") for k, v in obs_dict.items()}
|
||||
print("compiling...")
|
||||
# q_value = critic_ensemble(obs_dict, actions)
|
||||
action = actor(obs_dict)
|
||||
print("compiled")
|
||||
start = time.perf_counter()
|
||||
for _ in range(1000):
|
||||
# features = encoder(obs_dict)
|
||||
action = actor(obs_dict)
|
||||
# q_value = critic_ensemble(obs_dict, actions)
|
||||
print("Time taken:", time.perf_counter() - start)
|
||||
actions = torch.randn(batch_size, action_dim, device=device)
|
||||
|
||||
# Warmup run
|
||||
print("Warming up...")
|
||||
_ = critic_ensemble(obs_dict, actions)
|
||||
|
||||
# Time the forward pass
|
||||
print(f"Running benchmark with {num_iterations} iterations...")
|
||||
start_time = time.perf_counter()
|
||||
for _ in range(num_iterations):
|
||||
q_values = critic_ensemble(obs_dict, actions)
|
||||
end_time = time.perf_counter()
|
||||
|
||||
# Print results
|
||||
elapsed_time = end_time - start_time
|
||||
print(f"Total time: {elapsed_time:.4f} seconds")
|
||||
print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms")
|
||||
print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size]
|
||||
|
||||
# Verify that all critic heads produce different outputs
|
||||
# This confirms each critic head is unique
|
||||
# print("\nVerifying critic outputs are different:")
|
||||
# for i in range(num_critics):
|
||||
# for j in range(i + 1, num_critics):
|
||||
# diff = torch.abs(q_values[i] - q_values[j]).mean().item()
|
||||
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}")
|
||||
|
||||
@@ -191,6 +191,10 @@ class TDMPCConfig:
|
||||
"If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1."
|
||||
)
|
||||
if not self.use_mpc:
|
||||
raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.")
|
||||
raise ValueError(
|
||||
"If `n_action_steps > 1`, `use_mpc` must be set to `True`."
|
||||
)
|
||||
if self.n_action_steps > self.horizon:
|
||||
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")
|
||||
raise ValueError(
|
||||
"`n_action_steps` must be less than or equal to `horizon`."
|
||||
)
|
||||
|
||||
@@ -68,7 +68,9 @@ class TDMPCPolicy(
|
||||
name = "tdmpc"
|
||||
|
||||
def __init__(
|
||||
self, config: TDMPCConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
|
||||
self,
|
||||
config: TDMPCConfig | None = None,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -100,7 +102,9 @@ class TDMPCPolicy(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
image_keys = [
|
||||
k for k in config.input_shapes if k.startswith("observation.image")
|
||||
]
|
||||
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
||||
self._use_image = False
|
||||
self._use_env_state = False
|
||||
@@ -120,7 +124,9 @@ class TDMPCPolicy(
|
||||
"""
|
||||
self._queues = {
|
||||
"observation.state": deque(maxlen=1),
|
||||
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
|
||||
"action": deque(
|
||||
maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)
|
||||
),
|
||||
}
|
||||
if self._use_image:
|
||||
self._queues["observation.image"] = deque(maxlen=1)
|
||||
@@ -135,7 +141,9 @@ class TDMPCPolicy(
|
||||
"""Select a single action given environment observations."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self._use_image:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
@@ -209,13 +217,20 @@ class TDMPCPolicy(
|
||||
|
||||
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
|
||||
# trajectories.
|
||||
z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
|
||||
z = einops.repeat(
|
||||
z,
|
||||
"b d -> n b d",
|
||||
n=self.config.n_gaussian_samples + self.config.n_pi_samples,
|
||||
)
|
||||
|
||||
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
|
||||
# algorithm.
|
||||
# The initial mean and standard deviation for the cross-entropy method (CEM).
|
||||
mean = torch.zeros(
|
||||
self.config.horizon, batch_size, self.config.output_shapes["action"][0], device=device
|
||||
self.config.horizon,
|
||||
batch_size,
|
||||
self.config.output_shapes["action"][0],
|
||||
device=device,
|
||||
)
|
||||
# Maybe warm start CEM with the mean from the previous step.
|
||||
if self._prev_mean is not None:
|
||||
@@ -231,35 +246,47 @@ class TDMPCPolicy(
|
||||
self.config.output_shapes["action"][0],
|
||||
device=std.device,
|
||||
)
|
||||
gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1)
|
||||
gaussian_actions = torch.clamp(
|
||||
mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1
|
||||
)
|
||||
|
||||
# Compute elite actions.
|
||||
actions = torch.cat([gaussian_actions, pi_actions], dim=1)
|
||||
value = self.estimate_value(z, actions).nan_to_num_(0)
|
||||
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch)
|
||||
elite_idxs = torch.topk(
|
||||
value, self.config.n_elites, dim=0
|
||||
).indices # (n_elites, batch)
|
||||
elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch)
|
||||
# (horizon, n_elites, batch, action_dim)
|
||||
elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1)
|
||||
elite_actions = actions.take_along_dim(
|
||||
einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1
|
||||
)
|
||||
|
||||
# Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites.
|
||||
max_value = elite_value.max(0, keepdim=True)[0] # (1, batch)
|
||||
# The weighting is a softmax over trajectory values. Note that this is not the same as the usage
|
||||
# of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This
|
||||
# makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²).
|
||||
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
|
||||
score = torch.exp(
|
||||
self.config.elite_weighting_temperature * (elite_value - max_value)
|
||||
)
|
||||
score /= score.sum(axis=0, keepdim=True)
|
||||
# (horizon, batch, action_dim)
|
||||
_mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1)
|
||||
_mean = torch.sum(
|
||||
einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1
|
||||
)
|
||||
_std = torch.sqrt(
|
||||
torch.sum(
|
||||
einops.rearrange(score, "n b -> n b 1")
|
||||
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2,
|
||||
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d"))
|
||||
** 2,
|
||||
dim=1,
|
||||
)
|
||||
)
|
||||
# Update mean with an exponential moving average, and std with a direct replacement.
|
||||
mean = (
|
||||
self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean
|
||||
self.config.gaussian_mean_momentum * mean
|
||||
+ (1 - self.config.gaussian_mean_momentum) * _mean
|
||||
)
|
||||
std = _std.clamp_(self.config.min_std, self.config.max_std)
|
||||
|
||||
@@ -268,7 +295,9 @@ class TDMPCPolicy(
|
||||
|
||||
# Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax
|
||||
# scores from the last iteration.
|
||||
actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)]
|
||||
actions = elite_actions[
|
||||
:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)
|
||||
]
|
||||
|
||||
return actions
|
||||
|
||||
@@ -291,7 +320,8 @@ class TDMPCPolicy(
|
||||
# of the FOWM paper.
|
||||
if self.config.uncertainty_regularizer_coeff > 0:
|
||||
regularization = -(
|
||||
self.config.uncertainty_regularizer_coeff * self.model.Qs(z, actions[t]).std(0)
|
||||
self.config.uncertainty_regularizer_coeff
|
||||
* self.model.Qs(z, actions[t]).std(0)
|
||||
)
|
||||
else:
|
||||
regularization = 0
|
||||
@@ -311,15 +341,22 @@ class TDMPCPolicy(
|
||||
if self.config.q_ensemble_size > 2:
|
||||
G += (
|
||||
running_discount
|
||||
* torch.min(terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0)[
|
||||
0
|
||||
]
|
||||
* torch.min(
|
||||
terminal_values[
|
||||
torch.randint(0, self.config.q_ensemble_size, size=(2,))
|
||||
],
|
||||
dim=0,
|
||||
)[0]
|
||||
)
|
||||
else:
|
||||
G += running_discount * torch.min(terminal_values, dim=0)[0]
|
||||
# Finally, also regularize the terminal value.
|
||||
if self.config.uncertainty_regularizer_coeff > 0:
|
||||
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
|
||||
G -= (
|
||||
running_discount
|
||||
* self.config.uncertainty_regularizer_coeff
|
||||
* terminal_values.std(0)
|
||||
)
|
||||
return G
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||
@@ -331,7 +368,9 @@ class TDMPCPolicy(
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self._use_image:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
@@ -349,7 +388,10 @@ class TDMPCPolicy(
|
||||
# Apply random image augmentations.
|
||||
if self._use_image and self.config.max_random_shift_ratio > 0:
|
||||
observations["observation.image"] = flatten_forward_unflatten(
|
||||
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
|
||||
partial(
|
||||
random_shifts_aug,
|
||||
max_random_shift_ratio=self.config.max_random_shift_ratio,
|
||||
),
|
||||
observations["observation.image"],
|
||||
)
|
||||
|
||||
@@ -367,14 +409,20 @@ class TDMPCPolicy(
|
||||
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
|
||||
# gives us a next `z`.
|
||||
batch_size = batch["index"].shape[0]
|
||||
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
|
||||
z_preds = torch.empty(
|
||||
horizon + 1, batch_size, self.config.latent_dim, device=device
|
||||
)
|
||||
z_preds[0] = self.model.encode(current_observation)
|
||||
reward_preds = torch.empty_like(reward, device=device)
|
||||
for t in range(horizon):
|
||||
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t])
|
||||
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(
|
||||
z_preds[t], action[t]
|
||||
)
|
||||
|
||||
# Compute Q and V value predictions based on the latent rollout.
|
||||
q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch)
|
||||
q_preds_ensemble = self.model.Qs(
|
||||
z_preds[:-1], action
|
||||
) # (ensemble, horizon, batch)
|
||||
v_preds = self.model.V(z_preds[:-1])
|
||||
info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()})
|
||||
|
||||
@@ -388,10 +436,14 @@ class TDMPCPolicy(
|
||||
# actions (not actions estimated by π).
|
||||
# Note: Here we do not use self.model_target, but self.model. This is to follow the original code
|
||||
# and the FOWM paper.
|
||||
q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations))
|
||||
q_targets = reward + self.config.discount * self.model.V(
|
||||
self.model.encode(next_observations)
|
||||
)
|
||||
# From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we
|
||||
# are using them to compute loss for V.
|
||||
v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True)
|
||||
v_targets = self.model_target.Qs(
|
||||
z_preds[:-1].detach(), action, return_min=True
|
||||
)
|
||||
|
||||
# Compute losses.
|
||||
# Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the
|
||||
@@ -434,7 +486,9 @@ class TDMPCPolicy(
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(
|
||||
q_preds_ensemble,
|
||||
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
|
||||
einops.repeat(
|
||||
q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]
|
||||
),
|
||||
reduction="none",
|
||||
).sum(0) # sum over ensemble
|
||||
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||
@@ -472,12 +526,14 @@ class TDMPCPolicy(
|
||||
z_preds = z_preds.detach()
|
||||
# Use stopgrad for the advantage calculation.
|
||||
with torch.no_grad():
|
||||
advantage = self.model_target.Qs(z_preds[:-1], action, return_min=True) - self.model.V(
|
||||
z_preds[:-1]
|
||||
)
|
||||
advantage = self.model_target.Qs(
|
||||
z_preds[:-1], action, return_min=True
|
||||
) - self.model.V(z_preds[:-1])
|
||||
info["advantage"] = advantage[0]
|
||||
# (t, b)
|
||||
exp_advantage = torch.clamp(torch.exp(advantage * self.config.advantage_scaling), max=100.0)
|
||||
exp_advantage = torch.clamp(
|
||||
torch.exp(advantage * self.config.advantage_scaling), max=100.0
|
||||
)
|
||||
action_preds = self.model.pi(z_preds[:-1]) # (t, b, a)
|
||||
# Calculate the MSE between the actions and the action predictions.
|
||||
# Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation
|
||||
@@ -532,7 +588,9 @@ class TDMPCPolicy(
|
||||
# Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA
|
||||
# update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code
|
||||
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
|
||||
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
|
||||
update_ema_parameters(
|
||||
self.model_target, self.model, self.config.target_model_momentum
|
||||
)
|
||||
|
||||
|
||||
class TDMPCTOLD(nn.Module):
|
||||
@@ -543,7 +601,9 @@ class TDMPCTOLD(nn.Module):
|
||||
self.config = config
|
||||
self._encoder = TDMPCObservationEncoder(config)
|
||||
self._dynamics = nn.Sequential(
|
||||
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||
nn.Linear(
|
||||
config.latent_dim + config.output_shapes["action"][0], config.mlp_dim
|
||||
),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
@@ -554,7 +614,9 @@ class TDMPCTOLD(nn.Module):
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
self._reward = nn.Sequential(
|
||||
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||
nn.Linear(
|
||||
config.latent_dim + config.output_shapes["action"][0], config.mlp_dim
|
||||
),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
@@ -574,7 +636,10 @@ class TDMPCTOLD(nn.Module):
|
||||
self._Qs = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||
nn.Linear(
|
||||
config.latent_dim + config.output_shapes["action"][0],
|
||||
config.mlp_dim,
|
||||
),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Tanh(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
@@ -619,7 +684,9 @@ class TDMPCTOLD(nn.Module):
|
||||
m[-1], nn.Linear
|
||||
), "Sanity check. The last linear layer needs 0 initialization on weights."
|
||||
nn.init.zeros_(m[-1].weight)
|
||||
nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure
|
||||
nn.init.zeros_(
|
||||
m[-1].bias
|
||||
) # this has already been done, but keep this line here for good measure
|
||||
|
||||
def encode(self, obs: dict[str, Tensor]) -> Tensor:
|
||||
"""Encodes an observation into its latent representation."""
|
||||
@@ -717,14 +784,32 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
if "observation.image" in config.input_shapes:
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2
|
||||
config.input_shapes["observation.image"][0],
|
||||
config.image_encoder_hidden_dim,
|
||||
7,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
|
||||
nn.Conv2d(
|
||||
config.image_encoder_hidden_dim,
|
||||
config.image_encoder_hidden_dim,
|
||||
5,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||
nn.Conv2d(
|
||||
config.image_encoder_hidden_dim,
|
||||
config.image_encoder_hidden_dim,
|
||||
3,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||
nn.Conv2d(
|
||||
config.image_encoder_hidden_dim,
|
||||
config.image_encoder_hidden_dim,
|
||||
3,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
)
|
||||
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
||||
@@ -740,7 +825,10 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
)
|
||||
if "observation.state" in config.input_shapes:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim),
|
||||
nn.Linear(
|
||||
config.input_shapes["observation.state"][0],
|
||||
config.state_encoder_hidden_dim,
|
||||
),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
@@ -749,7 +837,8 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
|
||||
config.input_shapes["observation.environment_state"][0],
|
||||
config.state_encoder_hidden_dim,
|
||||
),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
@@ -766,9 +855,15 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
feat = []
|
||||
# NOTE: Order of observations matters here.
|
||||
if "observation.image" in self.config.input_shapes:
|
||||
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict["observation.image"]))
|
||||
feat.append(
|
||||
flatten_forward_unflatten(
|
||||
self.image_enc_layers, obs_dict["observation.image"]
|
||||
)
|
||||
)
|
||||
if "observation.environment_state" in self.config.input_shapes:
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
feat.append(
|
||||
self.env_state_enc_layers(obs_dict["observation.environment_state"])
|
||||
)
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
return torch.stack(feat, dim=0).mean(0)
|
||||
@@ -811,12 +906,17 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
|
||||
"""Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param."""
|
||||
for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True):
|
||||
for (n_p_ema, p_ema), (n_p, p) in zip(
|
||||
ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True
|
||||
ema_module.named_parameters(recurse=False),
|
||||
module.named_parameters(recurse=False),
|
||||
strict=True,
|
||||
):
|
||||
assert n_p_ema == n_p, "Parameter names don't match for EMA model update"
|
||||
if isinstance(p, dict):
|
||||
raise RuntimeError("Dict parameter not supported")
|
||||
if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad:
|
||||
if (
|
||||
isinstance(module, nn.modules.batchnorm._BatchNorm)
|
||||
or not p.requires_grad
|
||||
):
|
||||
# Copy BatchNorm parameters, and non-trainable parameters directly.
|
||||
p_ema.copy_(p.to(dtype=p_ema.dtype).data)
|
||||
with torch.no_grad():
|
||||
@@ -824,7 +924,9 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
|
||||
p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha)
|
||||
|
||||
|
||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||
def flatten_forward_unflatten(
|
||||
fn: Callable[[Tensor], Tensor], image_tensor: Tensor
|
||||
) -> Tensor:
|
||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -109,7 +109,9 @@ class VQBeTConfig:
|
||||
"observation.state": "min_max",
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
|
||||
output_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {"action": "min_max"}
|
||||
)
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
|
||||
@@ -79,7 +79,9 @@ class VQBeTPolicy(
|
||||
|
||||
self.vqbet = VQBeTModel(config)
|
||||
|
||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
self.expected_image_keys = [
|
||||
k for k in config.input_shapes if k.startswith("observation.image")
|
||||
]
|
||||
|
||||
self.reset()
|
||||
|
||||
@@ -104,8 +106,12 @@ class VQBeTPolicy(
|
||||
"""
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[k] for k in self.expected_image_keys], dim=-4
|
||||
)
|
||||
# Note: It's important that this happens after stacking the images into a single key.
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -116,8 +122,14 @@ class VQBeTPolicy(
|
||||
)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
|
||||
batch = {
|
||||
k: torch.stack(list(self._queues[k]), dim=1)
|
||||
for k in batch
|
||||
if k in self._queues
|
||||
}
|
||||
actions = self.vqbet(batch, rollout=True)[
|
||||
:, : self.config.action_chunk_size
|
||||
]
|
||||
|
||||
# the dimension of returned action is (batch_size, action_chunk_size, action_dim)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
@@ -130,8 +142,12 @@ class VQBeTPolicy(
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[k] for k in self.expected_image_keys], dim=-4
|
||||
)
|
||||
batch = self.normalize_targets(batch)
|
||||
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
|
||||
if not self.vqbet.action_head.vqvae_model.discretized.item():
|
||||
@@ -139,7 +155,9 @@ class VQBeTPolicy(
|
||||
# n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`.
|
||||
# n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree).
|
||||
loss, n_different_codes, n_different_combinations, recon_l1_error = (
|
||||
self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"])
|
||||
self.vqbet.action_head.discretize(
|
||||
self.config.n_vqvae_training_steps, batch["action"]
|
||||
)
|
||||
)
|
||||
return {
|
||||
"loss": loss,
|
||||
@@ -196,7 +214,9 @@ class SpatialSoftmax(nn.Module):
|
||||
|
||||
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
|
||||
# and causes a small degradation in pc_success of pre-trained models.
|
||||
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
|
||||
pos_x, pos_y = np.meshgrid(
|
||||
np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)
|
||||
)
|
||||
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
|
||||
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
|
||||
# register as buffer so it's moved to the correct device.
|
||||
@@ -288,14 +308,17 @@ class VQBeTModel(nn.Module):
|
||||
self.config = config
|
||||
|
||||
self.rgb_encoder = VQBeTRgbEncoder(config)
|
||||
self.num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
|
||||
self.num_images = len(
|
||||
[k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
)
|
||||
# This action query token is used as a prompt for querying action chunks. Please refer to "A_Q" in the image above.
|
||||
# Note: During the forward pass, this token is repeated as many times as needed. The authors also experimented with initializing the necessary number of tokens independently and observed inferior results.
|
||||
self.action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim))
|
||||
|
||||
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
|
||||
self.state_projector = MLP(
|
||||
config.input_shapes["observation.state"][0], hidden_channels=[self.config.gpt_input_dim]
|
||||
config.input_shapes["observation.state"][0],
|
||||
hidden_channels=[self.config.gpt_input_dim],
|
||||
)
|
||||
self.rgb_feature_projector = MLP(
|
||||
self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]
|
||||
@@ -310,7 +333,12 @@ class VQBeTModel(nn.Module):
|
||||
num_tokens = self.config.n_action_pred_token + self.config.n_obs_steps - 1
|
||||
self.register_buffer(
|
||||
"select_target_actions_indices",
|
||||
torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),
|
||||
torch.row_stack(
|
||||
[
|
||||
torch.arange(i, i + self.config.action_chunk_size)
|
||||
for i in range(num_tokens)
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor:
|
||||
@@ -325,7 +353,11 @@ class VQBeTModel(nn.Module):
|
||||
)
|
||||
# Separate batch and sequence dims.
|
||||
img_features = einops.rearrange(
|
||||
img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images
|
||||
img_features,
|
||||
"(b s n) ... -> b s n ...",
|
||||
b=batch_size,
|
||||
s=n_obs_steps,
|
||||
n=self.num_images,
|
||||
)
|
||||
|
||||
# Arrange prior and current observation step tokens as shown in the class docstring.
|
||||
@@ -337,13 +369,19 @@ class VQBeTModel(nn.Module):
|
||||
input_tokens.append(
|
||||
self.state_projector(batch["observation.state"])
|
||||
) # (batch, obs_step, projection dims)
|
||||
input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps))
|
||||
input_tokens.append(
|
||||
einops.repeat(
|
||||
self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps
|
||||
)
|
||||
)
|
||||
# Interleave tokens by stacking and rearranging.
|
||||
input_tokens = torch.stack(input_tokens, dim=2)
|
||||
input_tokens = einops.rearrange(input_tokens, "b n t d -> b (n t) d")
|
||||
|
||||
len_additional_action_token = self.config.n_action_pred_token - 1
|
||||
future_action_tokens = self.action_token.repeat(batch_size, len_additional_action_token, 1)
|
||||
future_action_tokens = self.action_token.repeat(
|
||||
batch_size, len_additional_action_token, 1
|
||||
)
|
||||
|
||||
# add additional action query tokens for predicting future action chunks
|
||||
input_tokens = torch.cat([input_tokens, future_action_tokens], dim=1)
|
||||
@@ -352,9 +390,9 @@ class VQBeTModel(nn.Module):
|
||||
features = self.policy(input_tokens)
|
||||
# len(self.config.input_shapes) is the number of different observation modes.
|
||||
# this line gets the index of action prompt tokens.
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_shapes) + 1) + len(
|
||||
self.config.input_shapes
|
||||
)
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * (
|
||||
len(self.config.input_shapes) + 1
|
||||
) + len(self.config.input_shapes)
|
||||
|
||||
# only extract the output tokens at the position of action query:
|
||||
# Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models,
|
||||
@@ -362,7 +400,11 @@ class VQBeTModel(nn.Module):
|
||||
# Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional).
|
||||
if len_additional_action_token > 0:
|
||||
features = torch.cat(
|
||||
[features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1
|
||||
[
|
||||
features[:, historical_act_pred_index],
|
||||
features[:, -len_additional_action_token:],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
else:
|
||||
features = features[:, historical_act_pred_index]
|
||||
@@ -370,13 +412,15 @@ class VQBeTModel(nn.Module):
|
||||
action_head_output = self.action_head(features)
|
||||
# if rollout, VQ-BeT don't calculate loss
|
||||
if rollout:
|
||||
return action_head_output["predicted_action"][:, n_obs_steps - 1, :].reshape(
|
||||
batch_size, self.config.action_chunk_size, -1
|
||||
)
|
||||
return action_head_output["predicted_action"][
|
||||
:, n_obs_steps - 1, :
|
||||
].reshape(batch_size, self.config.action_chunk_size, -1)
|
||||
# else, it calculate overall loss (bin prediction loss, and offset loss)
|
||||
else:
|
||||
output = batch["action"][:, self.select_target_actions_indices]
|
||||
loss = self.action_head.loss_fn(action_head_output, output, reduction="mean")
|
||||
loss = self.action_head.loss_fn(
|
||||
action_head_output, output, reduction="mean"
|
||||
)
|
||||
return action_head_output, loss
|
||||
|
||||
|
||||
@@ -411,7 +455,9 @@ class VQBeTHead(nn.Module):
|
||||
else:
|
||||
self.map_to_cbet_preds_bin = MLP(
|
||||
in_channels=config.gpt_output_dim,
|
||||
hidden_channels=[self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed],
|
||||
hidden_channels=[
|
||||
self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed
|
||||
],
|
||||
)
|
||||
self.map_to_cbet_preds_offset = MLP(
|
||||
in_channels=config.gpt_output_dim,
|
||||
@@ -438,7 +484,10 @@ class VQBeTHead(nn.Module):
|
||||
|
||||
loss, metric = self.vqvae_model.vqvae_forward(actions)
|
||||
n_different_codes = sum(
|
||||
[len(torch.unique(metric[2][:, i])) for i in range(self.vqvae_model.vqvae_num_layers)]
|
||||
[
|
||||
len(torch.unique(metric[2][:, i]))
|
||||
for i in range(self.vqvae_model.vqvae_num_layers)
|
||||
]
|
||||
)
|
||||
n_different_combinations = len(torch.unique(metric[2], dim=0))
|
||||
recon_l1_error = metric[0].detach().cpu().item()
|
||||
@@ -485,7 +534,13 @@ class VQBeTHead(nn.Module):
|
||||
|
||||
cbet_secondary_logits = self.map_to_cbet_preds_secondary_bin(
|
||||
torch.cat(
|
||||
(x, F.one_hot(sampled_primary_centers, num_classes=self.config.vqvae_n_embed)),
|
||||
(
|
||||
x,
|
||||
F.one_hot(
|
||||
sampled_primary_centers,
|
||||
num_classes=self.config.vqvae_n_embed,
|
||||
),
|
||||
),
|
||||
axis=1,
|
||||
)
|
||||
)
|
||||
@@ -493,19 +548,29 @@ class VQBeTHead(nn.Module):
|
||||
cbet_secondary_logits / self.config.bet_softmax_temperature, dim=-1
|
||||
)
|
||||
sampled_secondary_centers = einops.rearrange(
|
||||
torch.multinomial(cbet_secondary_probs.view(-1, choices), num_samples=1),
|
||||
torch.multinomial(
|
||||
cbet_secondary_probs.view(-1, choices), num_samples=1
|
||||
),
|
||||
"(NT) 1 -> NT",
|
||||
NT=NT,
|
||||
)
|
||||
sampled_centers = torch.stack((sampled_primary_centers, sampled_secondary_centers), axis=1)
|
||||
cbet_logits = torch.stack([cbet_primary_logits, cbet_secondary_logits], dim=1)
|
||||
sampled_centers = torch.stack(
|
||||
(sampled_primary_centers, sampled_secondary_centers), axis=1
|
||||
)
|
||||
cbet_logits = torch.stack(
|
||||
[cbet_primary_logits, cbet_secondary_logits], dim=1
|
||||
)
|
||||
# if self.config.sequentially_select is False, bin prediction head samples primary and secondary code at once.
|
||||
else:
|
||||
cbet_logits = self.map_to_cbet_preds_bin(x)
|
||||
cbet_logits = einops.rearrange(
|
||||
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
|
||||
cbet_logits,
|
||||
"(NT) (G C) -> (NT) G C",
|
||||
G=self.vqvae_model.vqvae_num_layers,
|
||||
)
|
||||
cbet_probs = torch.softmax(
|
||||
cbet_logits / self.config.bet_softmax_temperature, dim=-1
|
||||
)
|
||||
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
|
||||
NT, G, choices = cbet_probs.shape
|
||||
sampled_centers = einops.rearrange(
|
||||
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
|
||||
@@ -525,9 +590,17 @@ class VQBeTHead(nn.Module):
|
||||
sampled_offsets = sampled_offsets.sum(dim=1)
|
||||
with torch.no_grad():
|
||||
# Get the centroids (= vectors corresponding to the codes) of each layer to pass it through RVQ decoder
|
||||
return_decoder_input = self.vqvae_model.get_embeddings_from_code(sampled_centers).clone().detach()
|
||||
return_decoder_input = (
|
||||
self.vqvae_model.get_embeddings_from_code(sampled_centers)
|
||||
.clone()
|
||||
.detach()
|
||||
)
|
||||
# pass the centroids through decoder to get actions.
|
||||
decoded_action = self.vqvae_model.get_action_from_latent(return_decoder_input).clone().detach()
|
||||
decoded_action = (
|
||||
self.vqvae_model.get_action_from_latent(return_decoder_input)
|
||||
.clone()
|
||||
.detach()
|
||||
)
|
||||
# reshaped extracted offset to match with decoded centroids
|
||||
sampled_offsets = einops.rearrange(
|
||||
sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size
|
||||
@@ -576,7 +649,9 @@ class VQBeTHead(nn.Module):
|
||||
# Figure out the loss for the actions.
|
||||
# First, we need to find the closest cluster center for each ground truth action.
|
||||
with torch.no_grad():
|
||||
state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
|
||||
state_vq, action_bins = self.vqvae_model.get_code(
|
||||
action_seq
|
||||
) # action_bins: NT, G
|
||||
|
||||
# Now we can compute the loss.
|
||||
|
||||
@@ -599,8 +674,12 @@ class VQBeTHead(nn.Module):
|
||||
+ cbet_loss2 * self.config.secondary_code_loss_weight
|
||||
)
|
||||
|
||||
equal_primary_code_rate = torch.sum((action_bins[:, 0] == sampled_centers[:, 0]).int()) / (NT)
|
||||
equal_secondary_code_rate = torch.sum((action_bins[:, 1] == sampled_centers[:, 1]).int()) / (NT)
|
||||
equal_primary_code_rate = torch.sum(
|
||||
(action_bins[:, 0] == sampled_centers[:, 0]).int()
|
||||
) / (NT)
|
||||
equal_secondary_code_rate = torch.sum(
|
||||
(action_bins[:, 1] == sampled_centers[:, 1]).int()
|
||||
) / (NT)
|
||||
|
||||
action_mse_error = torch.mean((action_seq - predicted_action) ** 2)
|
||||
vq_action_error = torch.mean(torch.abs(action_seq - decoded_action))
|
||||
@@ -614,7 +693,9 @@ class VQBeTHead(nn.Module):
|
||||
"classification_loss": cbet_loss.detach().cpu().item(),
|
||||
"offset_loss": offset_loss.detach().cpu().item(),
|
||||
"equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(),
|
||||
"equal_secondary_code_rate": equal_secondary_code_rate.detach().cpu().item(),
|
||||
"equal_secondary_code_rate": equal_secondary_code_rate.detach()
|
||||
.cpu()
|
||||
.item(),
|
||||
"vq_action_error": vq_action_error.detach().cpu().item(),
|
||||
"offset_action_error": offset_action_error.detach().cpu().item(),
|
||||
"action_error_max": action_error_max.detach().cpu().item(),
|
||||
@@ -643,11 +724,17 @@ class VQBeTOptimizer(torch.optim.Adam):
|
||||
if cfg.policy.sequentially_select:
|
||||
decay_params = (
|
||||
decay_params
|
||||
+ list(policy.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
|
||||
+ list(policy.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
|
||||
+ list(
|
||||
policy.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters()
|
||||
)
|
||||
+ list(
|
||||
policy.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters()
|
||||
)
|
||||
)
|
||||
else:
|
||||
decay_params = decay_params + list(policy.vqbet.action_head.map_to_cbet_preds_bin.parameters())
|
||||
decay_params = decay_params + list(
|
||||
policy.vqbet.action_head.map_to_cbet_preds_bin.parameters()
|
||||
)
|
||||
|
||||
optim_groups = [
|
||||
{
|
||||
@@ -693,7 +780,11 @@ class VQBeTScheduler(nn.Module):
|
||||
progress = float(current_step - num_warmup_steps) / float(
|
||||
max(1, num_training_steps - num_warmup_steps)
|
||||
)
|
||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
||||
return max(
|
||||
0.0,
|
||||
0.5
|
||||
* (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
|
||||
)
|
||||
|
||||
self.lr_scheduler = LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
@@ -717,7 +808,9 @@ class VQBeTRgbEncoder(nn.Module):
|
||||
# Always use center crop for eval
|
||||
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
|
||||
if config.crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(
|
||||
config.crop_shape
|
||||
)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
@@ -738,7 +831,9 @@ class VQBeTRgbEncoder(nn.Module):
|
||||
self.backbone = _replace_submodules(
|
||||
root_module=self.backbone,
|
||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||
func=lambda x: nn.GroupNorm(
|
||||
num_groups=x.num_features // 16, num_channels=x.num_features
|
||||
),
|
||||
)
|
||||
|
||||
# Set up pooling and final layers.
|
||||
@@ -746,17 +841,25 @@ class VQBeTRgbEncoder(nn.Module):
|
||||
# The dummy input should take the number of image channels from `config.input_shapes` and it should
|
||||
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
||||
# height and width from `config.input_shapes`.
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
image_keys = [
|
||||
k for k in config.input_shapes if k.startswith("observation.image")
|
||||
]
|
||||
assert len(image_keys) == 1
|
||||
image_key = image_keys[0]
|
||||
dummy_input_h_w = (
|
||||
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
|
||||
config.crop_shape
|
||||
if config.crop_shape is not None
|
||||
else config.input_shapes[image_key][1:]
|
||||
)
|
||||
dummy_input = torch.zeros(
|
||||
size=(1, config.input_shapes[image_key][0], *dummy_input_h_w)
|
||||
)
|
||||
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
|
||||
with torch.inference_mode():
|
||||
dummy_feature_map = self.backbone(dummy_input)
|
||||
feature_map_shape = tuple(dummy_feature_map.shape[1:])
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
self.pool = SpatialSoftmax(
|
||||
feature_map_shape, num_kp=config.spatial_softmax_num_keypoints
|
||||
)
|
||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
self.relu = nn.ReLU()
|
||||
@@ -783,7 +886,9 @@ class VQBeTRgbEncoder(nn.Module):
|
||||
|
||||
|
||||
def _replace_submodules(
|
||||
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||
root_module: nn.Module,
|
||||
predicate: Callable[[nn.Module], bool],
|
||||
func: Callable[[nn.Module], nn.Module],
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Args:
|
||||
@@ -796,7 +901,11 @@ def _replace_submodules(
|
||||
if predicate(root_module):
|
||||
return func(root_module)
|
||||
|
||||
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
||||
replace_list = [
|
||||
k.split(".")
|
||||
for k, m in root_module.named_modules(remove_duplicate=True)
|
||||
if predicate(m)
|
||||
]
|
||||
for *parents, k in replace_list:
|
||||
parent_module = root_module
|
||||
if len(parents) > 0:
|
||||
@@ -811,7 +920,9 @@ def _replace_submodules(
|
||||
else:
|
||||
setattr(parent_module, k, tgt_module)
|
||||
# verify that all BN are replaced
|
||||
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
|
||||
assert not any(
|
||||
predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)
|
||||
)
|
||||
return root_module
|
||||
|
||||
|
||||
@@ -844,7 +955,8 @@ class VqVae(nn.Module):
|
||||
)
|
||||
|
||||
self.encoder = MLP(
|
||||
in_channels=self.config.output_shapes["action"][0] * self.config.action_chunk_size,
|
||||
in_channels=self.config.output_shapes["action"][0]
|
||||
* self.config.action_chunk_size,
|
||||
hidden_channels=[
|
||||
config.vqvae_enc_hidden_dim,
|
||||
config.vqvae_enc_hidden_dim,
|
||||
@@ -872,9 +984,13 @@ class VqVae(nn.Module):
|
||||
# given latent vector, this function outputs the decoded action.
|
||||
output = self.decoder(latent)
|
||||
if self.config.action_chunk_size == 1:
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
|
||||
return einops.rearrange(
|
||||
output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0]
|
||||
)
|
||||
else:
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
|
||||
return einops.rearrange(
|
||||
output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0]
|
||||
)
|
||||
|
||||
def get_code(self, state):
|
||||
# in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181)
|
||||
|
||||
@@ -123,9 +123,15 @@ class CausalSelfAttention(nn.Module):
|
||||
|
||||
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
||||
q, k, v = self.c_attn(x).split(self.gpt_hidden_dim, dim=2)
|
||||
k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
|
||||
1, 2
|
||||
) # (B, nh, T, hs)
|
||||
q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
|
||||
1, 2
|
||||
) # (B, nh, T, hs)
|
||||
v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
|
||||
1, 2
|
||||
) # (B, nh, T, hs)
|
||||
|
||||
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||
@@ -133,7 +139,9 @@ class CausalSelfAttention(nn.Module):
|
||||
att = F.softmax(att, dim=-1)
|
||||
att = self.attn_dropout(att)
|
||||
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
||||
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
||||
y = (
|
||||
y.transpose(1, 2).contiguous().view(B, T, C)
|
||||
) # re-assemble all head outputs side by side
|
||||
|
||||
# output projection
|
||||
y = self.resid_dropout(self.c_proj(y))
|
||||
@@ -189,12 +197,16 @@ class GPT(nn.Module):
|
||||
"ln_f": nn.LayerNorm(config.gpt_hidden_dim),
|
||||
}
|
||||
)
|
||||
self.lm_head = nn.Linear(config.gpt_hidden_dim, config.gpt_output_dim, bias=False)
|
||||
self.lm_head = nn.Linear(
|
||||
config.gpt_hidden_dim, config.gpt_output_dim, bias=False
|
||||
)
|
||||
# init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
|
||||
self.apply(self._init_weights)
|
||||
for pn, p in self.named_parameters():
|
||||
if pn.endswith("c_proj.weight"):
|
||||
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer))
|
||||
torch.nn.init.normal_(
|
||||
p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer)
|
||||
)
|
||||
|
||||
# report number of parameters
|
||||
n_params = sum(p.numel() for p in self.parameters())
|
||||
@@ -208,11 +220,17 @@ class GPT(nn.Module):
|
||||
), f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
|
||||
|
||||
# positional encodings that are added to the input embeddings
|
||||
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
|
||||
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(
|
||||
0
|
||||
) # shape (1, t)
|
||||
|
||||
# forward the GPT model itself
|
||||
tok_emb = self.transformer.wte(input) # token embeddings of shape (b, t, gpt_hidden_dim)
|
||||
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim)
|
||||
tok_emb = self.transformer.wte(
|
||||
input
|
||||
) # token embeddings of shape (b, t, gpt_hidden_dim)
|
||||
pos_emb = self.transformer.wpe(
|
||||
pos
|
||||
) # position embeddings of shape (1, t, gpt_hidden_dim)
|
||||
x = self.transformer.drop(tok_emb + pos_emb)
|
||||
for block in self.transformer.h:
|
||||
x = block(x)
|
||||
@@ -237,7 +255,9 @@ class GPT(nn.Module):
|
||||
# but want to use a smaller block size for some smaller, simpler model
|
||||
assert gpt_block_size <= self.config.gpt_block_size
|
||||
self.config.gpt_block_size = gpt_block_size
|
||||
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:gpt_block_size])
|
||||
self.transformer.wpe.weight = nn.Parameter(
|
||||
self.transformer.wpe.weight[:gpt_block_size]
|
||||
)
|
||||
for block in self.transformer.h:
|
||||
block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size]
|
||||
|
||||
@@ -270,7 +290,9 @@ class GPT(nn.Module):
|
||||
param_dict = dict(self.named_parameters())
|
||||
inter_params = decay & no_decay
|
||||
union_params = decay | no_decay
|
||||
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
|
||||
assert (
|
||||
len(inter_params) == 0
|
||||
), "parameters {} made it into both decay/no_decay sets!".format(
|
||||
str(inter_params)
|
||||
)
|
||||
assert (
|
||||
@@ -368,8 +390,12 @@ class ResidualVQ(nn.Module):
|
||||
codebook_input_dim = codebook_dim * heads
|
||||
|
||||
requires_projection = codebook_input_dim != dim
|
||||
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
|
||||
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
|
||||
self.project_in = (
|
||||
nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
self.project_out = (
|
||||
nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
|
||||
self.num_quantizers = num_quantizers
|
||||
|
||||
@@ -377,7 +403,10 @@ class ResidualVQ(nn.Module):
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
VectorQuantize(
|
||||
dim=codebook_dim, codebook_dim=codebook_dim, accept_image_fmap=accept_image_fmap, **kwargs
|
||||
dim=codebook_dim,
|
||||
codebook_dim=codebook_dim,
|
||||
accept_image_fmap=accept_image_fmap,
|
||||
**kwargs,
|
||||
)
|
||||
for _ in range(num_quantizers)
|
||||
]
|
||||
@@ -448,7 +477,9 @@ class ResidualVQ(nn.Module):
|
||||
|
||||
return all_codes
|
||||
|
||||
def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp=None):
|
||||
def forward(
|
||||
self, x, indices=None, return_all_codes=False, sample_codebook_temp=None
|
||||
):
|
||||
"""
|
||||
For given input tensor x, this function will return the quantized output, the indices of the quantized output, and the loss.
|
||||
First, the input tensor x is projected to the codebook dimension. Then, the input tensor x is passed through Nq layers of VectorQuantize.
|
||||
@@ -477,13 +508,17 @@ class ResidualVQ(nn.Module):
|
||||
), "some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss"
|
||||
ce_losses = []
|
||||
|
||||
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
|
||||
should_quantize_dropout = (
|
||||
self.training and self.quantize_dropout and not return_loss
|
||||
)
|
||||
|
||||
# sample a layer index at which to dropout further residual quantization
|
||||
# also prepare null indices and loss
|
||||
|
||||
if should_quantize_dropout:
|
||||
rand_quantize_dropout_index = randrange(self.quantize_dropout_cutoff_index, num_quant)
|
||||
rand_quantize_dropout_index = randrange(
|
||||
self.quantize_dropout_cutoff_index, num_quant
|
||||
)
|
||||
|
||||
if quant_dropout_multiple_of != 1:
|
||||
rand_quantize_dropout_index = (
|
||||
@@ -492,14 +527,23 @@ class ResidualVQ(nn.Module):
|
||||
- 1
|
||||
)
|
||||
|
||||
null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2])
|
||||
null_indices = torch.full(null_indices_shape, -1.0, device=device, dtype=torch.long)
|
||||
null_indices_shape = (
|
||||
(x.shape[0], *x.shape[-2:])
|
||||
if self.accept_image_fmap
|
||||
else tuple(x.shape[:2])
|
||||
)
|
||||
null_indices = torch.full(
|
||||
null_indices_shape, -1.0, device=device, dtype=torch.long
|
||||
)
|
||||
null_loss = torch.full((1,), 0.0, device=device, dtype=x.dtype)
|
||||
|
||||
# go through the layers
|
||||
|
||||
for quantizer_index, layer in enumerate(self.layers):
|
||||
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
|
||||
if (
|
||||
should_quantize_dropout
|
||||
and quantizer_index > rand_quantize_dropout_index
|
||||
):
|
||||
all_indices.append(null_indices)
|
||||
all_losses.append(null_loss)
|
||||
continue
|
||||
@@ -539,7 +583,9 @@ class ResidualVQ(nn.Module):
|
||||
|
||||
# stack all losses and indices
|
||||
|
||||
all_losses, all_indices = map(partial(torch.stack, dim=-1), (all_losses, all_indices))
|
||||
all_losses, all_indices = map(
|
||||
partial(torch.stack, dim=-1), (all_losses, all_indices)
|
||||
)
|
||||
|
||||
ret = (quantized_out, all_indices, all_losses)
|
||||
|
||||
@@ -599,8 +645,12 @@ class VectorQuantize(nn.Module):
|
||||
codebook_input_dim = codebook_dim * heads
|
||||
|
||||
requires_projection = codebook_input_dim != dim
|
||||
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
|
||||
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
|
||||
self.project_in = (
|
||||
nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
self.project_out = (
|
||||
nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
|
||||
self.eps = eps
|
||||
self.commitment_weight = commitment_weight
|
||||
@@ -614,10 +664,14 @@ class VectorQuantize(nn.Module):
|
||||
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
|
||||
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
|
||||
|
||||
assert not (ema_update and learnable_codebook), "learnable codebook not compatible with EMA update"
|
||||
assert not (
|
||||
ema_update and learnable_codebook
|
||||
), "learnable codebook not compatible with EMA update"
|
||||
|
||||
assert 0 <= sync_update_v <= 1.0
|
||||
assert not (sync_update_v > 0.0 and not learnable_codebook), "learnable codebook must be turned on"
|
||||
assert not (
|
||||
sync_update_v > 0.0 and not learnable_codebook
|
||||
), "learnable codebook must be turned on"
|
||||
|
||||
self.sync_update_v = sync_update_v
|
||||
|
||||
@@ -629,7 +683,9 @@ class VectorQuantize(nn.Module):
|
||||
)
|
||||
|
||||
if sync_codebook is None:
|
||||
sync_codebook = distributed.is_initialized() and distributed.get_world_size() > 1
|
||||
sync_codebook = (
|
||||
distributed.is_initialized() and distributed.get_world_size() > 1
|
||||
)
|
||||
|
||||
codebook_kwargs = {
|
||||
"dim": codebook_dim,
|
||||
@@ -794,11 +850,17 @@ class VectorQuantize(nn.Module):
|
||||
|
||||
# quantize again
|
||||
|
||||
quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)
|
||||
quantize, embed_ind, distances = self._codebook(
|
||||
x, **codebook_forward_kwargs
|
||||
)
|
||||
|
||||
if self.training:
|
||||
# determine code to use for commitment loss
|
||||
maybe_detach = torch.detach if not self.learnable_codebook or freeze_codebook else identity
|
||||
maybe_detach = (
|
||||
torch.detach
|
||||
if not self.learnable_codebook or freeze_codebook
|
||||
else identity
|
||||
)
|
||||
|
||||
commit_quantize = maybe_detach(quantize)
|
||||
|
||||
@@ -808,7 +870,9 @@ class VectorQuantize(nn.Module):
|
||||
|
||||
if self.sync_update_v > 0.0:
|
||||
# (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
|
||||
quantize = quantize + self.sync_update_v * (quantize - quantize.detach())
|
||||
quantize = quantize + self.sync_update_v * (
|
||||
quantize - quantize.detach()
|
||||
)
|
||||
|
||||
# function for calculating cross entropy loss to distance matrix
|
||||
# used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss
|
||||
@@ -841,7 +905,9 @@ class VectorQuantize(nn.Module):
|
||||
embed_ind = rearrange(embed_ind, "1 (b h) n -> b n h", h=heads)
|
||||
|
||||
if self.accept_image_fmap:
|
||||
embed_ind = rearrange(embed_ind, "b (h w) ... -> b h w ...", h=height, w=width)
|
||||
embed_ind = rearrange(
|
||||
embed_ind, "b (h w) ... -> b h w ...", h=height, w=width
|
||||
)
|
||||
|
||||
if only_one:
|
||||
embed_ind = rearrange(embed_ind, "b 1 -> b")
|
||||
@@ -895,8 +961,12 @@ class VectorQuantize(nn.Module):
|
||||
|
||||
num_codes = codebook.shape[-2]
|
||||
|
||||
if (self.orthogonal_reg_max_codes is not None) and num_codes > self.orthogonal_reg_max_codes:
|
||||
rand_ids = torch.randperm(num_codes, device=device)[: self.orthogonal_reg_max_codes]
|
||||
if (
|
||||
self.orthogonal_reg_max_codes is not None
|
||||
) and num_codes > self.orthogonal_reg_max_codes:
|
||||
rand_ids = torch.randperm(num_codes, device=device)[
|
||||
: self.orthogonal_reg_max_codes
|
||||
]
|
||||
codebook = codebook[:, rand_ids]
|
||||
|
||||
orthogonal_reg_loss = orthogonal_loss_fn(codebook)
|
||||
@@ -928,7 +998,9 @@ class VectorQuantize(nn.Module):
|
||||
# if masking, only return quantized for where mask has True
|
||||
|
||||
if mask is not None:
|
||||
quantize = torch.where(rearrange(mask, "... -> ... 1"), quantize, orig_input)
|
||||
quantize = torch.where(
|
||||
rearrange(mask, "... -> ... 1"), quantize, orig_input
|
||||
)
|
||||
|
||||
return quantize, embed_ind, loss
|
||||
|
||||
@@ -1038,7 +1110,9 @@ def sample_vectors(samples, num):
|
||||
|
||||
|
||||
def batched_sample_vectors(samples, num):
|
||||
return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0)
|
||||
return torch.stack(
|
||||
[sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0
|
||||
)
|
||||
|
||||
|
||||
def pad_shape(shape, size, dim=0):
|
||||
@@ -1089,7 +1163,9 @@ def sample_vectors_distributed(local_samples, num):
|
||||
all_num_samples = all_gather_sizes(local_samples, dim=0)
|
||||
|
||||
if rank == 0:
|
||||
samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum())
|
||||
samples_per_rank = sample_multinomial(
|
||||
num, all_num_samples / all_num_samples.sum()
|
||||
)
|
||||
else:
|
||||
samples_per_rank = torch.empty_like(all_num_samples)
|
||||
|
||||
@@ -1202,7 +1278,9 @@ class EuclideanCodebook(nn.Module):
|
||||
self.eps = eps
|
||||
self.threshold_ema_dead_code = threshold_ema_dead_code
|
||||
self.reset_cluster_size = (
|
||||
reset_cluster_size if (reset_cluster_size is not None) else threshold_ema_dead_code
|
||||
reset_cluster_size
|
||||
if (reset_cluster_size is not None)
|
||||
else threshold_ema_dead_code
|
||||
)
|
||||
|
||||
assert callable(gumbel_sample)
|
||||
@@ -1213,8 +1291,14 @@ class EuclideanCodebook(nn.Module):
|
||||
use_ddp and num_codebooks > 1 and kmeans_init
|
||||
), "kmeans init is not compatible with multiple codebooks in distributed environment for now"
|
||||
|
||||
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
|
||||
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
|
||||
self.sample_fn = (
|
||||
sample_vectors_distributed
|
||||
if use_ddp and sync_kmeans
|
||||
else batched_sample_vectors
|
||||
)
|
||||
self.kmeans_all_reduce_fn = (
|
||||
distributed.all_reduce if use_ddp and sync_kmeans else noop
|
||||
)
|
||||
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
|
||||
|
||||
self.register_buffer("initted", torch.Tensor([not kmeans_init]))
|
||||
@@ -1353,7 +1437,9 @@ class EuclideanCodebook(nn.Module):
|
||||
distributed.all_reduce(variance_numer)
|
||||
batch_variance = variance_numer / num_vectors
|
||||
|
||||
self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay)
|
||||
self.update_with_decay(
|
||||
"batch_variance", batch_variance, self.affine_param_batch_decay
|
||||
)
|
||||
|
||||
def replace(self, batch_samples, batch_mask):
|
||||
for ind, (samples, mask) in enumerate(
|
||||
@@ -1362,7 +1448,9 @@ class EuclideanCodebook(nn.Module):
|
||||
if not torch.any(mask):
|
||||
continue
|
||||
|
||||
sampled = self.sample_fn(rearrange(samples, "... -> 1 ..."), mask.sum().item())
|
||||
sampled = self.sample_fn(
|
||||
rearrange(samples, "... -> 1 ..."), mask.sum().item()
|
||||
)
|
||||
sampled = rearrange(sampled, "1 ... -> ...")
|
||||
|
||||
self.embed.data[ind][mask] = sampled
|
||||
@@ -1386,7 +1474,9 @@ class EuclideanCodebook(nn.Module):
|
||||
def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
|
||||
needs_codebook_dim = x.ndim < 4
|
||||
sample_codebook_temp = (
|
||||
sample_codebook_temp if (sample_codebook_temp is not None) else self.sample_codebook_temp
|
||||
sample_codebook_temp
|
||||
if (sample_codebook_temp is not None)
|
||||
else self.sample_codebook_temp
|
||||
)
|
||||
|
||||
x = x.float()
|
||||
@@ -1414,7 +1504,9 @@ class EuclideanCodebook(nn.Module):
|
||||
if self.affine_param:
|
||||
codebook_std = self.codebook_variance.clamp(min=1e-5).sqrt()
|
||||
batch_std = self.batch_variance.clamp(min=1e-5).sqrt()
|
||||
embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean
|
||||
embed = (embed - self.codebook_mean) * (
|
||||
batch_std / codebook_std
|
||||
) + self.batch_mean
|
||||
|
||||
dist = -cdist(flatten, embed)
|
||||
|
||||
@@ -1432,7 +1524,9 @@ class EuclideanCodebook(nn.Module):
|
||||
|
||||
if self.training and self.ema_update and not freeze_codebook:
|
||||
if self.affine_param:
|
||||
flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean
|
||||
flatten = (flatten - self.batch_mean) * (
|
||||
codebook_std / batch_std
|
||||
) + self.codebook_mean
|
||||
|
||||
if mask is not None:
|
||||
embed_onehot[~mask] = 0.0
|
||||
@@ -1455,7 +1549,9 @@ class EuclideanCodebook(nn.Module):
|
||||
self.expire_codes_(x)
|
||||
|
||||
if needs_codebook_dim:
|
||||
quantize, embed_ind = tuple(rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind))
|
||||
quantize, embed_ind = tuple(
|
||||
rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind)
|
||||
)
|
||||
|
||||
dist = unpack_one(dist, ps, "h * d")
|
||||
|
||||
|
||||
@@ -65,7 +65,9 @@ def save_image(img_array, serial_number, frame_index, images_dir):
|
||||
img.save(str(path), quality=100)
|
||||
logging.info(f"Saved image: {path}")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to save image for camera {serial_number} frame {frame_index}: {e}")
|
||||
logging.error(
|
||||
f"Failed to save image for camera {serial_number} frame {frame_index}: {e}"
|
||||
)
|
||||
|
||||
|
||||
def save_images_from_cameras(
|
||||
@@ -94,7 +96,9 @@ def save_images_from_cameras(
|
||||
cameras = []
|
||||
for cam_sn in serial_numbers:
|
||||
print(f"{cam_sn=}")
|
||||
camera = IntelRealSenseCamera(cam_sn, fps=fps, width=width, height=height, mock=mock)
|
||||
camera = IntelRealSenseCamera(
|
||||
cam_sn, fps=fps, width=width, height=height, mock=mock
|
||||
)
|
||||
camera.connect()
|
||||
print(
|
||||
f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.width}, height={camera.height}, color_mode={camera.color_mode})"
|
||||
@@ -140,7 +144,9 @@ def save_images_from_cameras(
|
||||
if time.perf_counter() - start_time > record_time_s:
|
||||
break
|
||||
|
||||
print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
|
||||
print(
|
||||
f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}"
|
||||
)
|
||||
|
||||
frame_index += 1
|
||||
finally:
|
||||
@@ -182,8 +188,12 @@ class IntelRealSenseCameraConfig:
|
||||
|
||||
self.channels = 3
|
||||
|
||||
at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None
|
||||
at_least_one_is_none = self.fps is None or self.width is None or self.height is None
|
||||
at_least_one_is_not_none = (
|
||||
self.fps is not None or self.width is not None or self.height is not None
|
||||
)
|
||||
at_least_one_is_none = (
|
||||
self.fps is None or self.width is None or self.height is None
|
||||
)
|
||||
if at_least_one_is_not_none and at_least_one_is_none:
|
||||
raise ValueError(
|
||||
"For `fps`, `width` and `height`, either all of them need to be set, or none of them, "
|
||||
@@ -191,7 +201,9 @@ class IntelRealSenseCameraConfig:
|
||||
)
|
||||
|
||||
if self.rotation not in [-90, None, 90, 180]:
|
||||
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
|
||||
raise ValueError(
|
||||
f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})"
|
||||
)
|
||||
|
||||
|
||||
class IntelRealSenseCamera:
|
||||
@@ -286,7 +298,9 @@ class IntelRealSenseCamera:
|
||||
self.rotation = cv2.ROTATE_180
|
||||
|
||||
@classmethod
|
||||
def init_from_name(cls, name: str, config: IntelRealSenseCameraConfig | None = None, **kwargs):
|
||||
def init_from_name(
|
||||
cls, name: str, config: IntelRealSenseCameraConfig | None = None, **kwargs
|
||||
):
|
||||
camera_infos = find_cameras()
|
||||
camera_names = [cam["name"] for cam in camera_infos]
|
||||
this_name_count = Counter(camera_names)[name]
|
||||
@@ -296,7 +310,9 @@ class IntelRealSenseCamera:
|
||||
f"Multiple {name} cameras have been detected. Please use their serial number to instantiate them."
|
||||
)
|
||||
|
||||
name_to_serial_dict = {cam["name"]: cam["serial_number"] for cam in camera_infos}
|
||||
name_to_serial_dict = {
|
||||
cam["name"]: cam["serial_number"] for cam in camera_infos
|
||||
}
|
||||
cam_sn = name_to_serial_dict[name]
|
||||
|
||||
if config is None:
|
||||
@@ -323,13 +339,17 @@ class IntelRealSenseCamera:
|
||||
|
||||
if self.fps and self.width and self.height:
|
||||
# TODO(rcadene): can we set rgb8 directly?
|
||||
config.enable_stream(rs.stream.color, self.width, self.height, rs.format.rgb8, self.fps)
|
||||
config.enable_stream(
|
||||
rs.stream.color, self.width, self.height, rs.format.rgb8, self.fps
|
||||
)
|
||||
else:
|
||||
config.enable_stream(rs.stream.color)
|
||||
|
||||
if self.use_depth:
|
||||
if self.fps and self.width and self.height:
|
||||
config.enable_stream(rs.stream.depth, self.width, self.height, rs.format.z16, self.fps)
|
||||
config.enable_stream(
|
||||
rs.stream.depth, self.width, self.height, rs.format.z16, self.fps
|
||||
)
|
||||
else:
|
||||
config.enable_stream(rs.stream.depth)
|
||||
|
||||
@@ -362,7 +382,9 @@ class IntelRealSenseCamera:
|
||||
actual_height = color_profile.height()
|
||||
|
||||
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
|
||||
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
|
||||
if self.fps is not None and not math.isclose(
|
||||
self.fps, actual_fps, rel_tol=1e-3
|
||||
):
|
||||
# Using `OSError` since it's a broad that encompasses issues related to device communication
|
||||
raise OSError(
|
||||
f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
|
||||
@@ -382,7 +404,9 @@ class IntelRealSenseCamera:
|
||||
|
||||
self.is_connected = True
|
||||
|
||||
def read(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
|
||||
def read(
|
||||
self, temporary_color: str | None = None
|
||||
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
|
||||
"""Read a frame from the camera returned in the format height x width x channels (e.g. 480 x 640 x 3)
|
||||
of type `np.uint8`, contrarily to the pytorch format which is float channel first.
|
||||
|
||||
@@ -409,11 +433,15 @@ class IntelRealSenseCamera:
|
||||
color_frame = frame.get_color_frame()
|
||||
|
||||
if not color_frame:
|
||||
raise OSError(f"Can't capture color image from IntelRealSenseCamera({self.serial_number}).")
|
||||
raise OSError(
|
||||
f"Can't capture color image from IntelRealSenseCamera({self.serial_number})."
|
||||
)
|
||||
|
||||
color_image = np.asanyarray(color_frame.get_data())
|
||||
|
||||
requested_color_mode = self.color_mode if temporary_color is None else temporary_color
|
||||
requested_color_mode = (
|
||||
self.color_mode if temporary_color is None else temporary_color
|
||||
)
|
||||
if requested_color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided."
|
||||
@@ -441,7 +469,9 @@ class IntelRealSenseCamera:
|
||||
if self.use_depth:
|
||||
depth_frame = frame.get_depth_frame()
|
||||
if not depth_frame:
|
||||
raise OSError(f"Can't capture depth image from IntelRealSenseCamera({self.serial_number}).")
|
||||
raise OSError(
|
||||
f"Can't capture depth image from IntelRealSenseCamera({self.serial_number})."
|
||||
)
|
||||
|
||||
depth_map = np.asanyarray(depth_frame.get_data())
|
||||
|
||||
@@ -483,7 +513,9 @@ class IntelRealSenseCamera:
|
||||
# TODO(rcadene, aliberts): intelrealsense has diverged compared to opencv over here
|
||||
num_tries += 1
|
||||
time.sleep(1 / self.fps)
|
||||
if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()):
|
||||
if num_tries > self.fps and (
|
||||
self.thread.ident is None or not self.thread.is_alive()
|
||||
):
|
||||
raise Exception(
|
||||
"The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called."
|
||||
)
|
||||
|
||||
@@ -31,10 +31,14 @@ from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
MAX_OPENCV_INDEX = 60
|
||||
|
||||
|
||||
def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
|
||||
def find_cameras(
|
||||
raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False
|
||||
) -> list[dict]:
|
||||
cameras = []
|
||||
if platform.system() == "Linux":
|
||||
print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports")
|
||||
print(
|
||||
"Linux detected. Finding available camera indices through scanning '/dev/video*' ports"
|
||||
)
|
||||
possible_ports = [str(port) for port in Path("/dev").glob("video*")]
|
||||
ports = _find_cameras(possible_ports, mock=mock)
|
||||
for port in ports:
|
||||
@@ -165,7 +169,9 @@ def save_images_from_cameras(
|
||||
dt_s = time.perf_counter() - now
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
|
||||
print(
|
||||
f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}"
|
||||
)
|
||||
|
||||
if time.perf_counter() - start_time > record_time_s:
|
||||
break
|
||||
@@ -205,7 +211,9 @@ class OpenCVCameraConfig:
|
||||
self.channels = 3
|
||||
|
||||
if self.rotation not in [-90, None, 90, 180]:
|
||||
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
|
||||
raise ValueError(
|
||||
f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})"
|
||||
)
|
||||
|
||||
|
||||
class OpenCVCamera:
|
||||
@@ -247,7 +255,12 @@ class OpenCVCamera:
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, camera_index: int | str, config: OpenCVCameraConfig | None = None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
camera_index: int | str,
|
||||
config: OpenCVCameraConfig | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if config is None:
|
||||
config = OpenCVCameraConfig()
|
||||
|
||||
@@ -261,12 +274,16 @@ class OpenCVCamera:
|
||||
if platform.system() == "Linux":
|
||||
if isinstance(self.camera_index, int):
|
||||
self.port = Path(f"/dev/video{self.camera_index}")
|
||||
elif isinstance(self.camera_index, str) and is_valid_unix_path(self.camera_index):
|
||||
elif isinstance(self.camera_index, str) and is_valid_unix_path(
|
||||
self.camera_index
|
||||
):
|
||||
self.port = Path(self.camera_index)
|
||||
# Retrieve the camera index from a potentially symlinked path
|
||||
self.camera_index = get_camera_index_from_unix_port(self.port)
|
||||
else:
|
||||
raise ValueError(f"Please check the provided camera_index: {camera_index}")
|
||||
raise ValueError(
|
||||
f"Please check the provided camera_index: {camera_index}"
|
||||
)
|
||||
|
||||
self.fps = config.fps
|
||||
self.width = config.width
|
||||
@@ -298,7 +315,9 @@ class OpenCVCamera:
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
|
||||
raise RobotDeviceAlreadyConnectedError(
|
||||
f"OpenCVCamera({self.camera_index}) is already connected."
|
||||
)
|
||||
|
||||
if self.mock:
|
||||
import tests.mock_cv2 as cv2
|
||||
@@ -309,7 +328,11 @@ class OpenCVCamera:
|
||||
# when other threads are used to save the images.
|
||||
cv2.setNumThreads(1)
|
||||
|
||||
camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index
|
||||
camera_idx = (
|
||||
f"/dev/video{self.camera_index}"
|
||||
if platform.system() == "Linux"
|
||||
else self.camera_index
|
||||
)
|
||||
# First create a temporary camera trying to access `camera_index`,
|
||||
# and verify it is a valid camera by calling `isOpened`.
|
||||
tmp_camera = cv2.VideoCapture(camera_idx)
|
||||
@@ -349,16 +372,22 @@ class OpenCVCamera:
|
||||
actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
||||
|
||||
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
|
||||
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
|
||||
if self.fps is not None and not math.isclose(
|
||||
self.fps, actual_fps, rel_tol=1e-3
|
||||
):
|
||||
# Using `OSError` since it's a broad that encompasses issues related to device communication
|
||||
raise OSError(
|
||||
f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}."
|
||||
)
|
||||
if self.width is not None and not math.isclose(self.width, actual_width, rel_tol=1e-3):
|
||||
if self.width is not None and not math.isclose(
|
||||
self.width, actual_width, rel_tol=1e-3
|
||||
):
|
||||
raise OSError(
|
||||
f"Can't set {self.width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}."
|
||||
)
|
||||
if self.height is not None and not math.isclose(self.height, actual_height, rel_tol=1e-3):
|
||||
if self.height is not None and not math.isclose(
|
||||
self.height, actual_height, rel_tol=1e-3
|
||||
):
|
||||
raise OSError(
|
||||
f"Can't set {self.height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}."
|
||||
)
|
||||
@@ -388,7 +417,9 @@ class OpenCVCamera:
|
||||
if not ret:
|
||||
raise OSError(f"Can't capture color image from camera {self.camera_index}.")
|
||||
|
||||
requested_color_mode = self.color_mode if temporary_color_mode is None else temporary_color_mode
|
||||
requested_color_mode = (
|
||||
self.color_mode if temporary_color_mode is None else temporary_color_mode
|
||||
)
|
||||
|
||||
if requested_color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
|
||||
@@ -23,11 +23,17 @@ from lerobot.common.datasets.utils import get_features_from_robot
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed
|
||||
from lerobot.common.utils.utils import (
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.scripts.eval import get_pretrained_policy_path
|
||||
|
||||
|
||||
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
|
||||
def log_control_info(
|
||||
robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None
|
||||
):
|
||||
log_items = []
|
||||
if episode_index is not None:
|
||||
log_items.append(f"ep:{episode_index}")
|
||||
@@ -98,7 +104,9 @@ def predict_action(observation, policy, device, use_amp):
|
||||
observation = copy(observation)
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||||
torch.autocast(device_type=device.type)
|
||||
if device.type == "cuda" and use_amp
|
||||
else nullcontext(),
|
||||
):
|
||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||
for name in observation:
|
||||
@@ -154,7 +162,9 @@ def init_keyboard_listener(assign_rewards=False):
|
||||
print("Right arrow key pressed. Exiting loop...")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||
print(
|
||||
"Left arrow key pressed. Exiting loop and rerecord the last episode..."
|
||||
)
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
@@ -180,8 +190,12 @@ def init_keyboard_listener(assign_rewards=False):
|
||||
def init_policy(pretrained_policy_name_or_path, policy_overrides):
|
||||
"""Instantiate the policy and load fps, device and use_amp from config yaml"""
|
||||
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
|
||||
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
||||
hydra_cfg = init_hydra_config(
|
||||
pretrained_policy_path / "config.yaml", policy_overrides
|
||||
)
|
||||
policy = make_policy(
|
||||
hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path
|
||||
)
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||
@@ -270,7 +284,9 @@ def control_loop(
|
||||
raise ValueError("When `teleoperate` is True, `policy` should be None.")
|
||||
|
||||
if dataset is not None and fps is not None and dataset.fps != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
||||
raise ValueError(
|
||||
f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps})."
|
||||
)
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
@@ -297,7 +313,9 @@ def control_loop(
|
||||
frame = {**observation, **action}
|
||||
if "next.reward" in events:
|
||||
frame["next.reward"] = events["next.reward"]
|
||||
frame["next.done"] = (events["next.reward"] == 1) or (events["exit_early"])
|
||||
frame["next.done"] = (events["next.reward"] == 1) or (
|
||||
events["exit_early"]
|
||||
)
|
||||
dataset.add_frame(frame)
|
||||
|
||||
# if frame["next.done"]:
|
||||
@@ -306,7 +324,9 @@ def control_loop(
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.imshow(
|
||||
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
|
||||
)
|
||||
cv2.waitKey(1)
|
||||
|
||||
if fps is not None:
|
||||
@@ -347,7 +367,7 @@ def reset_environment(robot, events, reset_time_s):
|
||||
def reset_follower_position(robot: Robot, target_position):
|
||||
current_position = robot.follower_arms["main"].read("Present_Position")
|
||||
trajectory = torch.from_numpy(
|
||||
np.linspace(current_position, target_position, 30)
|
||||
np.linspace(current_position, target_position, 50)
|
||||
) # NOTE: 30 is just an aribtrary number
|
||||
for pose in trajectory:
|
||||
robot.send_action(pose)
|
||||
@@ -384,7 +404,11 @@ def sanity_check_dataset_name(repo_id, policy):
|
||||
|
||||
|
||||
def sanity_check_dataset_robot_compatibility(
|
||||
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool, extra_features: dict = None
|
||||
dataset: LeRobotDataset,
|
||||
robot: Robot,
|
||||
fps: int,
|
||||
use_videos: bool,
|
||||
extra_features: dict = None,
|
||||
) -> None:
|
||||
features_from_robot = get_features_from_robot(robot, use_videos)
|
||||
if extra_features is not None:
|
||||
@@ -398,11 +422,14 @@ def sanity_check_dataset_robot_compatibility(
|
||||
|
||||
mismatches = []
|
||||
for field, dataset_value, present_value in fields:
|
||||
diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"])
|
||||
diff = DeepDiff(
|
||||
dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"]
|
||||
)
|
||||
if diff:
|
||||
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
|
||||
|
||||
if mismatches:
|
||||
raise ValueError(
|
||||
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
|
||||
"Dataset metadata compatibility check failed with mismatches:\n"
|
||||
+ "\n".join(mismatches)
|
||||
)
|
||||
|
||||
@@ -8,7 +8,10 @@ from copy import deepcopy
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
)
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
PROTOCOL_VERSION = 2.0
|
||||
@@ -143,7 +146,9 @@ NUM_READ_RETRY = 10
|
||||
NUM_WRITE_RETRY = 10
|
||||
|
||||
|
||||
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
|
||||
def convert_degrees_to_steps(
|
||||
degrees: float | np.ndarray, models: str | list[str]
|
||||
) -> np.ndarray:
|
||||
"""This function converts the degree range to the step range for indicating motors rotation.
|
||||
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
|
||||
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
|
||||
@@ -378,7 +383,9 @@ class DynamixelMotorsBus:
|
||||
indices = []
|
||||
for idx in tqdm.tqdm(possible_ids):
|
||||
try:
|
||||
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
|
||||
present_idx = self.read_with_motor_ids(
|
||||
self.motor_models, [idx], "ID", num_retry=num_retry
|
||||
)[0]
|
||||
except ConnectionError:
|
||||
continue
|
||||
|
||||
@@ -394,7 +401,9 @@ class DynamixelMotorsBus:
|
||||
def set_bus_baudrate(self, baudrate):
|
||||
present_bus_baudrate = self.port_handler.getBaudRate()
|
||||
if present_bus_baudrate != baudrate:
|
||||
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
|
||||
print(
|
||||
f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}."
|
||||
)
|
||||
self.port_handler.setBaudRate(baudrate)
|
||||
|
||||
if self.port_handler.getBaudRate() != baudrate:
|
||||
@@ -415,7 +424,9 @@ class DynamixelMotorsBus:
|
||||
def set_calibration(self, calibration: dict[str, list]):
|
||||
self.calibration = calibration
|
||||
|
||||
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
def apply_calibration_autocorrect(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
"""This function applies the calibration, automatically detects out of range errors for motors values and attempts to correct.
|
||||
|
||||
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
|
||||
@@ -428,7 +439,9 @@ class DynamixelMotorsBus:
|
||||
values = self.apply_calibration(values, motor_names)
|
||||
return values
|
||||
|
||||
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
def apply_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
|
||||
a "zero position" at 0 degree.
|
||||
|
||||
@@ -503,7 +516,9 @@ class DynamixelMotorsBus:
|
||||
|
||||
return values
|
||||
|
||||
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
def autocorrect_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
"""This function automatically detects issues with values of motors after calibration, and correct for these issues.
|
||||
|
||||
Some motors might have values outside of expected maximum bounds after calibration.
|
||||
@@ -545,15 +560,23 @@ class DynamixelMotorsBus:
|
||||
values[i] *= -1
|
||||
|
||||
# Convert from initial range to range [-180, 180] degrees
|
||||
calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
|
||||
calib_val = (
|
||||
(values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
)
|
||||
in_range = (calib_val > LOWER_BOUND_DEGREE) and (
|
||||
calib_val < UPPER_BOUND_DEGREE
|
||||
)
|
||||
|
||||
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees
|
||||
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
|
||||
# (- (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= ((resolution // 2) - values[i] - homing_offset) / resolution
|
||||
low_factor = (-(resolution // 2) - values[i] - homing_offset) / resolution
|
||||
upp_factor = ((resolution // 2) - values[i] - homing_offset) / resolution
|
||||
low_factor = (
|
||||
-(resolution // 2) - values[i] - homing_offset
|
||||
) / resolution
|
||||
upp_factor = (
|
||||
(resolution // 2) - values[i] - homing_offset
|
||||
) / resolution
|
||||
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
start_pos = self.calibration["start_pos"][calib_idx]
|
||||
@@ -561,7 +584,9 @@ class DynamixelMotorsBus:
|
||||
|
||||
# Convert from initial range to range [0, 100] in %
|
||||
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
|
||||
in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
|
||||
in_range = (calib_val > LOWER_BOUND_LINEAR) and (
|
||||
calib_val < UPPER_BOUND_LINEAR
|
||||
)
|
||||
|
||||
# Solve this inequality to find the factor to shift the range into [0, 100] %
|
||||
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
|
||||
@@ -577,19 +602,27 @@ class DynamixelMotorsBus:
|
||||
factor = math.ceil(low_factor)
|
||||
|
||||
if factor > upp_factor:
|
||||
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
|
||||
raise ValueError(
|
||||
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
|
||||
)
|
||||
else:
|
||||
factor = math.ceil(upp_factor)
|
||||
|
||||
if factor > low_factor:
|
||||
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
|
||||
raise ValueError(
|
||||
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
|
||||
)
|
||||
|
||||
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
||||
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
||||
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
out_of_range_str = (
|
||||
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
)
|
||||
in_range_str = (
|
||||
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
)
|
||||
|
||||
logging.warning(
|
||||
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
|
||||
@@ -599,7 +632,9 @@ class DynamixelMotorsBus:
|
||||
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
||||
self.calibration["homing_offset"][calib_idx] += resolution * factor
|
||||
|
||||
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
def revert_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
"""Inverse of `apply_calibration`."""
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
@@ -638,7 +673,9 @@ class DynamixelMotorsBus:
|
||||
values = np.round(values).astype(np.int32)
|
||||
return values
|
||||
|
||||
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
||||
def read_with_motor_ids(
|
||||
self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY
|
||||
):
|
||||
if self.mock:
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
@@ -740,7 +777,9 @@ class DynamixelMotorsBus:
|
||||
values = self.apply_calibration_autocorrect(values, motor_names)
|
||||
|
||||
# log the number of seconds it took to read the data from the motors
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
|
||||
delta_ts_name = get_log_name(
|
||||
"delta_timestamp_s", "read", data_name, motor_names
|
||||
)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# log the utc time at which the data was received
|
||||
@@ -749,7 +788,9 @@ class DynamixelMotorsBus:
|
||||
|
||||
return values
|
||||
|
||||
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
||||
def write_with_motor_ids(
|
||||
self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY
|
||||
):
|
||||
if self.mock:
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
@@ -778,7 +819,12 @@ class DynamixelMotorsBus:
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
|
||||
def write(
|
||||
self,
|
||||
data_name,
|
||||
values: int | float | np.ndarray,
|
||||
motor_names: str | list[str] | None = None,
|
||||
):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
||||
@@ -839,7 +885,9 @@ class DynamixelMotorsBus:
|
||||
)
|
||||
|
||||
# log the number of seconds it took to write the data to the motors
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
|
||||
delta_ts_name = get_log_name(
|
||||
"delta_timestamp_s", "write", data_name, motor_names
|
||||
)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# TODO(rcadene): should we log the time before sending the write command?
|
||||
|
||||
@@ -8,7 +8,10 @@ from copy import deepcopy
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
)
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
PROTOCOL_VERSION = 0
|
||||
@@ -122,7 +125,9 @@ NUM_READ_RETRY = 20
|
||||
NUM_WRITE_RETRY = 20
|
||||
|
||||
|
||||
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
|
||||
def convert_degrees_to_steps(
|
||||
degrees: float | np.ndarray, models: str | list[str]
|
||||
) -> np.ndarray:
|
||||
"""This function converts the degree range to the step range for indicating motors rotation.
|
||||
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
|
||||
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
|
||||
@@ -358,7 +363,9 @@ class FeetechMotorsBus:
|
||||
indices = []
|
||||
for idx in tqdm.tqdm(possible_ids):
|
||||
try:
|
||||
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
|
||||
present_idx = self.read_with_motor_ids(
|
||||
self.motor_models, [idx], "ID", num_retry=num_retry
|
||||
)[0]
|
||||
except ConnectionError:
|
||||
continue
|
||||
|
||||
@@ -374,7 +381,9 @@ class FeetechMotorsBus:
|
||||
def set_bus_baudrate(self, baudrate):
|
||||
present_bus_baudrate = self.port_handler.getBaudRate()
|
||||
if present_bus_baudrate != baudrate:
|
||||
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
|
||||
print(
|
||||
f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}."
|
||||
)
|
||||
self.port_handler.setBaudRate(baudrate)
|
||||
|
||||
if self.port_handler.getBaudRate() != baudrate:
|
||||
@@ -395,7 +404,9 @@ class FeetechMotorsBus:
|
||||
def set_calibration(self, calibration: dict[str, list]):
|
||||
self.calibration = calibration
|
||||
|
||||
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
def apply_calibration_autocorrect(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
"""This function apply the calibration, automatically detects out of range errors for motors values and attempt to correct.
|
||||
|
||||
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
|
||||
@@ -408,7 +419,9 @@ class FeetechMotorsBus:
|
||||
values = self.apply_calibration(values, motor_names)
|
||||
return values
|
||||
|
||||
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
def apply_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
|
||||
a "zero position" at 0 degree.
|
||||
|
||||
@@ -482,7 +495,9 @@ class FeetechMotorsBus:
|
||||
|
||||
return values
|
||||
|
||||
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
def autocorrect_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
"""This function automatically detects issues with values of motors after calibration, and correct for these issues.
|
||||
|
||||
Some motors might have values outside of expected maximum bounds after calibration.
|
||||
@@ -521,18 +536,26 @@ class FeetechMotorsBus:
|
||||
values[i] *= -1
|
||||
|
||||
# Convert from initial range to range [-180, 180] degrees
|
||||
calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
|
||||
calib_val = (
|
||||
(values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
)
|
||||
in_range = (calib_val > LOWER_BOUND_DEGREE) and (
|
||||
calib_val < UPPER_BOUND_DEGREE
|
||||
)
|
||||
|
||||
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees
|
||||
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
|
||||
# (- HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= (HALF_TURN_DEGREE / 180 * (resolution // 2) - values[i] - homing_offset) / resolution
|
||||
low_factor = (
|
||||
-HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
|
||||
-HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2)
|
||||
- values[i]
|
||||
- homing_offset
|
||||
) / resolution
|
||||
upp_factor = (
|
||||
HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
|
||||
HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2)
|
||||
- values[i]
|
||||
- homing_offset
|
||||
) / resolution
|
||||
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
@@ -541,7 +564,9 @@ class FeetechMotorsBus:
|
||||
|
||||
# Convert from initial range to range [0, 100] in %
|
||||
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
|
||||
in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
|
||||
in_range = (calib_val > LOWER_BOUND_LINEAR) and (
|
||||
calib_val < UPPER_BOUND_LINEAR
|
||||
)
|
||||
|
||||
# Solve this inequality to find the factor to shift the range into [0, 100] %
|
||||
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
|
||||
@@ -557,19 +582,27 @@ class FeetechMotorsBus:
|
||||
factor = math.ceil(low_factor)
|
||||
|
||||
if factor > upp_factor:
|
||||
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
|
||||
raise ValueError(
|
||||
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
|
||||
)
|
||||
else:
|
||||
factor = math.ceil(upp_factor)
|
||||
|
||||
if factor > low_factor:
|
||||
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
|
||||
raise ValueError(
|
||||
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
|
||||
)
|
||||
|
||||
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
||||
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
||||
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
out_of_range_str = (
|
||||
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
)
|
||||
in_range_str = (
|
||||
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
)
|
||||
|
||||
logging.warning(
|
||||
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
|
||||
@@ -579,7 +612,9 @@ class FeetechMotorsBus:
|
||||
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
||||
self.calibration["homing_offset"][calib_idx] += resolution * factor
|
||||
|
||||
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
def revert_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
"""Inverse of `apply_calibration`."""
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
@@ -655,7 +690,9 @@ class FeetechMotorsBus:
|
||||
|
||||
return values
|
||||
|
||||
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
||||
def read_with_motor_ids(
|
||||
self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY
|
||||
):
|
||||
if self.mock:
|
||||
import tests.mock_scservo_sdk as scs
|
||||
else:
|
||||
@@ -760,7 +797,9 @@ class FeetechMotorsBus:
|
||||
values = self.apply_calibration_autocorrect(values, motor_names)
|
||||
|
||||
# log the number of seconds it took to read the data from the motors
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
|
||||
delta_ts_name = get_log_name(
|
||||
"delta_timestamp_s", "read", data_name, motor_names
|
||||
)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# log the utc time at which the data was received
|
||||
@@ -769,7 +808,9 @@ class FeetechMotorsBus:
|
||||
|
||||
return values
|
||||
|
||||
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
||||
def write_with_motor_ids(
|
||||
self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY
|
||||
):
|
||||
if self.mock:
|
||||
import tests.mock_scservo_sdk as scs
|
||||
else:
|
||||
@@ -798,7 +839,12 @@ class FeetechMotorsBus:
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
|
||||
def write(
|
||||
self,
|
||||
data_name,
|
||||
values: int | float | np.ndarray,
|
||||
motor_names: str | list[str] | None = None,
|
||||
):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
||||
@@ -859,7 +905,9 @@ class FeetechMotorsBus:
|
||||
)
|
||||
|
||||
# log the number of seconds it took to write the data to the motors
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
|
||||
delta_ts_name = get_log_name(
|
||||
"delta_timestamp_s", "write", data_name, motor_names
|
||||
)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# TODO(rcadene): should we log the time before sending the write command?
|
||||
|
||||
@@ -10,9 +10,7 @@ from lerobot.common.robot_devices.motors.dynamixel import (
|
||||
)
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
|
||||
URL_TEMPLATE = (
|
||||
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
|
||||
)
|
||||
URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
|
||||
|
||||
# The following positions are provided in nominal degree range ]-180, +180[
|
||||
# For more info on these constants, see comments in the code where they get used.
|
||||
@@ -23,7 +21,9 @@ ROTATED_POSITION_DEGREE = 90
|
||||
def assert_drive_mode(drive_mode):
|
||||
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
|
||||
if not np.all(np.isin(drive_mode, [0, 1])):
|
||||
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
|
||||
raise ValueError(
|
||||
f"`drive_mode` contains values other than 0 or 1: ({drive_mode})"
|
||||
)
|
||||
|
||||
|
||||
def apply_drive_mode(position, drive_mode):
|
||||
@@ -64,12 +64,16 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
|
||||
```
|
||||
"""
|
||||
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
||||
raise ValueError("To run calibration, the torque must be disabled on all motors.")
|
||||
raise ValueError(
|
||||
"To run calibration, the torque must be disabled on all motors."
|
||||
)
|
||||
|
||||
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
|
||||
|
||||
print("\nMove arm to zero position")
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
|
||||
print(
|
||||
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")
|
||||
)
|
||||
input("Press Enter to continue...")
|
||||
|
||||
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
|
||||
@@ -90,10 +94,15 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
|
||||
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view
|
||||
# of the previous motor in the kinetic chain.
|
||||
print("\nMove arm to rotated target position")
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
|
||||
print(
|
||||
"See: "
|
||||
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")
|
||||
)
|
||||
input("Press Enter to continue...")
|
||||
|
||||
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
|
||||
rotated_target_pos = convert_degrees_to_steps(
|
||||
ROTATED_POSITION_DEGREE, arm.motor_models
|
||||
)
|
||||
|
||||
# Find drive mode by rotating each motor by a quarter of a turn.
|
||||
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
|
||||
@@ -102,11 +111,15 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
|
||||
|
||||
# Re-compute homing offset to take into account drive mode
|
||||
rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode)
|
||||
rotated_nearest_pos = compute_nearest_rounded_position(rotated_drived_pos, arm.motor_models)
|
||||
rotated_nearest_pos = compute_nearest_rounded_position(
|
||||
rotated_drived_pos, arm.motor_models
|
||||
)
|
||||
homing_offset = rotated_target_pos - rotated_nearest_pos
|
||||
|
||||
print("\nMove arm to rest position")
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest"))
|
||||
print(
|
||||
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")
|
||||
)
|
||||
input("Press Enter to continue...")
|
||||
print()
|
||||
|
||||
|
||||
@@ -12,9 +12,7 @@ from lerobot.common.robot_devices.motors.feetech import (
|
||||
)
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
|
||||
URL_TEMPLATE = (
|
||||
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
|
||||
)
|
||||
URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
|
||||
|
||||
# The following positions are provided in nominal degree range ]-180, +180[
|
||||
# For more info on these constants, see comments in the code where they get used.
|
||||
@@ -25,7 +23,9 @@ ROTATED_POSITION_DEGREE = 90
|
||||
def assert_drive_mode(drive_mode):
|
||||
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
|
||||
if not np.all(np.isin(drive_mode, [0, 1])):
|
||||
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
|
||||
raise ValueError(
|
||||
f"`drive_mode` contains values other than 0 or 1: ({drive_mode})"
|
||||
)
|
||||
|
||||
|
||||
def apply_drive_mode(position, drive_mode):
|
||||
@@ -126,7 +126,9 @@ def apply_offset(calib, offset):
|
||||
return calib
|
||||
|
||||
|
||||
def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
|
||||
def run_arm_auto_calibration(
|
||||
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
|
||||
):
|
||||
if robot_type == "so100":
|
||||
return run_arm_auto_calibration_so100(arm, robot_type, arm_name, arm_type)
|
||||
elif robot_type == "moss":
|
||||
@@ -135,18 +137,27 @@ def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm
|
||||
raise ValueError(robot_type)
|
||||
|
||||
|
||||
def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
|
||||
def run_arm_auto_calibration_so100(
|
||||
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
|
||||
):
|
||||
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
|
||||
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
||||
raise ValueError("To run calibration, the torque must be disabled on all motors.")
|
||||
raise ValueError(
|
||||
"To run calibration, the torque must be disabled on all motors."
|
||||
)
|
||||
|
||||
if not (robot_type == "so100" and arm_type == "follower"):
|
||||
raise NotImplementedError("Auto calibration only supports the follower of so100 arms for now.")
|
||||
raise NotImplementedError(
|
||||
"Auto calibration only supports the follower of so100 arms for now."
|
||||
)
|
||||
|
||||
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
|
||||
|
||||
print("\nMove arm to initial position")
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
|
||||
print(
|
||||
"See: "
|
||||
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")
|
||||
)
|
||||
input("Press Enter to continue...")
|
||||
|
||||
# Lower the acceleration of the motors (in [0,254])
|
||||
@@ -193,11 +204,16 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
|
||||
|
||||
print("Calibrate elbow_flex")
|
||||
calib["elbow_flex"] = move_to_calibrate(
|
||||
arm, "elbow_flex", positive_first=False, in_between_move_hook=in_between_move_hook
|
||||
arm,
|
||||
"elbow_flex",
|
||||
positive_first=False,
|
||||
in_between_move_hook=in_between_move_hook,
|
||||
)
|
||||
calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024)
|
||||
|
||||
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex")
|
||||
arm.write(
|
||||
"Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex"
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
def in_between_move_hook():
|
||||
@@ -225,18 +241,30 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
|
||||
}
|
||||
arm.write("Goal_Position", list(positions.values()), list(positions.keys()))
|
||||
|
||||
arm.write("Goal_Position", round(calib["shoulder_lift"]["zero_pos"] - 1600), "shoulder_lift")
|
||||
arm.write(
|
||||
"Goal_Position",
|
||||
round(calib["shoulder_lift"]["zero_pos"] - 1600),
|
||||
"shoulder_lift",
|
||||
)
|
||||
time.sleep(2)
|
||||
arm.write("Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex")
|
||||
arm.write(
|
||||
"Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex"
|
||||
)
|
||||
time.sleep(2)
|
||||
arm.write("Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex")
|
||||
arm.write(
|
||||
"Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex"
|
||||
)
|
||||
time.sleep(2)
|
||||
arm.write("Goal_Position", round(calib["gripper"]["end_pos"]), "gripper")
|
||||
time.sleep(2)
|
||||
|
||||
print("Calibrate wrist_roll")
|
||||
calib["wrist_roll"] = move_to_calibrate(
|
||||
arm, "wrist_roll", invert_drive_mode=True, positive_first=False, while_move_hook=while_move_hook
|
||||
arm,
|
||||
"wrist_roll",
|
||||
invert_drive_mode=True,
|
||||
positive_first=False,
|
||||
while_move_hook=while_move_hook,
|
||||
)
|
||||
|
||||
arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"], "wrist_roll")
|
||||
@@ -246,7 +274,9 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
|
||||
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex")
|
||||
time.sleep(1)
|
||||
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 2048, "elbow_flex")
|
||||
arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift")
|
||||
arm.write(
|
||||
"Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift"
|
||||
)
|
||||
time.sleep(1)
|
||||
arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan")
|
||||
time.sleep(1)
|
||||
@@ -275,18 +305,27 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
|
||||
return calib_dict
|
||||
|
||||
|
||||
def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
|
||||
def run_arm_auto_calibration_moss(
|
||||
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
|
||||
):
|
||||
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
|
||||
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
||||
raise ValueError("To run calibration, the torque must be disabled on all motors.")
|
||||
raise ValueError(
|
||||
"To run calibration, the torque must be disabled on all motors."
|
||||
)
|
||||
|
||||
if not (robot_type == "moss" and arm_type == "follower"):
|
||||
raise NotImplementedError("Auto calibration only supports the follower of moss arms for now.")
|
||||
raise NotImplementedError(
|
||||
"Auto calibration only supports the follower of moss arms for now."
|
||||
)
|
||||
|
||||
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
|
||||
|
||||
print("\nMove arm to initial position")
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
|
||||
print(
|
||||
"See: "
|
||||
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")
|
||||
)
|
||||
input("Press Enter to continue...")
|
||||
|
||||
# Lower the acceleration of the motors (in [0,254])
|
||||
@@ -370,8 +409,12 @@ def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str
|
||||
|
||||
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
|
||||
time.sleep(1)
|
||||
arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift")
|
||||
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex")
|
||||
arm.write(
|
||||
"Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift"
|
||||
)
|
||||
arm.write(
|
||||
"Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex"
|
||||
)
|
||||
time.sleep(2)
|
||||
|
||||
calib_modes = []
|
||||
@@ -398,7 +441,9 @@ def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str
|
||||
return calib_dict
|
||||
|
||||
|
||||
def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
|
||||
def run_arm_manual_calibration(
|
||||
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
|
||||
):
|
||||
"""This function ensures that a neural network trained on data collected on a given robot
|
||||
can work on another robot. For instance before calibration, setting a same goal position
|
||||
for each motor of two different robots will get two very different positions. But after calibration,
|
||||
@@ -421,12 +466,16 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a
|
||||
```
|
||||
"""
|
||||
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
||||
raise ValueError("To run calibration, the torque must be disabled on all motors.")
|
||||
raise ValueError(
|
||||
"To run calibration, the torque must be disabled on all motors."
|
||||
)
|
||||
|
||||
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
|
||||
|
||||
print("\nMove arm to zero position")
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
|
||||
print(
|
||||
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")
|
||||
)
|
||||
input("Press Enter to continue...")
|
||||
|
||||
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
|
||||
@@ -446,10 +495,15 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a
|
||||
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view
|
||||
# of the previous motor in the kinetic chain.
|
||||
print("\nMove arm to rotated target position")
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
|
||||
print(
|
||||
"See: "
|
||||
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")
|
||||
)
|
||||
input("Press Enter to continue...")
|
||||
|
||||
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
|
||||
rotated_target_pos = convert_degrees_to_steps(
|
||||
ROTATED_POSITION_DEGREE, arm.motor_models
|
||||
)
|
||||
|
||||
# Find drive mode by rotating each motor by a quarter of a turn.
|
||||
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
|
||||
@@ -461,7 +515,9 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a
|
||||
homing_offset = rotated_target_pos - rotated_drived_pos
|
||||
|
||||
print("\nMove arm to rest position")
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest"))
|
||||
print(
|
||||
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")
|
||||
)
|
||||
input("Press Enter to continue...")
|
||||
print()
|
||||
|
||||
|
||||
@@ -18,11 +18,16 @@ import torch
|
||||
from lerobot.common.robot_devices.cameras.utils import Camera
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
from lerobot.common.robot_devices.robots.utils import get_arm_id
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
)
|
||||
|
||||
|
||||
def ensure_safe_goal_position(
|
||||
goal_pos: torch.Tensor, present_pos: torch.Tensor, max_relative_target: float | list[float]
|
||||
goal_pos: torch.Tensor,
|
||||
present_pos: torch.Tensor,
|
||||
max_relative_target: float | list[float],
|
||||
):
|
||||
# Cap relative action target magnitude for safety.
|
||||
diff = goal_pos - present_pos
|
||||
@@ -70,7 +75,11 @@ class ManipulatorRobotConfig:
|
||||
joint_position_relative_bounds: dict[np.ndarray] | None = None
|
||||
|
||||
def __setattr__(self, prop: str, val):
|
||||
if prop == "max_relative_target" and val is not None and isinstance(val, Sequence):
|
||||
if (
|
||||
prop == "max_relative_target"
|
||||
and val is not None
|
||||
and isinstance(val, Sequence)
|
||||
):
|
||||
for name in self.follower_arms:
|
||||
if len(self.follower_arms[name].motors) != len(val):
|
||||
raise ValueError(
|
||||
@@ -87,7 +96,9 @@ class ManipulatorRobotConfig:
|
||||
|
||||
def __post_init__(self):
|
||||
if self.robot_type not in ["koch", "koch_bimanual", "aloha", "so100", "moss"]:
|
||||
raise ValueError(f"Provided robot type ({self.robot_type}) is not supported.")
|
||||
raise ValueError(
|
||||
f"Provided robot type ({self.robot_type}) is not supported."
|
||||
)
|
||||
|
||||
|
||||
class ManipulatorRobot:
|
||||
@@ -341,7 +352,9 @@ class ManipulatorRobot:
|
||||
# to squeeze the gripper and have it spring back to an open position on its own.
|
||||
for name in self.leader_arms:
|
||||
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
|
||||
self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper")
|
||||
self.leader_arms[name].write(
|
||||
"Goal_Position", self.config.gripper_open_degree, "gripper"
|
||||
)
|
||||
|
||||
# Check both arms can be read
|
||||
for name in self.follower_arms:
|
||||
@@ -373,18 +386,26 @@ class ManipulatorRobot:
|
||||
print(f"Missing calibration file '{arm_calib_path}'")
|
||||
|
||||
if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
|
||||
from lerobot.common.robot_devices.robots.dynamixel_calibration import run_arm_calibration
|
||||
from lerobot.common.robot_devices.robots.dynamixel_calibration import (
|
||||
run_arm_calibration,
|
||||
)
|
||||
|
||||
calibration = run_arm_calibration(arm, self.robot_type, name, arm_type)
|
||||
calibration = run_arm_calibration(
|
||||
arm, self.robot_type, name, arm_type
|
||||
)
|
||||
|
||||
elif self.robot_type in ["so100", "moss"]:
|
||||
from lerobot.common.robot_devices.robots.feetech_calibration import (
|
||||
run_arm_manual_calibration,
|
||||
)
|
||||
|
||||
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
|
||||
calibration = run_arm_manual_calibration(
|
||||
arm, self.robot_type, name, arm_type
|
||||
)
|
||||
|
||||
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
||||
print(
|
||||
f"Calibration is done! Saving calibration file '{arm_calib_path}'"
|
||||
)
|
||||
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(arm_calib_path, "w") as f:
|
||||
json.dump(calibration, f)
|
||||
@@ -403,13 +424,17 @@ class ManipulatorRobot:
|
||||
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
|
||||
|
||||
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
||||
raise ValueError("To run set robot preset, the torque must be disabled on all motors.")
|
||||
raise ValueError(
|
||||
"To run set robot preset, the torque must be disabled on all motors."
|
||||
)
|
||||
|
||||
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't
|
||||
# rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm,
|
||||
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
|
||||
all_motors_except_gripper = [name for name in arm.motor_names if name != "gripper"]
|
||||
all_motors_except_gripper = [
|
||||
name for name in arm.motor_names if name != "gripper"
|
||||
]
|
||||
if len(all_motors_except_gripper) > 0:
|
||||
# 4 corresponds to Extended Position on Koch motors
|
||||
arm.write("Operating_Mode", 4, all_motors_except_gripper)
|
||||
@@ -438,7 +463,9 @@ class ManipulatorRobot:
|
||||
# Enable torque on the gripper of the leader arms, and move it to 45 degrees,
|
||||
# so that we can use it as a trigger to close the gripper of the follower arms.
|
||||
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
|
||||
self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper")
|
||||
self.leader_arms[name].write(
|
||||
"Goal_Position", self.config.gripper_open_degree, "gripper"
|
||||
)
|
||||
|
||||
def set_aloha_robot_preset(self):
|
||||
def set_shadow_(arm):
|
||||
@@ -468,11 +495,15 @@ class ManipulatorRobot:
|
||||
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
|
||||
all_motors_except_gripper = [
|
||||
name for name in self.follower_arms[name].motor_names if name != "gripper"
|
||||
name
|
||||
for name in self.follower_arms[name].motor_names
|
||||
if name != "gripper"
|
||||
]
|
||||
if len(all_motors_except_gripper) > 0:
|
||||
# 4 corresponds to Extended Position on Aloha motors
|
||||
self.follower_arms[name].write("Operating_Mode", 4, all_motors_except_gripper)
|
||||
self.follower_arms[name].write(
|
||||
"Operating_Mode", 4, all_motors_except_gripper
|
||||
)
|
||||
|
||||
# Use 'position control current based' for follower gripper to be limited by the limit of the current.
|
||||
# It can grasp an object without forcing too much even tho,
|
||||
@@ -520,7 +551,9 @@ class ManipulatorRobot:
|
||||
before_lread_t = time.perf_counter()
|
||||
leader_pos[name] = self.leader_arms[name].read("Present_Position")
|
||||
leader_pos[name] = torch.from_numpy(leader_pos[name])
|
||||
self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t
|
||||
self.logs[f"read_leader_{name}_pos_dt_s"] = (
|
||||
time.perf_counter() - before_lread_t
|
||||
)
|
||||
|
||||
# Send goal position to the follower
|
||||
follower_goal_pos = {}
|
||||
@@ -541,14 +574,18 @@ class ManipulatorRobot:
|
||||
if self.config.max_relative_target is not None:
|
||||
present_pos = self.follower_arms[name].read("Present_Position")
|
||||
present_pos = torch.from_numpy(present_pos)
|
||||
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
|
||||
goal_pos = ensure_safe_goal_position(
|
||||
goal_pos, present_pos, self.config.max_relative_target
|
||||
)
|
||||
|
||||
# Used when record_data=True
|
||||
follower_goal_pos[name] = goal_pos
|
||||
|
||||
goal_pos = goal_pos.numpy().astype(np.int32)
|
||||
self.follower_arms[name].write("Goal_Position", goal_pos)
|
||||
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t
|
||||
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = (
|
||||
time.perf_counter() - before_fwrite_t
|
||||
)
|
||||
|
||||
# Early exit when recording data is not requested
|
||||
if not record_data:
|
||||
@@ -561,7 +598,9 @@ class ManipulatorRobot:
|
||||
before_fread_t = time.perf_counter()
|
||||
follower_pos[name] = self.follower_arms[name].read("Present_Position")
|
||||
follower_pos[name] = torch.from_numpy(follower_pos[name])
|
||||
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
|
||||
self.logs[f"read_follower_{name}_pos_dt_s"] = (
|
||||
time.perf_counter() - before_fread_t
|
||||
)
|
||||
|
||||
# Create state by concatenating follower current position
|
||||
state = []
|
||||
@@ -583,8 +622,12 @@ class ManipulatorRobot:
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
images[name] = torch.from_numpy(images[name])
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
|
||||
"delta_timestamp_s"
|
||||
]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = (
|
||||
time.perf_counter() - before_camread_t
|
||||
)
|
||||
|
||||
# Populate output dictionnaries
|
||||
obs_dict, action_dict = {}, {}
|
||||
@@ -608,7 +651,9 @@ class ManipulatorRobot:
|
||||
before_fread_t = time.perf_counter()
|
||||
follower_pos[name] = self.follower_arms[name].read("Present_Position")
|
||||
follower_pos[name] = torch.from_numpy(follower_pos[name])
|
||||
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
|
||||
self.logs[f"read_follower_{name}_pos_dt_s"] = (
|
||||
time.perf_counter() - before_fread_t
|
||||
)
|
||||
|
||||
# Create state by concatenating follower current position
|
||||
state = []
|
||||
@@ -623,8 +668,12 @@ class ManipulatorRobot:
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
images[name] = torch.from_numpy(images[name])
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
|
||||
"delta_timestamp_s"
|
||||
]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = (
|
||||
time.perf_counter() - before_camread_t
|
||||
)
|
||||
|
||||
# Populate output dictionnaries and format to pytorch
|
||||
obs_dict = {}
|
||||
@@ -670,7 +719,9 @@ class ManipulatorRobot:
|
||||
if self.config.max_relative_target is not None:
|
||||
present_pos = self.follower_arms[name].read("Present_Position")
|
||||
present_pos = torch.from_numpy(present_pos)
|
||||
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
|
||||
goal_pos = ensure_safe_goal_position(
|
||||
goal_pos, present_pos, self.config.max_relative_target
|
||||
)
|
||||
|
||||
# Save tensor to concat and return
|
||||
action_sent.append(goal_pos)
|
||||
|
||||
@@ -60,7 +60,9 @@ class StretchRobot(StretchAPI):
|
||||
def connect(self) -> None:
|
||||
self.is_connected = self.startup()
|
||||
if not self.is_connected:
|
||||
print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'")
|
||||
print(
|
||||
"Another process is already using Stretch. Try running 'stretch_free_robot_process.py'"
|
||||
)
|
||||
raise ConnectionError()
|
||||
|
||||
for name in self.cameras:
|
||||
@@ -68,7 +70,9 @@ class StretchRobot(StretchAPI):
|
||||
self.is_connected = self.is_connected and self.cameras[name].is_connected
|
||||
|
||||
if not self.is_connected:
|
||||
print("Could not connect to the cameras, check that all cameras are plugged-in.")
|
||||
print(
|
||||
"Could not connect to the cameras, check that all cameras are plugged-in."
|
||||
)
|
||||
raise ConnectionError()
|
||||
|
||||
self.run_calibration()
|
||||
@@ -113,8 +117,12 @@ class StretchRobot(StretchAPI):
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
images[name] = torch.from_numpy(images[name])
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
|
||||
"delta_timestamp_s"
|
||||
]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = (
|
||||
time.perf_counter() - before_camread_t
|
||||
)
|
||||
|
||||
# Populate output dictionnaries
|
||||
obs_dict, action_dict = {}, {}
|
||||
@@ -158,8 +166,12 @@ class StretchRobot(StretchAPI):
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
images[name] = torch.from_numpy(images[name])
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
|
||||
"delta_timestamp_s"
|
||||
]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = (
|
||||
time.perf_counter() - before_camread_t
|
||||
)
|
||||
|
||||
# Populate output dictionnaries
|
||||
obs_dict = {}
|
||||
|
||||
@@ -34,7 +34,8 @@ class RobotDeviceNotConnectedError(Exception):
|
||||
"""Exception raised when the robot device is not connected."""
|
||||
|
||||
def __init__(
|
||||
self, message="This robot device is not connected. Try calling `robot_device.connect()` first."
|
||||
self,
|
||||
message="This robot device is not connected. Try calling `robot_device.connect()` first.",
|
||||
):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
@@ -17,7 +17,9 @@ import importlib
|
||||
import logging
|
||||
|
||||
|
||||
def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
|
||||
def is_package_available(
|
||||
pkg_name: str, return_version: bool = False
|
||||
) -> tuple[bool, str] | bool:
|
||||
"""Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py
|
||||
Check if the package spec exists and grab its version to avoid importing a local directory.
|
||||
**Note:** this doesn't work for all packages.
|
||||
|
||||
@@ -22,6 +22,8 @@ def write_video(video_path, stacked_frames, fps):
|
||||
# Filter out DeprecationWarnings raised from pkg_resources
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning
|
||||
"ignore",
|
||||
"pkg_resources is deprecated as an API",
|
||||
category=DeprecationWarning,
|
||||
)
|
||||
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
||||
|
||||
@@ -116,11 +116,11 @@ def seeded_context(seed: int) -> Generator[None, None, None]:
|
||||
set_global_random_state(random_state_dict)
|
||||
|
||||
|
||||
def init_logging():
|
||||
def init_logging(log_file=None):
|
||||
def custom_format(record):
|
||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
fnameline = f"{record.pathname}:{record.lineno}"
|
||||
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||
message = f"{record.levelname} [PID: {os.getpid()}] {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||
return message
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -134,6 +134,12 @@ def init_logging():
|
||||
console_handler.setFormatter(formatter)
|
||||
logging.getLogger().addHandler(console_handler)
|
||||
|
||||
if log_file is not None:
|
||||
# File handler
|
||||
file_handler = logging.FileHandler(log_file)
|
||||
file_handler.setFormatter(formatter)
|
||||
logging.getLogger().addHandler(file_handler)
|
||||
|
||||
|
||||
def format_big_number(num, precision=0):
|
||||
suffixes = ["", "K", "M", "B", "T", "Q"]
|
||||
@@ -156,11 +162,16 @@ def _relative_path_between(path1: Path, path2: Path) -> Path:
|
||||
except ValueError: # most likely because path1 is not a subpath of path2
|
||||
common_parts = Path(osp.commonpath([path1, path2])).parts
|
||||
return Path(
|
||||
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
|
||||
"/".join(
|
||||
[".."] * (len(path2.parts) - len(common_parts))
|
||||
+ list(path1.parts[len(common_parts) :])
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> DictConfig:
|
||||
def init_hydra_config(
|
||||
config_path: str, overrides: list[str] | None = None
|
||||
) -> DictConfig:
|
||||
"""Initialize a Hydra config given only the path to the relevant config file.
|
||||
|
||||
For config resolution, it is assumed that the config file's parent is the Hydra config dir.
|
||||
@@ -169,7 +180,11 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D
|
||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||
# Hydra needs a path relative to this file.
|
||||
hydra.initialize(
|
||||
str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent)),
|
||||
str(
|
||||
_relative_path_between(
|
||||
Path(config_path).absolute().parent, Path(__file__).absolute().parent
|
||||
)
|
||||
),
|
||||
version_base="1.2",
|
||||
)
|
||||
cfg = hydra.compose(Path(config_path).stem, overrides)
|
||||
@@ -183,10 +198,26 @@ def print_cuda_memory_usage():
|
||||
gc.collect()
|
||||
# Also clear the cache if you want to fully release the memory
|
||||
torch.cuda.empty_cache()
|
||||
print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2))
|
||||
print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2))
|
||||
print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2))
|
||||
print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2))
|
||||
print(
|
||||
"Current GPU Memory Allocated: {:.2f} MB".format(
|
||||
torch.cuda.memory_allocated(0) / 1024**2
|
||||
)
|
||||
)
|
||||
print(
|
||||
"Maximum GPU Memory Allocated: {:.2f} MB".format(
|
||||
torch.cuda.max_memory_allocated(0) / 1024**2
|
||||
)
|
||||
)
|
||||
print(
|
||||
"Current GPU Memory Reserved: {:.2f} MB".format(
|
||||
torch.cuda.memory_reserved(0) / 1024**2
|
||||
)
|
||||
)
|
||||
print(
|
||||
"Maximum GPU Memory Reserved: {:.2f} MB".format(
|
||||
torch.cuda.max_memory_reserved(0) / 1024**2
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def capture_timestamp_utc():
|
||||
@@ -221,7 +252,12 @@ def log_say(text, play_sounds, blocking=False):
|
||||
|
||||
|
||||
class TimerManager:
|
||||
def __init__(self, elapsed_time_list: list[float] | None = None, label="Elapsed time", log=True):
|
||||
def __init__(
|
||||
self,
|
||||
elapsed_time_list: list[float] | None = None,
|
||||
label="Elapsed time",
|
||||
log=True,
|
||||
):
|
||||
self.label = label
|
||||
self.elapsed_time_list = elapsed_time_list
|
||||
self.log = log
|
||||
|
||||
18
lerobot/configs/env/maniskill_example.yaml
vendored
18
lerobot/configs/env/maniskill_example.yaml
vendored
@@ -1,20 +1,30 @@
|
||||
# @package _global_
|
||||
|
||||
fps: 20
|
||||
fps: 400
|
||||
|
||||
env:
|
||||
name: maniskill/pushcube
|
||||
task: PushCube-v1
|
||||
image_size: 128
|
||||
image_size: 64
|
||||
control_mode: pd_ee_delta_pose
|
||||
state_dim: 25
|
||||
action_dim: 7
|
||||
fps: ${fps}
|
||||
obs: rgb
|
||||
render_mode: rgb_array
|
||||
render_size: 128
|
||||
render_size: 64
|
||||
device: cuda
|
||||
|
||||
reward_classifier:
|
||||
pretrained_path: null
|
||||
config_path: null
|
||||
config_path: null
|
||||
|
||||
wrapper:
|
||||
joint_masking_action_space: null
|
||||
delta_action: null
|
||||
|
||||
video_record:
|
||||
enabled: false
|
||||
record_dir: maniskill_videos
|
||||
trajectory_name: trajectory
|
||||
fps: ${fps}
|
||||
|
||||
47
lerobot/configs/env/so100_real.yaml
vendored
47
lerobot/configs/env/so100_real.yaml
vendored
@@ -5,27 +5,46 @@ fps: 10
|
||||
env:
|
||||
name: real_world
|
||||
task: null
|
||||
state_dim: 6
|
||||
action_dim: 6
|
||||
state_dim: 15
|
||||
action_dim: 3
|
||||
fps: ${fps}
|
||||
device: mps
|
||||
|
||||
|
||||
wrapper:
|
||||
crop_params_dict:
|
||||
observation.images.front: [102, 43, 358, 523]
|
||||
observation.images.side: [92, 123, 379, 349]
|
||||
# observation.images.front: [109, 37, 361, 557]
|
||||
# observation.images.side: [94, 161, 372, 315]
|
||||
observation.images.front: [171, 207, 116, 251]
|
||||
observation.images.side: [232, 200, 142, 204]
|
||||
resize_size: [128, 128]
|
||||
control_time_s: 20
|
||||
reset_follower_pos: true
|
||||
control_time_s: 10
|
||||
reset_follower_pos: false
|
||||
use_relative_joint_positions: true
|
||||
reset_time_s: 5
|
||||
display_cameras: false
|
||||
delta_action: 0.1
|
||||
joint_masking_action_space: [1, 1, 1, 1, 0, 0] # disable wrist and gripper
|
||||
delta_action: null #0.3
|
||||
joint_masking_action_space: null #[1, 1, 1, 1, 0, 0] # disable wrist and gripper
|
||||
add_joint_velocity_to_observation: true
|
||||
add_ee_pose_to_observation: true
|
||||
|
||||
# If null then the teleoperation will be used to reset the robot
|
||||
# Bounds for pushcube_gamepad_lerobot15 dataset and experiments
|
||||
# fixed_reset_joint_positions: [-19.86, 103.19, 117.33, 42.7, 13.89, 0.297]
|
||||
# ee_action_space_params: # If null then ee_action_space is not used
|
||||
# bounds:
|
||||
# max: [0.291, 0.147, 0.074]
|
||||
# min: [0.139, -0.143, 0.03]
|
||||
|
||||
# Bounds for insertcube_gamepad dataset and experiments
|
||||
fixed_reset_joint_positions: [20.0, 90., 90., 75., -0.7910156, -0.5673759]
|
||||
ee_action_space_params:
|
||||
bounds:
|
||||
max: [0.25295413, 0.07498981, 0.06862044]
|
||||
min: [0.2010096, -0.12, 0.0433196]
|
||||
|
||||
use_gamepad: true
|
||||
x_step_size: 0.03
|
||||
y_step_size: 0.03
|
||||
z_step_size: 0.03
|
||||
|
||||
reward_classifier:
|
||||
pretrained_path: outputs/classifier/13-02-random-sample-resnet10-frozen/checkpoints/best/pretrained_model
|
||||
config_path: lerobot/configs/policy/hilserl_classifier.yaml
|
||||
|
||||
pretrained_path: null # outputs/classifier/13-02-random-sample-resnet10-frozen/checkpoints/best/pretrained_model
|
||||
config_path: null # lerobot/configs/policy/hilserl_classifier.yaml
|
||||
|
||||
@@ -3,6 +3,14 @@
|
||||
defaults:
|
||||
- _self_
|
||||
|
||||
hydra:
|
||||
run:
|
||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
||||
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||
dir: outputs/train_hilserl_classifier/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${hydra.job.name}
|
||||
job:
|
||||
name: default
|
||||
|
||||
seed: 13
|
||||
dataset_repo_id: aractingi/push_cube_square_light_reward_cropped_resized
|
||||
# aractingi/push_cube_square_reward_1_cropped_resized
|
||||
|
||||
@@ -8,20 +8,23 @@
|
||||
# env.gym.obs_type=environment_state_agent_pos \
|
||||
|
||||
seed: 1
|
||||
# dataset_repo_id: "AdilZtn/Maniskill-Pushcube-demonstration-medium"
|
||||
dataset_repo_id: null
|
||||
|
||||
training:
|
||||
# Offline training dataloader
|
||||
num_workers: 4
|
||||
|
||||
# batch_size: 256
|
||||
batch_size: 512
|
||||
grad_clip_norm: 10.0
|
||||
grad_clip_norm: 40.0
|
||||
lr: 3e-4
|
||||
|
||||
|
||||
storage_device: "cuda"
|
||||
|
||||
eval_freq: 2500
|
||||
log_freq: 10
|
||||
save_freq: 2000000
|
||||
save_freq: 1000000
|
||||
|
||||
online_steps: 1000000
|
||||
online_rollout_n_episodes: 10
|
||||
@@ -29,18 +32,13 @@ training:
|
||||
online_steps_between_rollouts: 1000
|
||||
online_sampling_ratio: 1.0
|
||||
online_env_seed: 10000
|
||||
online_buffer_capacity: 1000000
|
||||
online_buffer_capacity: 200000
|
||||
offline_buffer_capacity: 100000
|
||||
online_buffer_seed_size: 0
|
||||
online_step_before_learning: 5000
|
||||
online_step_before_learning: 500
|
||||
do_online_rollout_async: false
|
||||
policy_update_freq: 1
|
||||
|
||||
# delta_timestamps:
|
||||
# observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
# observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
# action: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
# next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
|
||||
policy:
|
||||
name: sac
|
||||
|
||||
@@ -52,39 +50,47 @@ policy:
|
||||
n_action_steps: 1
|
||||
|
||||
shared_encoder: true
|
||||
vision_encoder_name: null
|
||||
# vision_encoder_name: "helper2424/resnet10"
|
||||
vision_encoder_name: null
|
||||
# freeze_vision_encoder: true
|
||||
freeze_vision_encoder: false
|
||||
input_shapes:
|
||||
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.state: ["${env.state_dim}"]
|
||||
observation.image: [3, 128, 128]
|
||||
observation.image: [3, 64, 64]
|
||||
output_shapes:
|
||||
action: [7]
|
||||
|
||||
camera_number: 1
|
||||
|
||||
# Normalization / Unnormalization
|
||||
# input_normalization_modes: null
|
||||
input_normalization_modes:
|
||||
observation.state: min_max
|
||||
input_normalization_params:
|
||||
observation.image: mean_std
|
||||
# input_normalization_params: null
|
||||
input_normalization_params:
|
||||
observation.state:
|
||||
min: [-1.9361e+00, -7.7640e-01, -7.7094e-01, -2.9709e+00, -8.5656e-01,
|
||||
1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00,
|
||||
-3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00,
|
||||
-6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01,
|
||||
8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01]
|
||||
|
||||
8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01]
|
||||
max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400,
|
||||
0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163,
|
||||
7.8145, 9.7415, 0.2422, 0.4505, 0.6306, 0.2622, 1.0000, 0.5135,
|
||||
0.4001]
|
||||
|
||||
observation.image:
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
output_normalization_params:
|
||||
action:
|
||||
min: [-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0]
|
||||
max: [10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]
|
||||
min: [-0.03, -0.03, -0.03, -0.03, -0.03, -0.03, -0.03]
|
||||
max: [0.03, 0.03, 0.03, 0.03, 0.03, 0.03, 0.03]
|
||||
output_normalization_shapes:
|
||||
action: [7]
|
||||
|
||||
@@ -104,5 +110,9 @@ policy:
|
||||
utd_ratio: 2 # 10
|
||||
|
||||
actor_learner_config:
|
||||
actor_ip: "127.0.0.1"
|
||||
port: 50051
|
||||
learner_host: "127.0.0.1"
|
||||
learner_port: 50051
|
||||
policy_parameters_push_frequency: 4
|
||||
concurrency:
|
||||
actor: 'threads'
|
||||
learner: 'threads'
|
||||
|
||||
@@ -8,8 +8,7 @@
|
||||
# env.gym.obs_type=environment_state_agent_pos \
|
||||
|
||||
seed: 1
|
||||
dataset_repo_id: aractingi/push_cube_overfit_cropped_resized
|
||||
#aractingi/push_cube_square_offline_demo_cropped_resized
|
||||
dataset_repo_id: aractingi/insertcube_simple
|
||||
|
||||
training:
|
||||
# Offline training dataloader
|
||||
@@ -30,7 +29,7 @@ training:
|
||||
online_steps_between_rollouts: 1000
|
||||
online_sampling_ratio: 1.0
|
||||
online_env_seed: 10000
|
||||
online_buffer_capacity: 1000000
|
||||
online_buffer_capacity: 10000
|
||||
online_buffer_seed_size: 0
|
||||
online_step_before_learning: 100 #5000
|
||||
do_online_rollout_async: false
|
||||
@@ -62,10 +61,10 @@ policy:
|
||||
observation.images.side: [3, 128, 128]
|
||||
# observation.image: [3, 128, 128]
|
||||
output_shapes:
|
||||
action: [4] # ["${env.action_dim}"]
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
input_normalization_modes:
|
||||
observation.images.front: mean_std
|
||||
observation.images.side: mean_std
|
||||
observation.state: min_max
|
||||
@@ -77,23 +76,16 @@ policy:
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
observation.state:
|
||||
min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786]
|
||||
max: [ 7.215820e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01]
|
||||
|
||||
# min: [-87.09961, 62.402344, 67.23633, 36.035156, 77.34375,0.53691274]
|
||||
# max: [58.183594, 131.83594, 145.98633, 82.08984, 78.22266, 0.60402685]
|
||||
# min: [-88.50586, 23.81836, 0.87890625, -32.16797, 78.66211, 0.53691274]
|
||||
# max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156, 88.18792]
|
||||
# 6- joint positions, 6- joint velocities, 3- ee position
|
||||
max: [ 52.822266, 136.14258, 142.03125, 72.1582, 22.675781, -0.5673759, 100., 100., 100., 100., 100., 100., 0.25295413, 0.07498981, 0.06862044]
|
||||
min: [-2.6367188, 86.572266, 89.82422, 12.392578, -26.015625, -0.5673759, -100., -100., -100., -100., -100., -100., 0.2010096, -0.12, 0.0433196]
|
||||
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
output_normalization_params:
|
||||
# action:
|
||||
# min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0]
|
||||
# max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
action:
|
||||
min: [-149.23828125, -97.734375, -100.1953125, -73.740234375]
|
||||
max: [149.23828125, 97.734375, 100.1953125, 73.740234375]
|
||||
min: [-0.03, -0.03, -0.01]
|
||||
max: [0.03, 0.03, 0.03]
|
||||
|
||||
# Architecture / modeling.
|
||||
# Neural networks.
|
||||
@@ -112,8 +104,9 @@ policy:
|
||||
utd_ratio: 2 # 10
|
||||
|
||||
actor_learner_config:
|
||||
actor_ip: "127.0.0.1"
|
||||
port: 50051
|
||||
learner_host: "127.0.0.1"
|
||||
learner_port: 50051
|
||||
policy_parameters_push_frequency: 15
|
||||
|
||||
# # Loss coefficients.
|
||||
# reward_coeff: 0.5
|
||||
|
||||
@@ -14,9 +14,13 @@ calibration_dir: .cache/calibration/so100
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: null
|
||||
joint_position_relative_bounds:
|
||||
max: [ 7.2158203e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01]
|
||||
min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786]
|
||||
joint_position_relative_bounds: null
|
||||
# max: [100, 100, 100, 100, 100, 100]
|
||||
# min: [-100, -100, -100, -100, -100, -100]
|
||||
# max: [ 7.2158203e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01]
|
||||
# min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786]
|
||||
# max: [ 35.06836 , 103.18359 , 127.61719 , 75.58594 , 0., 0.]
|
||||
# min: [ -8.876953 , 63.808594 , 90.49805 , 49.48242 , 0., 0.]
|
||||
|
||||
leader_arms:
|
||||
main:
|
||||
@@ -47,13 +51,13 @@ follower_arms:
|
||||
cameras:
|
||||
front:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 0
|
||||
camera_index: 1
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
side:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 1
|
||||
camera_index: 0
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
|
||||
@@ -22,13 +22,17 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
||||
from lerobot.common.robot_devices.motors.feetech import (
|
||||
SCS_SERIES_BAUDRATE_TABLE as SERIES_BAUDRATE_TABLE,
|
||||
)
|
||||
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus as MotorsBusClass
|
||||
from lerobot.common.robot_devices.motors.feetech import (
|
||||
FeetechMotorsBus as MotorsBusClass,
|
||||
)
|
||||
elif brand == "dynamixel":
|
||||
from lerobot.common.robot_devices.motors.dynamixel import MODEL_BAUDRATE_TABLE
|
||||
from lerobot.common.robot_devices.motors.dynamixel import (
|
||||
X_SERIES_BAUDRATE_TABLE as SERIES_BAUDRATE_TABLE,
|
||||
)
|
||||
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus as MotorsBusClass
|
||||
from lerobot.common.robot_devices.motors.dynamixel import (
|
||||
DynamixelMotorsBus as MotorsBusClass,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Currently we do not support this motor brand: {brand}. We currently support feetech and dynamixel motors."
|
||||
@@ -46,7 +50,9 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
||||
motor_model = model # Use the motor model passed via argument
|
||||
|
||||
# Initialize the MotorBus with the correct port and motor configurations
|
||||
motor_bus = MotorsBusClass(port=port, motors={motor_name: (motor_index_arbitrary, motor_model)})
|
||||
motor_bus = MotorsBusClass(
|
||||
port=port, motors={motor_name: (motor_index_arbitrary, motor_model)}
|
||||
)
|
||||
|
||||
# Try to connect to the motor bus and handle any connection-specific errors
|
||||
try:
|
||||
@@ -78,20 +84,26 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
||||
motor_index = present_ids[0]
|
||||
|
||||
if motor_index == -1:
|
||||
raise ValueError("No motors detected. Please ensure you have one motor connected.")
|
||||
raise ValueError(
|
||||
"No motors detected. Please ensure you have one motor connected."
|
||||
)
|
||||
|
||||
print(f"Motor index found at: {motor_index}")
|
||||
|
||||
if brand == "feetech":
|
||||
# Allows ID and BAUDRATE to be written in memory
|
||||
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
|
||||
motor_bus.write_with_motor_ids(
|
||||
motor_bus.motor_models, motor_index, "Lock", 0
|
||||
)
|
||||
|
||||
if baudrate != baudrate_des:
|
||||
print(f"Setting its baudrate to {baudrate_des}")
|
||||
baudrate_idx = list(SERIES_BAUDRATE_TABLE.values()).index(baudrate_des)
|
||||
|
||||
# The write can fail, so we allow retries
|
||||
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx)
|
||||
motor_bus.write_with_motor_ids(
|
||||
motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx
|
||||
)
|
||||
time.sleep(0.5)
|
||||
motor_bus.set_bus_baudrate(baudrate_des)
|
||||
present_baudrate_idx = motor_bus.read_with_motor_ids(
|
||||
@@ -103,9 +115,13 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
||||
|
||||
print(f"Setting its index to desired index {motor_idx_des}")
|
||||
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
|
||||
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "ID", motor_idx_des)
|
||||
motor_bus.write_with_motor_ids(
|
||||
motor_bus.motor_models, motor_index, "ID", motor_idx_des
|
||||
)
|
||||
|
||||
present_idx = motor_bus.read_with_motor_ids(motor_bus.motor_models, motor_idx_des, "ID", num_retry=2)
|
||||
present_idx = motor_bus.read_with_motor_ids(
|
||||
motor_bus.motor_models, motor_idx_des, "ID", num_retry=2
|
||||
)
|
||||
if present_idx != motor_idx_des:
|
||||
raise OSError("Failed to write index.")
|
||||
|
||||
@@ -133,12 +149,29 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--port", type=str, required=True, help="Motors bus port (e.g. dynamixel,feetech)")
|
||||
parser.add_argument("--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)")
|
||||
parser.add_argument("--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)")
|
||||
parser.add_argument("--ID", type=int, required=True, help="Desired ID of the current motor (e.g. 1,2,3)")
|
||||
parser.add_argument(
|
||||
"--baudrate", type=int, default=1000000, help="Desired baudrate for the motor (default: 1000000)"
|
||||
"--port",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Motors bus port (e.g. dynamixel,feetech)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ID",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Desired ID of the current motor (e.g. 1,2,3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--baudrate",
|
||||
type=int,
|
||||
default=1000000,
|
||||
help="Desired baudrate for the motor (default: 1000000)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@@ -118,7 +118,12 @@ from lerobot.common.robot_devices.control_utils import (
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
|
||||
from lerobot.common.utils.utils import init_hydra_config, init_logging, log_say, none_or_int
|
||||
from lerobot.common.utils.utils import (
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
log_say,
|
||||
none_or_int,
|
||||
)
|
||||
|
||||
########################################################################################
|
||||
# Control modes
|
||||
@@ -173,7 +178,10 @@ def calibrate(robot: Robot, arms: list[str] | None):
|
||||
|
||||
@safe_disconnect
|
||||
def teleoperate(
|
||||
robot: Robot, fps: int | None = None, teleop_time_s: float | None = None, display_cameras: bool = False
|
||||
robot: Robot,
|
||||
fps: int | None = None,
|
||||
teleop_time_s: float | None = None,
|
||||
display_cameras: bool = False,
|
||||
):
|
||||
control_loop(
|
||||
robot,
|
||||
@@ -234,11 +242,15 @@ def record(
|
||||
|
||||
# Load pretrained policy
|
||||
if pretrained_policy_name_or_path is not None:
|
||||
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
||||
policy, policy_fps, device, use_amp = init_policy(
|
||||
pretrained_policy_name_or_path, policy_overrides
|
||||
)
|
||||
|
||||
if fps is None:
|
||||
fps = policy_fps
|
||||
logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).")
|
||||
logging.warning(
|
||||
f"No fps provided, so using the fps from policy config ({policy_fps})."
|
||||
)
|
||||
elif fps != policy_fps:
|
||||
logging.warning(
|
||||
f"There is a mismatch between the provided fps ({fps}) and the one from policy config ({policy_fps})."
|
||||
@@ -254,7 +266,9 @@ def record(
|
||||
num_processes=num_image_writer_processes,
|
||||
num_threads=num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video, extra_features)
|
||||
sanity_check_dataset_robot_compatibility(
|
||||
dataset, robot, fps, video, extra_features
|
||||
)
|
||||
else:
|
||||
# Create empty dataset or load existing saved episodes
|
||||
sanity_check_dataset_name(repo_id, policy)
|
||||
@@ -265,7 +279,8 @@ def record(
|
||||
robot=robot,
|
||||
use_videos=video,
|
||||
image_writer_processes=num_image_writer_processes,
|
||||
image_writer_threads=num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
image_writer_threads=num_image_writer_threads_per_camera
|
||||
* len(robot.cameras),
|
||||
features=extra_features,
|
||||
)
|
||||
|
||||
@@ -282,7 +297,9 @@ def record(
|
||||
# 3. place the cameras windows on screen
|
||||
enable_teleoperation = policy is None
|
||||
log_say("Warmup record", play_sounds)
|
||||
warmup_record(robot, events, enable_teleoperation, warmup_time_s, display_cameras, fps)
|
||||
warmup_record(
|
||||
robot, events, enable_teleoperation, warmup_time_s, display_cameras, fps
|
||||
)
|
||||
|
||||
if has_method(robot, "teleop_safety_stop"):
|
||||
robot.teleop_safety_stop()
|
||||
@@ -365,7 +382,9 @@ def replay(
|
||||
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
|
||||
# TODO(rcadene): Add option to record logs
|
||||
|
||||
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
|
||||
dataset = LeRobotDataset(
|
||||
repo_id, root=root, episodes=[episode], local_files_only=local_files_only
|
||||
)
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
@@ -416,7 +435,10 @@ if __name__ == "__main__":
|
||||
|
||||
parser_teleop = subparsers.add_parser("teleoperate", parents=[base_parser])
|
||||
parser_teleop.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
"--fps",
|
||||
type=none_or_int,
|
||||
default=None,
|
||||
help="Frames per second (set to None to disable)",
|
||||
)
|
||||
parser_teleop.add_argument(
|
||||
"--display-cameras",
|
||||
@@ -428,7 +450,10 @@ if __name__ == "__main__":
|
||||
parser_record = subparsers.add_parser("record", parents=[base_parser])
|
||||
task_args = parser_record.add_mutually_exclusive_group(required=True)
|
||||
parser_record.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
"--fps",
|
||||
type=none_or_int,
|
||||
default=None,
|
||||
help="Frames per second (set to None to disable)",
|
||||
)
|
||||
task_args.add_argument(
|
||||
"--single-task",
|
||||
@@ -477,7 +502,9 @@ if __name__ == "__main__":
|
||||
default=60,
|
||||
help="Number of seconds for resetting the environment after each episode.",
|
||||
)
|
||||
parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.")
|
||||
parser_record.add_argument(
|
||||
"--num-episodes", type=int, default=50, help="Number of episodes to record."
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--run-compute-stats",
|
||||
type=int,
|
||||
@@ -559,7 +586,10 @@ if __name__ == "__main__":
|
||||
|
||||
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
|
||||
parser_replay.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
"--fps",
|
||||
type=none_or_int,
|
||||
default=None,
|
||||
help="Frames per second (set to None to disable)",
|
||||
)
|
||||
parser_replay.add_argument(
|
||||
"--root",
|
||||
@@ -585,7 +615,9 @@ if __name__ == "__main__":
|
||||
default=0,
|
||||
help="Enables the replay of delta actions instead of absolute actions.",
|
||||
)
|
||||
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episode to replay.")
|
||||
parser_replay.add_argument(
|
||||
"--episode", type=int, default=0, help="Index of the episode to replay."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@@ -135,7 +135,11 @@ def init_sim_calibration(robot, cfg):
|
||||
axis_directions = np.array(cfg.get("axis_directions", [1]))
|
||||
offsets = np.array(cfg.get("offsets", [0])) * np.pi
|
||||
|
||||
return {"start_pos": start_pos, "axis_directions": axis_directions, "offsets": offsets}
|
||||
return {
|
||||
"start_pos": start_pos,
|
||||
"axis_directions": axis_directions,
|
||||
"offsets": offsets,
|
||||
}
|
||||
|
||||
|
||||
def real_positions_to_sim(real_positions, axis_directions, start_pos, offsets):
|
||||
@@ -156,7 +160,10 @@ def teleoperate(env, robot: Robot, process_action_fn, teleop_time_s=None):
|
||||
leader_pos = robot.leader_arms.main.read("Present_Position")
|
||||
action = process_action_fn(leader_pos)
|
||||
env.step(np.expand_dims(action, 0))
|
||||
if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s:
|
||||
if (
|
||||
teleop_time_s is not None
|
||||
and time.perf_counter() - start_teleop_t > teleop_time_s
|
||||
):
|
||||
print("Teleoperation processes finished.")
|
||||
break
|
||||
|
||||
@@ -188,19 +195,27 @@ def record(
|
||||
# Load pretrained policy
|
||||
|
||||
extra_features = (
|
||||
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
|
||||
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}}
|
||||
if assign_rewards
|
||||
else None
|
||||
)
|
||||
|
||||
policy = None
|
||||
if pretrained_policy_name_or_path is not None:
|
||||
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
||||
policy, policy_fps, device, use_amp = init_policy(
|
||||
pretrained_policy_name_or_path, policy_overrides
|
||||
)
|
||||
|
||||
if fps is None:
|
||||
fps = policy_fps
|
||||
logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).")
|
||||
logging.warning(
|
||||
f"No fps provided, so using the fps from policy config ({policy_fps})."
|
||||
)
|
||||
|
||||
if policy is None and process_action_from_leader is None:
|
||||
raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.")
|
||||
raise ValueError(
|
||||
"Either policy or process_action_fn has to be set to enable control in sim."
|
||||
)
|
||||
|
||||
# initialize listener before sim env
|
||||
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
|
||||
@@ -233,7 +248,11 @@ def record(
|
||||
shape = env.observation_space[key].shape
|
||||
if not key.startswith("observation.image."):
|
||||
key = "observation.image." + key
|
||||
features[key] = {"dtype": "video", "names": ["channel", "height", "width"], "shape": shape}
|
||||
features[key] = {
|
||||
"dtype": "video",
|
||||
"names": ["channel", "height", "width"],
|
||||
"shape": shape,
|
||||
}
|
||||
|
||||
for key, obs_key in state_keys_dict.items():
|
||||
features[key] = {
|
||||
@@ -242,7 +261,11 @@ def record(
|
||||
"shape": env.observation_space[obs_key].shape,
|
||||
}
|
||||
|
||||
features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None}
|
||||
features["action"] = {
|
||||
"dtype": "float32",
|
||||
"shape": env.action_space.shape,
|
||||
"names": None,
|
||||
}
|
||||
features = {**features, **extra_features}
|
||||
|
||||
# Create empty dataset or load existing saved episodes
|
||||
@@ -343,7 +366,9 @@ def record(
|
||||
if events["stop_recording"] or recorded_episodes >= num_episodes:
|
||||
break
|
||||
else:
|
||||
logging.info("Waiting for a few seconds before starting next episode recording...")
|
||||
logging.info(
|
||||
"Waiting for a few seconds before starting next episode recording..."
|
||||
)
|
||||
busy_wait(3)
|
||||
|
||||
log_say("Stop recording", play_sounds, blocking=True)
|
||||
@@ -361,7 +386,12 @@ def record(
|
||||
|
||||
|
||||
def replay(
|
||||
env, root: Path, repo_id: str, episode: int, fps: int | None = None, local_files_only: bool = True
|
||||
env,
|
||||
root: Path,
|
||||
repo_id: str,
|
||||
episode: int,
|
||||
fps: int | None = None,
|
||||
local_files_only: bool = True,
|
||||
):
|
||||
env = env()
|
||||
|
||||
@@ -408,7 +438,10 @@ if __name__ == "__main__":
|
||||
|
||||
parser_record = subparsers.add_parser("record", parents=[base_parser])
|
||||
parser_record.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
"--fps",
|
||||
type=none_or_int,
|
||||
default=None,
|
||||
help="Frames per second (set to None to disable)",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--root",
|
||||
@@ -434,7 +467,9 @@ if __name__ == "__main__":
|
||||
required=True,
|
||||
help="A description of the task preformed during recording that can be used as a language instruction.",
|
||||
)
|
||||
parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.")
|
||||
parser_record.add_argument(
|
||||
"--num-episodes", type=int, default=50, help="Number of episodes to record."
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--run-compute-stats",
|
||||
type=int,
|
||||
@@ -495,7 +530,10 @@ if __name__ == "__main__":
|
||||
|
||||
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
|
||||
parser_replay.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
"--fps",
|
||||
type=none_or_int,
|
||||
default=None,
|
||||
help="Frames per second (set to None to disable)",
|
||||
)
|
||||
parser_replay.add_argument(
|
||||
"--root",
|
||||
@@ -509,7 +547,9 @@ if __name__ == "__main__":
|
||||
default="lerobot/test",
|
||||
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
||||
)
|
||||
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episodes to replay.")
|
||||
parser_replay.add_argument(
|
||||
"--episode", type=int, default=0, help="Index of the episodes to replay."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@@ -59,7 +59,11 @@ np_version = np.__version__ if HAS_NP else "N/A"
|
||||
|
||||
torch_version = torch.__version__ if HAS_TORCH else "N/A"
|
||||
torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
|
||||
cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A"
|
||||
cuda_version = (
|
||||
torch._C._cuda_getCompiledVersion()
|
||||
if HAS_TORCH and torch.version.cuda is not None
|
||||
else "N/A"
|
||||
)
|
||||
|
||||
|
||||
# TODO(aliberts): refactor into an actual command `lerobot env`
|
||||
@@ -77,7 +81,9 @@ def display_sys_info() -> dict:
|
||||
"Using GPU in script?": "<fill in>",
|
||||
# "Using distributed or parallel set-up in script?": "<fill in>",
|
||||
}
|
||||
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n")
|
||||
print(
|
||||
"\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n"
|
||||
)
|
||||
print(format_dict(info))
|
||||
return info
|
||||
|
||||
|
||||
@@ -149,7 +149,9 @@ def rollout(
|
||||
if return_observations:
|
||||
all_observations.append(deepcopy(observation))
|
||||
|
||||
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
|
||||
observation = {
|
||||
key: observation[key].to(device, non_blocking=True) for key in observation
|
||||
}
|
||||
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
@@ -166,7 +168,10 @@ def rollout(
|
||||
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
|
||||
# available of none of the envs finished.
|
||||
if "final_info" in info:
|
||||
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
|
||||
successes = [
|
||||
info["is_success"] if info is not None else False
|
||||
for info in info["final_info"]
|
||||
]
|
||||
else:
|
||||
successes = [False] * env.num_envs
|
||||
|
||||
@@ -180,9 +185,13 @@ def rollout(
|
||||
|
||||
step += 1
|
||||
running_success_rate = (
|
||||
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean()
|
||||
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any")
|
||||
.numpy()
|
||||
.mean()
|
||||
)
|
||||
progbar.set_postfix(
|
||||
{"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"}
|
||||
)
|
||||
progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"})
|
||||
progbar.update()
|
||||
|
||||
# Track the final observation.
|
||||
@@ -200,7 +209,9 @@ def rollout(
|
||||
if return_observations:
|
||||
stacked_observations = {}
|
||||
for key in all_observations[0]:
|
||||
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
|
||||
stacked_observations[key] = torch.stack(
|
||||
[obs[key] for obs in all_observations], dim=1
|
||||
)
|
||||
ret["observation"] = stacked_observations
|
||||
|
||||
return ret
|
||||
@@ -255,7 +266,9 @@ def eval_policy(
|
||||
return
|
||||
n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs)
|
||||
if isinstance(env, gym.vector.SyncVectorEnv):
|
||||
ep_frames.append(np.stack([env.envs[i].render() for i in range(n_to_render_now)])) # noqa: B023
|
||||
ep_frames.append(
|
||||
np.stack([env.envs[i].render() for i in range(n_to_render_now)])
|
||||
) # noqa: B023
|
||||
elif isinstance(env, gym.vector.AsyncVectorEnv):
|
||||
# Here we must render all frames and discard any we don't need.
|
||||
ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
|
||||
@@ -267,7 +280,9 @@ def eval_policy(
|
||||
episode_data: dict | None = None
|
||||
|
||||
# we dont want progress bar when we use slurm, since it clutters the logs
|
||||
progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm())
|
||||
progbar = trange(
|
||||
n_batches, desc="Stepping through eval batches", disable=inside_slurm()
|
||||
)
|
||||
for batch_ix in progbar:
|
||||
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
|
||||
# step.
|
||||
@@ -278,7 +293,8 @@ def eval_policy(
|
||||
seeds = None
|
||||
else:
|
||||
seeds = range(
|
||||
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
|
||||
start_seed + (batch_ix * env.num_envs),
|
||||
start_seed + ((batch_ix + 1) * env.num_envs),
|
||||
)
|
||||
rollout_data = rollout(
|
||||
env,
|
||||
@@ -296,13 +312,22 @@ def eval_policy(
|
||||
|
||||
# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
|
||||
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
|
||||
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
|
||||
mask = (
|
||||
torch.arange(n_steps)
|
||||
<= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)
|
||||
).int()
|
||||
# Extend metrics.
|
||||
batch_sum_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "sum")
|
||||
batch_sum_rewards = einops.reduce(
|
||||
(rollout_data["reward"] * mask), "b n -> b", "sum"
|
||||
)
|
||||
sum_rewards.extend(batch_sum_rewards.tolist())
|
||||
batch_max_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "max")
|
||||
batch_max_rewards = einops.reduce(
|
||||
(rollout_data["reward"] * mask), "b n -> b", "max"
|
||||
)
|
||||
max_rewards.extend(batch_max_rewards.tolist())
|
||||
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
|
||||
batch_successes = einops.reduce(
|
||||
(rollout_data["success"] * mask), "b n -> b", "any"
|
||||
)
|
||||
all_successes.extend(batch_successes.tolist())
|
||||
if seeds:
|
||||
all_seeds.extend(seeds)
|
||||
@@ -315,17 +340,27 @@ def eval_policy(
|
||||
rollout_data,
|
||||
done_indices,
|
||||
start_episode_index=batch_ix * env.num_envs,
|
||||
start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)),
|
||||
start_data_index=(
|
||||
0
|
||||
if episode_data is None
|
||||
else (episode_data["index"][-1].item() + 1)
|
||||
),
|
||||
fps=env.unwrapped.metadata["render_fps"],
|
||||
)
|
||||
if episode_data is None:
|
||||
episode_data = this_episode_data
|
||||
else:
|
||||
# Some sanity checks to make sure we are correctly compiling the data.
|
||||
assert episode_data["episode_index"][-1] + 1 == this_episode_data["episode_index"][0]
|
||||
assert (
|
||||
episode_data["episode_index"][-1] + 1
|
||||
== this_episode_data["episode_index"][0]
|
||||
)
|
||||
assert episode_data["index"][-1] + 1 == this_episode_data["index"][0]
|
||||
# Concatenate the episode data.
|
||||
episode_data = {k: torch.cat([episode_data[k], this_episode_data[k]]) for k in episode_data}
|
||||
episode_data = {
|
||||
k: torch.cat([episode_data[k], this_episode_data[k]])
|
||||
for k in episode_data
|
||||
}
|
||||
|
||||
# Maybe render video for visualization.
|
||||
if max_episodes_rendered > 0 and len(ep_frames) > 0:
|
||||
@@ -343,7 +378,9 @@ def eval_policy(
|
||||
target=write_video,
|
||||
args=(
|
||||
str(video_path),
|
||||
stacked_frames[: done_index + 1], # + 1 to capture the last observation
|
||||
stacked_frames[
|
||||
: done_index + 1
|
||||
], # + 1 to capture the last observation
|
||||
env.unwrapped.metadata["render_fps"],
|
||||
),
|
||||
)
|
||||
@@ -352,7 +389,9 @@ def eval_policy(
|
||||
n_episodes_rendered += 1
|
||||
|
||||
progbar.set_postfix(
|
||||
{"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"}
|
||||
{
|
||||
"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"
|
||||
}
|
||||
)
|
||||
|
||||
# Wait till all video rendering threads are done.
|
||||
@@ -398,7 +437,11 @@ def eval_policy(
|
||||
|
||||
|
||||
def _compile_episode_data(
|
||||
rollout_data: dict, done_indices: Tensor, start_episode_index: int, start_data_index: int, fps: float
|
||||
rollout_data: dict,
|
||||
done_indices: Tensor,
|
||||
start_episode_index: int,
|
||||
start_data_index: int,
|
||||
fps: float,
|
||||
) -> dict:
|
||||
"""Convenience function for `eval_policy(return_episode_data=True)`
|
||||
|
||||
@@ -416,12 +459,16 @@ def _compile_episode_data(
|
||||
# Here we do `num_frames - 1` as we don't want to include the last observation frame just yet.
|
||||
ep_dict = {
|
||||
"action": rollout_data["action"][ep_ix, : num_frames - 1],
|
||||
"episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)),
|
||||
"episode_index": torch.tensor(
|
||||
[start_episode_index + ep_ix] * (num_frames - 1)
|
||||
),
|
||||
"frame_index": torch.arange(0, num_frames - 1, 1),
|
||||
"timestamp": torch.arange(0, num_frames - 1, 1) / fps,
|
||||
"next.done": rollout_data["done"][ep_ix, : num_frames - 1],
|
||||
"next.success": rollout_data["success"][ep_ix, : num_frames - 1],
|
||||
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32),
|
||||
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(
|
||||
torch.float32
|
||||
),
|
||||
}
|
||||
|
||||
# For the last observation frame, all other keys will just be copy padded.
|
||||
@@ -437,7 +484,9 @@ def _compile_episode_data(
|
||||
for key in ep_dicts[0]:
|
||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
|
||||
data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
|
||||
data_dict["index"] = torch.arange(
|
||||
start_data_index, start_data_index + total_frames, 1
|
||||
)
|
||||
|
||||
return data_dict
|
||||
|
||||
@@ -450,7 +499,9 @@ def main(
|
||||
):
|
||||
assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None)
|
||||
if pretrained_policy_path is not None:
|
||||
hydra_cfg = init_hydra_config(str(pretrained_policy_path / "config.yaml"), config_overrides)
|
||||
hydra_cfg = init_hydra_config(
|
||||
str(pretrained_policy_path / "config.yaml"), config_overrides
|
||||
)
|
||||
else:
|
||||
hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)
|
||||
|
||||
@@ -481,15 +532,23 @@ def main(
|
||||
|
||||
logging.info("Making policy.")
|
||||
if hydra_cfg_path is None:
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=str(pretrained_policy_path))
|
||||
policy = make_policy(
|
||||
hydra_cfg=hydra_cfg,
|
||||
pretrained_policy_name_or_path=str(pretrained_policy_path),
|
||||
)
|
||||
else:
|
||||
# Note: We need the dataset stats to pass to the policy's normalization modules.
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).meta.stats)
|
||||
policy = make_policy(
|
||||
hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).meta.stats
|
||||
)
|
||||
|
||||
assert isinstance(policy, nn.Module)
|
||||
policy.eval()
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(),
|
||||
):
|
||||
info = eval_policy(
|
||||
env,
|
||||
policy,
|
||||
@@ -511,16 +570,14 @@ def main(
|
||||
|
||||
def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
|
||||
try:
|
||||
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision))
|
||||
pretrained_policy_path = Path(
|
||||
snapshot_download(pretrained_policy_name_or_path, revision=revision)
|
||||
)
|
||||
except (HFValidationError, RepositoryNotFoundError) as e:
|
||||
if isinstance(e, HFValidationError):
|
||||
error_message = (
|
||||
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
|
||||
)
|
||||
error_message = "The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
|
||||
else:
|
||||
error_message = (
|
||||
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
|
||||
)
|
||||
error_message = "The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
|
||||
|
||||
logging.warning(f"{error_message} Treating it as a local directory.")
|
||||
pretrained_policy_path = Path(pretrained_policy_name_or_path)
|
||||
@@ -555,7 +612,9 @@ if __name__ == "__main__":
|
||||
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
|
||||
parser.add_argument(
|
||||
"--revision", help="Optionally provide the Hugging Face Hub revision ID."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out-dir",
|
||||
help=(
|
||||
@@ -571,7 +630,11 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.pretrained_policy_name_or_path is None:
|
||||
main(hydra_cfg_path=args.config, out_dir=args.out_dir, config_overrides=args.overrides)
|
||||
main(
|
||||
hydra_cfg_path=args.config,
|
||||
out_dir=args.out_dir,
|
||||
config_overrides=args.overrides,
|
||||
)
|
||||
else:
|
||||
pretrained_policy_path = get_pretrained_policy_path(
|
||||
args.pretrained_policy_name_or_path, revision=args.revision
|
||||
|
||||
@@ -46,7 +46,11 @@ import torch
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position
|
||||
from lerobot.common.robot_devices.control_utils import (
|
||||
busy_wait,
|
||||
is_headless,
|
||||
reset_follower_position,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.factory import Robot, make_robot
|
||||
from lerobot.common.utils.utils import (
|
||||
init_hydra_config,
|
||||
@@ -60,13 +64,19 @@ def get_classifier(pretrained_path, config_path):
|
||||
return
|
||||
|
||||
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
|
||||
ClassifierConfig,
|
||||
)
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
||||
Classifier,
|
||||
)
|
||||
|
||||
cfg = init_hydra_config(config_path)
|
||||
|
||||
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
||||
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
|
||||
classifier_config.num_cameras = len(
|
||||
cfg.training.image_keys
|
||||
) # TODO automate these paths
|
||||
model = Classifier(classifier_config)
|
||||
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
|
||||
model = model.to("mps")
|
||||
@@ -151,11 +161,17 @@ def rollout(
|
||||
images = []
|
||||
for key in image_keys:
|
||||
if display_cameras:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.imshow(
|
||||
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
|
||||
)
|
||||
cv2.waitKey(1)
|
||||
images.append(observation[key].to("mps"))
|
||||
|
||||
reward = reward_classifier.predict_reward(images) if reward_classifier is not None else 0.0
|
||||
reward = (
|
||||
reward_classifier.predict_reward(images)
|
||||
if reward_classifier is not None
|
||||
else 0.0
|
||||
)
|
||||
all_rewards.append(reward)
|
||||
|
||||
# print("REWARD : ", reward)
|
||||
@@ -219,11 +235,19 @@ def eval_policy(
|
||||
|
||||
start_eval = time.perf_counter()
|
||||
progbar = trange(n_episodes, desc="Evaluating policy on real robot")
|
||||
reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file)
|
||||
reward_classifier = get_classifier(
|
||||
reward_classifier_pretrained_path, reward_classifier_config_file
|
||||
)
|
||||
|
||||
for _ in progbar:
|
||||
rollout_data = rollout(
|
||||
robot, policy, reward_classifier, fps, control_time_s, use_amp, display_cameras
|
||||
robot,
|
||||
policy,
|
||||
reward_classifier,
|
||||
fps,
|
||||
control_time_s,
|
||||
use_amp,
|
||||
display_cameras,
|
||||
)
|
||||
|
||||
rollouts.append(rollout_data)
|
||||
@@ -289,7 +313,9 @@ def init_keyboard_listener():
|
||||
print("Right arrow key pressed. Exiting loop...")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||
print(
|
||||
"Left arrow key pressed. Exiting loop and rerecord the last episode..."
|
||||
)
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.space:
|
||||
@@ -301,7 +327,10 @@ def init_keyboard_listener():
|
||||
"Place the leader in similar pose to the follower and press space again."
|
||||
)
|
||||
events["pause_policy"] = True
|
||||
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
|
||||
log_say(
|
||||
"Human intervention stage. Get ready to take over.",
|
||||
play_sounds=True,
|
||||
)
|
||||
else:
|
||||
events["human_intervention_step"] = True
|
||||
print("Space key pressed. Human intervention starting.")
|
||||
@@ -351,7 +380,9 @@ if __name__ == "__main__":
|
||||
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
|
||||
parser.add_argument(
|
||||
"--revision", help="Optionally provide the Hugging Face Hub revision ID."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out-dir",
|
||||
help=(
|
||||
@@ -360,7 +391,8 @@ if __name__ == "__main__":
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--display-cameras", help=("Whether to display the camera feed while the rollout is happening")
|
||||
"--display-cameras",
|
||||
help=("Whether to display the camera feed while the rollout is happening"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-classifier-pretrained-path",
|
||||
|
||||
@@ -32,9 +32,13 @@ def find_port():
|
||||
print(f"The port of this MotorsBus is '{port}'")
|
||||
print("Reconnect the USB cable.")
|
||||
elif len(ports_diff) == 0:
|
||||
raise OSError(f"Could not detect the port. No difference was found ({ports_diff}).")
|
||||
raise OSError(
|
||||
f"Could not detect the port. No difference was found ({ports_diff})."
|
||||
)
|
||||
else:
|
||||
raise OSError(f"Could not detect the port. More than one port was found ({ports_diff}).")
|
||||
raise OSError(
|
||||
f"Could not detect the port. More than one port was found ({ports_diff})."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -56,24 +56,42 @@ from safetensors.torch import save_file
|
||||
from lerobot.common.datasets.compute_stats import compute_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
||||
from lerobot.common.datasets.utils import create_branch, create_lerobot_dataset_card, flatten_dict
|
||||
from lerobot.common.datasets.utils import (
|
||||
create_branch,
|
||||
create_lerobot_dataset_card,
|
||||
flatten_dict,
|
||||
)
|
||||
|
||||
|
||||
def get_from_raw_to_lerobot_format_fn(raw_format: str):
|
||||
if raw_format == "pusht_zarr":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format
|
||||
from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import (
|
||||
from_raw_to_lerobot_format,
|
||||
)
|
||||
elif raw_format == "umi_zarr":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
|
||||
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import (
|
||||
from_raw_to_lerobot_format,
|
||||
)
|
||||
elif raw_format == "aloha_hdf5":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import (
|
||||
from_raw_to_lerobot_format,
|
||||
)
|
||||
elif raw_format in ["rlds", "openx"]:
|
||||
from lerobot.common.datasets.push_dataset_to_hub.openx_rlds_format import from_raw_to_lerobot_format
|
||||
from lerobot.common.datasets.push_dataset_to_hub.openx_rlds_format import (
|
||||
from_raw_to_lerobot_format,
|
||||
)
|
||||
elif raw_format == "dora_parquet":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
|
||||
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import (
|
||||
from_raw_to_lerobot_format,
|
||||
)
|
||||
elif raw_format == "xarm_pkl":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
|
||||
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import (
|
||||
from_raw_to_lerobot_format,
|
||||
)
|
||||
elif raw_format == "cam_png":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.cam_png_format import from_raw_to_lerobot_format
|
||||
from lerobot.common.datasets.push_dataset_to_hub.cam_png_format import (
|
||||
from_raw_to_lerobot_format,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The selected {raw_format} can't be found. Did you add it to `lerobot/scripts/push_dataset_to_hub.py::get_from_raw_to_lerobot_format_fn`?"
|
||||
@@ -83,7 +101,10 @@ def get_from_raw_to_lerobot_format_fn(raw_format: str):
|
||||
|
||||
|
||||
def save_meta_data(
|
||||
info: dict[str, Any], stats: dict, episode_data_index: dict[str, list], meta_data_dir: Path
|
||||
info: dict[str, Any],
|
||||
stats: dict,
|
||||
episode_data_index: dict[str, list],
|
||||
meta_data_dir: Path,
|
||||
):
|
||||
meta_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -97,12 +118,16 @@ def save_meta_data(
|
||||
save_file(flatten_dict(stats), stats_path)
|
||||
|
||||
# save episode_data_index
|
||||
episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
|
||||
episode_data_index = {
|
||||
key: torch.tensor(episode_data_index[key]) for key in episode_data_index
|
||||
}
|
||||
ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
|
||||
save_file(episode_data_index, ep_data_idx_path)
|
||||
|
||||
|
||||
def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str | None):
|
||||
def push_meta_data_to_hub(
|
||||
repo_id: str, meta_data_dir: str | Path, revision: str | None
|
||||
):
|
||||
"""Expect all meta data files to be all stored in a single "meta_data" directory.
|
||||
On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root.
|
||||
"""
|
||||
@@ -187,7 +212,9 @@ def push_dataset_to_hub(
|
||||
if force_override:
|
||||
shutil.rmtree(local_dir)
|
||||
elif not resume:
|
||||
raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.")
|
||||
raise ValueError(
|
||||
f"`local_dir` already exists ({local_dir}). Use `--force-override 1`."
|
||||
)
|
||||
|
||||
meta_data_dir = local_dir / "meta_data"
|
||||
videos_dir = local_dir / "videos"
|
||||
@@ -223,7 +250,9 @@ def push_dataset_to_hub(
|
||||
stats = compute_stats(lerobot_dataset, batch_size, num_workers)
|
||||
|
||||
if local_dir:
|
||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||
hf_dataset = hf_dataset.with_format(
|
||||
None
|
||||
) # to remove transforms that cant be saved
|
||||
hf_dataset.save_to_disk(str(local_dir / "train"))
|
||||
|
||||
if push_to_hub or local_dir:
|
||||
|
||||
@@ -13,29 +13,25 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import io
|
||||
import logging
|
||||
import pickle
|
||||
import queue
|
||||
import time
|
||||
from concurrent import futures
|
||||
from statistics import mean, quantiles
|
||||
from functools import lru_cache
|
||||
from lerobot.scripts.server.utils import setup_process_handlers
|
||||
|
||||
# from lerobot.scripts.eval import eval_policy
|
||||
from threading import Thread
|
||||
|
||||
import grpc
|
||||
import hydra
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
from torch import nn
|
||||
import time
|
||||
|
||||
# TODO: Remove the import of maniskill
|
||||
# from lerobot.common.envs.factory import make_maniskill_env
|
||||
# from lerobot.common.envs.utils import preprocess_maniskill_observation
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.robot_devices.control_utils import busy_wait
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.utils.utils import (
|
||||
@@ -44,132 +40,255 @@ from lerobot.common.utils.utils import (
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc
|
||||
from lerobot.scripts.server.buffer import Transition, move_state_dict_to_device, move_transition_to_device
|
||||
from lerobot.scripts.server.buffer import (
|
||||
Transition,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
python_object_to_bytes,
|
||||
transitions_to_bytes,
|
||||
bytes_to_state_dict,
|
||||
)
|
||||
from lerobot.scripts.server.network_utils import (
|
||||
receive_bytes_in_chunks,
|
||||
send_bytes_in_chunks,
|
||||
)
|
||||
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
|
||||
from lerobot.scripts.server import learner_service
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
from torch.multiprocessing import Queue, Event
|
||||
from queue import Empty
|
||||
|
||||
parameters_queue = queue.Queue(maxsize=1)
|
||||
message_queue = queue.Queue(maxsize=1_000_000)
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
from lerobot.scripts.server.utils import get_last_item_from_queue
|
||||
|
||||
ACTOR_SHUTDOWN_TIMEOUT = 30
|
||||
|
||||
|
||||
class ActorInformation:
|
||||
def receive_policy(
|
||||
cfg: DictConfig,
|
||||
parameters_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
):
|
||||
logging.info("[ACTOR] Start receiving parameters from the Learner")
|
||||
|
||||
if not use_threads(cfg):
|
||||
# Setup process handlers to handle shutdown signal
|
||||
# But use shutdown event from the main process
|
||||
setup_process_handlers(False)
|
||||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.actor_learner_config.learner_host,
|
||||
port=cfg.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
try:
|
||||
iterator = learner_client.StreamParameters(hilserl_pb2.Empty())
|
||||
receive_bytes_in_chunks(
|
||||
iterator,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
log_prefix="[ACTOR] parameters",
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
if not use_threads(cfg):
|
||||
grpc_channel.close()
|
||||
logging.info("[ACTOR] Received policy loop stopped")
|
||||
|
||||
|
||||
def transitions_stream(
|
||||
shutdown_event: Event, transitions_queue: Queue
|
||||
) -> hilserl_pb2.Empty:
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
message = transitions_queue.get(block=True, timeout=5)
|
||||
except Empty:
|
||||
logging.debug("[ACTOR] Transition queue is empty")
|
||||
continue
|
||||
|
||||
yield from send_bytes_in_chunks(
|
||||
message, hilserl_pb2.Transition, log_prefix="[ACTOR] Send transitions"
|
||||
)
|
||||
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
|
||||
def interactions_stream(
|
||||
shutdown_event: any, # Event,
|
||||
interactions_queue: Queue,
|
||||
) -> hilserl_pb2.Empty:
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
message = interactions_queue.get(block=True, timeout=5)
|
||||
except Empty:
|
||||
logging.debug("[ACTOR] Interaction queue is empty")
|
||||
continue
|
||||
|
||||
yield from send_bytes_in_chunks(
|
||||
message,
|
||||
hilserl_pb2.InteractionMessage,
|
||||
log_prefix="[ACTOR] Send interactions",
|
||||
)
|
||||
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
|
||||
def send_transitions(
|
||||
cfg: DictConfig,
|
||||
transitions_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
) -> hilserl_pb2.Empty:
|
||||
"""
|
||||
This helper class is used to differentiate between two types of messages that are placed in the same queue during streaming:
|
||||
Sends transitions to the learner.
|
||||
|
||||
- **Transition Data:** Contains experience tuples (observation, action, reward, next observation) collected during interaction.
|
||||
- **Interaction Messages:** Encapsulates statistics related to the interaction process.
|
||||
This function continuously retrieves messages from the queue and processes:
|
||||
|
||||
Attributes:
|
||||
transition (Optional): Transition data to be sent to the learner.
|
||||
interaction_message (Optional): Iteraction message providing additional statistics for logging.
|
||||
- **Transition Data:**
|
||||
- A batch of transitions (observation, action, reward, next observation) is collected.
|
||||
- Transitions are moved to the CPU and serialized using PyTorch.
|
||||
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
|
||||
"""
|
||||
|
||||
def __init__(self, transition=None, interaction_message=None):
|
||||
self.transition = transition
|
||||
self.interaction_message = interaction_message
|
||||
if not use_threads(cfg):
|
||||
# Setup process handlers to handle shutdown signal
|
||||
# But use shutdown event from the main process
|
||||
setup_process_handlers(False)
|
||||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.actor_learner_config.learner_host,
|
||||
port=cfg.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
try:
|
||||
learner_client.SendTransitions(
|
||||
transitions_stream(shutdown_event, transitions_queue)
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
logging.info("[ACTOR] Finished streaming transitions")
|
||||
|
||||
if not use_threads(cfg):
|
||||
grpc_channel.close()
|
||||
logging.info("[ACTOR] Transitions process stopped")
|
||||
|
||||
|
||||
class ActorServiceServicer(hilserl_pb2_grpc.ActorServiceServicer):
|
||||
def send_interactions(
|
||||
cfg: DictConfig,
|
||||
interactions_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
) -> hilserl_pb2.Empty:
|
||||
"""
|
||||
gRPC service for actor-learner communication in reinforcement learning.
|
||||
Sends interactions to the learner.
|
||||
|
||||
This service is responsible for:
|
||||
1. Streaming batches of transition data and statistical metrics from the actor to the learner.
|
||||
2. Receiving updated network parameters from the learner.
|
||||
This function continuously retrieves messages from the queue and processes:
|
||||
|
||||
- **Interaction Messages:**
|
||||
- Contains useful statistics about episodic rewards and policy timings.
|
||||
- The message is serialized using `pickle` and sent to the learner.
|
||||
"""
|
||||
|
||||
def StreamTransition(self, request, context): # noqa: N802
|
||||
"""
|
||||
Streams data from the actor to the learner.
|
||||
if not use_threads(cfg):
|
||||
# Setup process handlers to handle shutdown signal
|
||||
# But use shutdown event from the main process
|
||||
setup_process_handlers(False)
|
||||
|
||||
This function continuously retrieves messages from the queue and processes them based on their type:
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.actor_learner_config.learner_host,
|
||||
port=cfg.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
- **Transition Data:**
|
||||
- A batch of transitions (observation, action, reward, next observation) is collected.
|
||||
- Transitions are moved to the CPU and serialized using PyTorch.
|
||||
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
|
||||
try:
|
||||
learner_client.SendInteractions(
|
||||
interactions_stream(shutdown_event, interactions_queue)
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
- **Interaction Messages:**
|
||||
- Contains useful statistics about episodic rewards and policy timings.
|
||||
- The message is serialized using `pickle` and sent to the learner.
|
||||
logging.info("[ACTOR] Finished streaming interactions")
|
||||
|
||||
Yields:
|
||||
hilserl_pb2.ActorInformation: The response message containing either transition data or an interaction message.
|
||||
"""
|
||||
while True:
|
||||
message = message_queue.get(block=True)
|
||||
|
||||
if message.transition is not None:
|
||||
transition_to_send_to_learner: list[Transition] = [
|
||||
move_transition_to_device(transition=T, device="cpu") for T in message.transition
|
||||
]
|
||||
# Check for NaNs in transitions before sending to learner
|
||||
for transition in transition_to_send_to_learner:
|
||||
for key, value in transition["state"].items():
|
||||
if torch.isnan(value).any():
|
||||
logging.warning(f"Found NaN values in transition {key}")
|
||||
buf = io.BytesIO()
|
||||
torch.save(transition_to_send_to_learner, buf)
|
||||
transition_bytes = buf.getvalue()
|
||||
|
||||
transition_message = hilserl_pb2.Transition(transition_bytes=transition_bytes)
|
||||
|
||||
response = hilserl_pb2.ActorInformation(transition=transition_message)
|
||||
|
||||
elif message.interaction_message is not None:
|
||||
content = hilserl_pb2.InteractionMessage(
|
||||
interaction_message_bytes=pickle.dumps(message.interaction_message)
|
||||
)
|
||||
response = hilserl_pb2.ActorInformation(interaction_message=content)
|
||||
|
||||
yield response
|
||||
|
||||
def SendParameters(self, request, context): # noqa: N802
|
||||
"""
|
||||
Receives updated parameters from the learner and updates the actor.
|
||||
|
||||
The learner calls this method to send new model parameters. The received parameters are deserialized
|
||||
and placed in a queue to be consumed by the actor.
|
||||
|
||||
Args:
|
||||
request (hilserl_pb2.ParameterUpdate): The request containing serialized network parameters.
|
||||
context (grpc.ServicerContext): The gRPC context.
|
||||
|
||||
Returns:
|
||||
hilserl_pb2.Empty: An empty response to acknowledge receipt.
|
||||
"""
|
||||
buffer = io.BytesIO(request.parameter_bytes)
|
||||
params = torch.load(buffer)
|
||||
parameters_queue.put(params)
|
||||
return hilserl_pb2.Empty()
|
||||
if not use_threads(cfg):
|
||||
grpc_channel.close()
|
||||
logging.info("[ACTOR] Interactions process stopped")
|
||||
|
||||
|
||||
def serve_actor_service(port=50052):
|
||||
@lru_cache(maxsize=1)
|
||||
def learner_service_client(
|
||||
host="127.0.0.1", port=50051
|
||||
) -> tuple[hilserl_pb2_grpc.LearnerServiceStub, grpc.Channel]:
|
||||
import json
|
||||
|
||||
"""
|
||||
Runs a gRPC server to start streaming the data from the actor to the learner.
|
||||
Throught this server the learner can push parameters to the Actor as well.
|
||||
Returns a client for the learner service.
|
||||
|
||||
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
|
||||
So we need to create only one client and reuse it.
|
||||
"""
|
||||
server = grpc.server(
|
||||
futures.ThreadPoolExecutor(max_workers=20),
|
||||
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
|
||||
|
||||
service_config = {
|
||||
"methodConfig": [
|
||||
{
|
||||
"name": [{}], # Applies to ALL methods in ALL services
|
||||
"retryPolicy": {
|
||||
"maxAttempts": 5, # Max retries (total attempts = 5)
|
||||
"initialBackoff": "0.1s", # First retry after 0.1s
|
||||
"maxBackoff": "2s", # Max wait time between retries
|
||||
"backoffMultiplier": 2, # Exponential backoff factor
|
||||
"retryableStatusCodes": [
|
||||
"UNAVAILABLE",
|
||||
"DEADLINE_EXCEEDED",
|
||||
], # Retries on network failures
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
service_config_json = json.dumps(service_config)
|
||||
|
||||
channel = grpc.insecure_channel(
|
||||
f"{host}:{port}",
|
||||
options=[
|
||||
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||
("grpc.enable_retries", 1),
|
||||
("grpc.service_config", service_config_json),
|
||||
],
|
||||
)
|
||||
hilserl_pb2_grpc.add_ActorServiceServicer_to_server(ActorServiceServicer(), server)
|
||||
server.add_insecure_port(f"[::]:{port}")
|
||||
server.start()
|
||||
logging.info(f"[ACTOR] gRPC server listening on port {port}")
|
||||
server.wait_for_termination()
|
||||
stub = hilserl_pb2_grpc.LearnerServiceStub(channel)
|
||||
logging.info("[ACTOR] Learner service client created")
|
||||
return stub, channel
|
||||
|
||||
|
||||
def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, device):
|
||||
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
|
||||
if not parameters_queue.empty():
|
||||
logging.info("[ACTOR] Load new parameters from Learner.")
|
||||
state_dict = parameters_queue.get()
|
||||
bytes_state_dict = get_last_item_from_queue(parameters_queue)
|
||||
state_dict = bytes_to_state_dict(bytes_state_dict)
|
||||
state_dict = move_state_dict_to_device(state_dict, device=device)
|
||||
policy.load_state_dict(state_dict)
|
||||
|
||||
|
||||
def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module):
|
||||
def act_with_policy(
|
||||
cfg: DictConfig,
|
||||
robot: Robot,
|
||||
reward_classifier: nn.Module,
|
||||
shutdown_event: any, # Event,
|
||||
parameters_queue: Queue,
|
||||
transitions_queue: Queue,
|
||||
interactions_queue: Queue,
|
||||
):
|
||||
"""
|
||||
Executes policy interaction within the environment.
|
||||
|
||||
@@ -182,7 +301,9 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||
|
||||
logging.info("make_env online")
|
||||
|
||||
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg)
|
||||
online_env = make_robot_env(
|
||||
robot=robot, reward_classifier=reward_classifier, cfg=cfg
|
||||
)
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
@@ -192,17 +313,6 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||
|
||||
logging.info("make_policy")
|
||||
|
||||
# HACK: This is an ugly hack to pass the normalization parameters to the policy
|
||||
# Because the action space is dynamic so we override the output normalization parameters
|
||||
# it's ugly, we know ... and we will fix it
|
||||
min_action_space: list = online_env.action_space.spaces[0].low.tolist()
|
||||
max_action_space: list = online_env.action_space.spaces[0].high.tolist()
|
||||
output_normalization_params: dict[dict[str, list]] = {
|
||||
"action": {"min": min_action_space, "max": max_action_space}
|
||||
}
|
||||
cfg.policy.output_normalization_params = output_normalization_params
|
||||
cfg.policy.output_shapes["action"] = online_env.action_space.spaces[0].shape
|
||||
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
@@ -225,19 +335,33 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||
list_transition_to_send_to_learner = []
|
||||
list_policy_time = []
|
||||
episode_intervention = False
|
||||
# Add counters for intervention rate calculation
|
||||
episode_intervention_steps = 0
|
||||
episode_total_steps = 0
|
||||
|
||||
for interaction_step in range(cfg.training.online_steps):
|
||||
start_time = time.perf_counter()
|
||||
if shutdown_event.is_set():
|
||||
logging.info("[ACTOR] Shutting down act_with_policy")
|
||||
return
|
||||
|
||||
if interaction_step >= cfg.training.online_step_before_learning:
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with TimerManager(
|
||||
elapsed_time_list=list_policy_time, label="Policy inference time", log=False
|
||||
elapsed_time_list=list_policy_time,
|
||||
label="Policy inference time",
|
||||
log=False,
|
||||
) as timer: # noqa: F841
|
||||
action = policy.select_action(batch=obs)
|
||||
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
log_policy_frequency_issue(
|
||||
policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step
|
||||
)
|
||||
|
||||
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
|
||||
next_obs, reward, done, truncated, info = online_env.step(
|
||||
action.squeeze(dim=0).cpu().numpy()
|
||||
)
|
||||
else:
|
||||
# TODO (azouitine): Make a custom space for torch tensor
|
||||
action = online_env.action_space.sample()
|
||||
@@ -245,10 +369,14 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||
|
||||
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
|
||||
action = (
|
||||
torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0)
|
||||
torch.from_numpy(action[0])
|
||||
.to(device, non_blocking=device.type == "cuda")
|
||||
.unsqueeze(dim=0)
|
||||
)
|
||||
|
||||
sum_reward_episode += float(reward)
|
||||
# Increment total steps counter for intervention rate
|
||||
episode_total_steps += 1
|
||||
|
||||
# NOTE: We overide the action if the intervention is True, because the action applied is the intervention action
|
||||
if "is_intervention" in info and info["is_intervention"]:
|
||||
@@ -257,11 +385,15 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||
# but sometimes for example we want to deactivate the gripper
|
||||
action = info["action_intervention"]
|
||||
episode_intervention = True
|
||||
# Increment intervention steps counter
|
||||
episode_intervention_steps += 1
|
||||
|
||||
# Check for NaN values in observations
|
||||
for key, tensor in obs.items():
|
||||
if torch.isnan(tensor).any():
|
||||
logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}")
|
||||
logging.error(
|
||||
f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}"
|
||||
)
|
||||
|
||||
list_transition_to_send_to_learner.append(
|
||||
Transition(
|
||||
@@ -270,10 +402,10 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||
reward=reward,
|
||||
next_state=next_obs,
|
||||
done=done,
|
||||
truncated=truncated, # TODO: (azouitine) Handle truncation properly
|
||||
complementary_info=info, # TODO Handle information for the transition, is_demonstraction: bool
|
||||
)
|
||||
)
|
||||
|
||||
# assign obs to the next obs and continue the rollout
|
||||
obs = next_obs
|
||||
|
||||
@@ -281,36 +413,54 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||
# Because we are using a single environment we can index at zero
|
||||
if done or truncated:
|
||||
# TODO: Handle logging for episode information
|
||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
logging.info(
|
||||
f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}"
|
||||
)
|
||||
|
||||
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
|
||||
update_policy_parameters(
|
||||
policy=policy.actor, parameters_queue=parameters_queue, device=device
|
||||
)
|
||||
|
||||
if len(list_transition_to_send_to_learner) > 0:
|
||||
send_transitions_in_chunks(
|
||||
transitions=list_transition_to_send_to_learner, message_queue=message_queue, chunk_size=4
|
||||
push_transitions_to_transport_queue(
|
||||
transitions=list_transition_to_send_to_learner,
|
||||
transitions_queue=transitions_queue,
|
||||
)
|
||||
list_transition_to_send_to_learner = []
|
||||
|
||||
stats = get_frequency_stats(list_policy_time)
|
||||
list_policy_time.clear()
|
||||
|
||||
# Calculate intervention rate
|
||||
intervention_rate = 0.0
|
||||
if episode_total_steps > 0:
|
||||
intervention_rate = episode_intervention_steps / episode_total_steps
|
||||
|
||||
# Send episodic reward to the learner
|
||||
message_queue.put(
|
||||
ActorInformation(
|
||||
interaction_message={
|
||||
interactions_queue.put(
|
||||
python_object_to_bytes(
|
||||
{
|
||||
"Episodic reward": sum_reward_episode,
|
||||
"Interaction step": interaction_step,
|
||||
"Episode intervention": int(episode_intervention),
|
||||
"Intervention rate": intervention_rate,
|
||||
**stats,
|
||||
}
|
||||
)
|
||||
)
|
||||
sum_reward_episode = 0.0
|
||||
episode_intervention = False
|
||||
# Reset intervention counters
|
||||
episode_intervention_steps = 0
|
||||
episode_total_steps = 0
|
||||
obs, info = online_env.reset()
|
||||
|
||||
if cfg.fps is not None:
|
||||
dt_time = time.perf_counter() - start_time
|
||||
busy_wait(1 / cfg.fps - dt_time)
|
||||
|
||||
def send_transitions_in_chunks(transitions: list, message_queue, chunk_size: int = 100):
|
||||
|
||||
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
|
||||
"""Send transitions to learner in smaller chunks to avoid network issues.
|
||||
|
||||
Args:
|
||||
@@ -318,10 +468,16 @@ def send_transitions_in_chunks(transitions: list, message_queue, chunk_size: int
|
||||
message_queue: Queue to send messages to learner
|
||||
chunk_size: Size of each chunk to send
|
||||
"""
|
||||
for i in range(0, len(transitions), chunk_size):
|
||||
chunk = transitions[i : i + chunk_size]
|
||||
logging.debug(f"[ACTOR] Sending chunk of {len(chunk)} transitions to Learner.")
|
||||
message_queue.put(ActorInformation(transition=chunk))
|
||||
transition_to_send_to_learner = []
|
||||
for transition in transitions:
|
||||
tr = move_transition_to_device(transition=transition, device="cpu")
|
||||
for key, value in tr["state"].items():
|
||||
if torch.isnan(value).any():
|
||||
logging.warning(f"Found NaN values in transition {key}")
|
||||
|
||||
transition_to_send_to_learner.append(tr)
|
||||
|
||||
transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner))
|
||||
|
||||
|
||||
def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
|
||||
@@ -332,22 +488,111 @@ def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
|
||||
quantiles_90 = quantiles(list_policy_fps, n=10)[-1]
|
||||
logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}")
|
||||
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}")
|
||||
stats = {"Policy frequency [Hz]": policy_fps, "Policy frequency 90th-p [Hz]": quantiles_90}
|
||||
stats = {
|
||||
"Policy frequency [Hz]": policy_fps,
|
||||
"Policy frequency 90th-p [Hz]": quantiles_90,
|
||||
}
|
||||
return stats
|
||||
|
||||
|
||||
def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_step: int):
|
||||
def log_policy_frequency_issue(
|
||||
policy_fps: float, cfg: DictConfig, interaction_step: int
|
||||
):
|
||||
if policy_fps < cfg.fps:
|
||||
logging.warning(
|
||||
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"
|
||||
)
|
||||
|
||||
|
||||
def establish_learner_connection(
|
||||
stub,
|
||||
shutdown_event: any, # Event,
|
||||
attempts=30,
|
||||
):
|
||||
for _ in range(attempts):
|
||||
if shutdown_event.is_set():
|
||||
logging.info("[ACTOR] Shutting down establish_learner_connection")
|
||||
return False
|
||||
|
||||
# Force a connection attempt and check state
|
||||
try:
|
||||
logging.info("[ACTOR] Send ready message to Learner")
|
||||
if stub.Ready(hilserl_pb2.Empty()) == hilserl_pb2.Empty():
|
||||
return True
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] Waiting for Learner to be ready... {e}")
|
||||
time.sleep(2)
|
||||
return False
|
||||
|
||||
|
||||
def use_threads(cfg: DictConfig) -> bool:
|
||||
return cfg.actor_learner_config.concurrency.actor == "threads"
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
||||
def actor_cli(cfg: dict):
|
||||
if not use_threads(cfg):
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
|
||||
init_logging(log_file="actor.log")
|
||||
robot = make_robot(cfg=cfg.robot)
|
||||
|
||||
server_thread = Thread(target=serve_actor_service, args=(cfg.actor_learner_config.port,), daemon=True)
|
||||
shutdown_event = setup_process_handlers(use_threads(cfg))
|
||||
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.actor_learner_config.learner_host,
|
||||
port=cfg.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
logging.info("[ACTOR] Establishing connection with Learner")
|
||||
if not establish_learner_connection(learner_client, shutdown_event):
|
||||
logging.error("[ACTOR] Failed to establish connection with Learner")
|
||||
return
|
||||
|
||||
if not use_threads(cfg):
|
||||
# If we use multithreading, we can reuse the channel
|
||||
grpc_channel.close()
|
||||
grpc_channel = None
|
||||
|
||||
logging.info("[ACTOR] Connection with Learner established")
|
||||
|
||||
parameters_queue = Queue()
|
||||
transitions_queue = Queue()
|
||||
interactions_queue = Queue()
|
||||
|
||||
concurrency_entity = None
|
||||
if use_threads(cfg):
|
||||
from threading import Thread
|
||||
|
||||
concurrency_entity = Thread
|
||||
else:
|
||||
from multiprocessing import Process
|
||||
|
||||
concurrency_entity = Process
|
||||
|
||||
receive_policy_process = concurrency_entity(
|
||||
target=receive_policy,
|
||||
args=(cfg, parameters_queue, shutdown_event, grpc_channel),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
transitions_process = concurrency_entity(
|
||||
target=send_transitions,
|
||||
args=(cfg, transitions_queue, shutdown_event, grpc_channel),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
interactions_process = concurrency_entity(
|
||||
target=send_interactions,
|
||||
args=(cfg, interactions_queue, shutdown_event, grpc_channel),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
transitions_process.start()
|
||||
interactions_process.start()
|
||||
receive_policy_process.start()
|
||||
|
||||
# HACK: FOR MANISKILL we do not have a reward classifier
|
||||
# TODO: Remove this once we merge into main
|
||||
@@ -360,15 +605,36 @@ def actor_cli(cfg: dict):
|
||||
pretrained_path=cfg.env.reward_classifier.pretrained_path,
|
||||
config_path=cfg.env.reward_classifier.config_path,
|
||||
)
|
||||
policy_thread = Thread(
|
||||
target=act_with_policy,
|
||||
daemon=True,
|
||||
args=(cfg, robot, reward_classifier),
|
||||
|
||||
act_with_policy(
|
||||
cfg,
|
||||
robot,
|
||||
reward_classifier,
|
||||
shutdown_event,
|
||||
parameters_queue,
|
||||
transitions_queue,
|
||||
interactions_queue,
|
||||
)
|
||||
server_thread.start()
|
||||
policy_thread.start()
|
||||
policy_thread.join()
|
||||
server_thread.join()
|
||||
logging.info("[ACTOR] Policy process joined")
|
||||
|
||||
logging.info("[ACTOR] Closing queues")
|
||||
transitions_queue.close()
|
||||
interactions_queue.close()
|
||||
parameters_queue.close()
|
||||
|
||||
transitions_process.join()
|
||||
logging.info("[ACTOR] Transitions process joined")
|
||||
interactions_process.join()
|
||||
logging.info("[ACTOR] Interactions process joined")
|
||||
receive_policy_process.join()
|
||||
logging.info("[ACTOR] Receive policy process joined")
|
||||
|
||||
logging.info("[ACTOR] join queues")
|
||||
transitions_queue.cancel_join_thread()
|
||||
interactions_queue.cancel_join_thread()
|
||||
parameters_queue.cancel_join_thread()
|
||||
|
||||
logging.info("[ACTOR] queues closed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -225,7 +225,9 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Crop rectangular ROIs from a LeRobot dataset."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
@@ -247,7 +249,9 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
local_files_only = args.root is not None
|
||||
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=local_files_only)
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=args.repo_id, root=args.root, local_files_only=local_files_only
|
||||
)
|
||||
|
||||
images = get_image_from_lerobot_dataset(dataset)
|
||||
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
|
||||
@@ -256,14 +260,9 @@ if __name__ == "__main__":
|
||||
if args.crop_params_path is None:
|
||||
rois = select_square_roi_for_images(images)
|
||||
else:
|
||||
with open(args.crop_params_path, "r") as f:
|
||||
with open(args.crop_params_path) as f:
|
||||
rois = json.load(f)
|
||||
|
||||
# rois = {
|
||||
# "observation.images.front": [102, 43, 358, 523],
|
||||
# "observation.images.side": [92, 123, 379, 349],
|
||||
# }
|
||||
|
||||
# Print the selected rectangular ROIs
|
||||
print("\nSelected Rectangular Regions of Interest (top, left, height, width):")
|
||||
for key, roi in rois.items():
|
||||
|
||||
797
lerobot/scripts/server/end_effector_control_utils.py
Normal file
797
lerobot/scripts/server/end_effector_control_utils.py
Normal file
@@ -0,0 +1,797 @@
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||
import logging
|
||||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
import argparse
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
class InputController:
|
||||
"""Base class for input controllers that generate motion deltas."""
|
||||
|
||||
def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01):
|
||||
"""
|
||||
Initialize the controller.
|
||||
|
||||
Args:
|
||||
x_step_size: Base movement step size in meters
|
||||
y_step_size: Base movement step size in meters
|
||||
z_step_size: Base movement step size in meters
|
||||
"""
|
||||
self.x_step_size = x_step_size
|
||||
self.y_step_size = y_step_size
|
||||
self.z_step_size = z_step_size
|
||||
self.running = True
|
||||
self.episode_end_status = None # None, "success", or "failure"
|
||||
|
||||
def start(self):
|
||||
"""Start the controller and initialize resources."""
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
"""Stop the controller and release resources."""
|
||||
pass
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas (dx, dy, dz) in meters."""
|
||||
return 0.0, 0.0, 0.0
|
||||
|
||||
def should_quit(self):
|
||||
"""Return True if the user has requested to quit."""
|
||||
return not self.running
|
||||
|
||||
def update(self):
|
||||
"""Update controller state - call this once per frame."""
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
"""Support for use in 'with' statements."""
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Ensure resources are released when exiting 'with' block."""
|
||||
self.stop()
|
||||
|
||||
def get_episode_end_status(self):
|
||||
"""
|
||||
Get the current episode end status.
|
||||
|
||||
Returns:
|
||||
None if episode should continue, "success" or "failure" otherwise
|
||||
"""
|
||||
status = self.episode_end_status
|
||||
self.episode_end_status = None # Reset after reading
|
||||
return status
|
||||
|
||||
|
||||
class KeyboardController(InputController):
|
||||
"""Generate motion deltas from keyboard input."""
|
||||
|
||||
def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01):
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.key_states = {
|
||||
"forward_x": False,
|
||||
"backward_x": False,
|
||||
"forward_y": False,
|
||||
"backward_y": False,
|
||||
"forward_z": False,
|
||||
"backward_z": False,
|
||||
"quit": False,
|
||||
"success": False,
|
||||
"failure": False,
|
||||
}
|
||||
self.listener = None
|
||||
|
||||
def start(self):
|
||||
"""Start the keyboard listener."""
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if key == keyboard.Key.up:
|
||||
self.key_states["forward_x"] = True
|
||||
elif key == keyboard.Key.down:
|
||||
self.key_states["backward_x"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
self.key_states["forward_y"] = True
|
||||
elif key == keyboard.Key.right:
|
||||
self.key_states["backward_y"] = True
|
||||
elif key == keyboard.Key.shift:
|
||||
self.key_states["backward_z"] = True
|
||||
elif key == keyboard.Key.shift_r:
|
||||
self.key_states["forward_z"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
self.key_states["quit"] = True
|
||||
self.running = False
|
||||
return False
|
||||
elif key == keyboard.Key.enter:
|
||||
self.key_states["success"] = True
|
||||
self.episode_end_status = "success"
|
||||
elif key == keyboard.Key.backspace:
|
||||
self.key_states["failure"] = True
|
||||
self.episode_end_status = "failure"
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def on_release(key):
|
||||
try:
|
||||
if key == keyboard.Key.up:
|
||||
self.key_states["forward_x"] = False
|
||||
elif key == keyboard.Key.down:
|
||||
self.key_states["backward_x"] = False
|
||||
elif key == keyboard.Key.left:
|
||||
self.key_states["forward_y"] = False
|
||||
elif key == keyboard.Key.right:
|
||||
self.key_states["backward_y"] = False
|
||||
elif key == keyboard.Key.shift:
|
||||
self.key_states["backward_z"] = False
|
||||
elif key == keyboard.Key.shift_r:
|
||||
self.key_states["forward_z"] = False
|
||||
elif key == keyboard.Key.enter:
|
||||
self.key_states["success"] = False
|
||||
elif key == keyboard.Key.backspace:
|
||||
self.key_states["failure"] = False
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
self.listener = keyboard.Listener(on_press=on_press, on_release=on_release)
|
||||
self.listener.start()
|
||||
|
||||
print("Keyboard controls:")
|
||||
print(" Arrow keys: Move in X-Y plane")
|
||||
print(" Shift and Shift_R: Move in Z axis")
|
||||
print(" Enter: End episode with SUCCESS")
|
||||
print(" Backspace: End episode with FAILURE")
|
||||
print(" ESC: Exit")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the keyboard listener."""
|
||||
if self.listener and self.listener.is_alive():
|
||||
self.listener.stop()
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from keyboard state."""
|
||||
delta_x = delta_y = delta_z = 0.0
|
||||
|
||||
if self.key_states["forward_x"]:
|
||||
delta_x += self.x_step_size
|
||||
if self.key_states["backward_x"]:
|
||||
delta_x -= self.x_step_size
|
||||
if self.key_states["forward_y"]:
|
||||
delta_y += self.y_step_size
|
||||
if self.key_states["backward_y"]:
|
||||
delta_y -= self.y_step_size
|
||||
if self.key_states["forward_z"]:
|
||||
delta_z += self.z_step_size
|
||||
if self.key_states["backward_z"]:
|
||||
delta_z -= self.z_step_size
|
||||
|
||||
return delta_x, delta_y, delta_z
|
||||
|
||||
def should_quit(self):
|
||||
"""Return True if ESC was pressed."""
|
||||
return self.key_states["quit"]
|
||||
|
||||
def should_save(self):
|
||||
"""Return True if Enter was pressed (save episode)."""
|
||||
return self.key_states["success"] or self.key_states["failure"]
|
||||
|
||||
|
||||
class GamepadController(InputController):
|
||||
"""Generate motion deltas from gamepad input."""
|
||||
|
||||
def __init__(
|
||||
self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01, deadzone=0.1
|
||||
):
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.deadzone = deadzone
|
||||
self.joystick = None
|
||||
self.intervention_flag = False
|
||||
|
||||
def start(self):
|
||||
"""Initialize pygame and the gamepad."""
|
||||
import pygame
|
||||
|
||||
pygame.init()
|
||||
pygame.joystick.init()
|
||||
|
||||
if pygame.joystick.get_count() == 0:
|
||||
logging.error(
|
||||
"No gamepad detected. Please connect a gamepad and try again."
|
||||
)
|
||||
self.running = False
|
||||
return
|
||||
|
||||
self.joystick = pygame.joystick.Joystick(0)
|
||||
self.joystick.init()
|
||||
logging.info(f"Initialized gamepad: {self.joystick.get_name()}")
|
||||
|
||||
print("Gamepad controls:")
|
||||
print(" Left analog stick: Move in X-Y plane")
|
||||
print(" Right analog stick (vertical): Move in Z axis")
|
||||
print(" B/Circle button: Exit")
|
||||
print(" Y/Triangle button: End episode with SUCCESS")
|
||||
print(" A/Cross button: End episode with FAILURE")
|
||||
print(" X/Square button: Rerecord episode")
|
||||
|
||||
def stop(self):
|
||||
"""Clean up pygame resources."""
|
||||
import pygame
|
||||
|
||||
if pygame.joystick.get_init():
|
||||
if self.joystick:
|
||||
self.joystick.quit()
|
||||
pygame.joystick.quit()
|
||||
pygame.quit()
|
||||
|
||||
def update(self):
|
||||
"""Process pygame events to get fresh gamepad readings."""
|
||||
import pygame
|
||||
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.JOYBUTTONDOWN:
|
||||
if event.button == 3:
|
||||
self.episode_end_status = "success"
|
||||
# A button (1) for failure
|
||||
elif event.button == 1:
|
||||
self.episode_end_status = "failure"
|
||||
# X button (0) for rerecord
|
||||
elif event.button == 0:
|
||||
self.episode_end_status = "rerecord_episode"
|
||||
|
||||
# Reset episode status on button release
|
||||
elif event.type == pygame.JOYBUTTONUP:
|
||||
if event.button in [0, 2, 3]:
|
||||
self.episode_end_status = None
|
||||
|
||||
# Check for RB button (typically button 5) for intervention flag
|
||||
if self.joystick.get_button(5):
|
||||
self.intervention_flag = True
|
||||
else:
|
||||
self.intervention_flag = False
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from gamepad state."""
|
||||
import pygame
|
||||
|
||||
try:
|
||||
# Read joystick axes
|
||||
# Left stick X and Y (typically axes 0 and 1)
|
||||
x_input = self.joystick.get_axis(0) # Left/Right
|
||||
y_input = self.joystick.get_axis(1) # Up/Down (often inverted)
|
||||
|
||||
# Right stick Y (typically axis 3 or 4)
|
||||
z_input = self.joystick.get_axis(3) # Up/Down for Z
|
||||
|
||||
# Apply deadzone to avoid drift
|
||||
x_input = 0 if abs(x_input) < self.deadzone else x_input
|
||||
y_input = 0 if abs(y_input) < self.deadzone else y_input
|
||||
z_input = 0 if abs(z_input) < self.deadzone else z_input
|
||||
|
||||
# Calculate deltas (note: may need to invert axes depending on controller)
|
||||
delta_x = -y_input * self.y_step_size # Forward/backward
|
||||
delta_y = -x_input * self.x_step_size # Left/right
|
||||
delta_z = -z_input * self.z_step_size # Up/down
|
||||
|
||||
return delta_x, delta_y, delta_z
|
||||
|
||||
except pygame.error:
|
||||
logging.error("Error reading gamepad. Is it still connected?")
|
||||
return 0.0, 0.0, 0.0
|
||||
|
||||
def should_intervene(self):
|
||||
"""Return True if intervention flag was set."""
|
||||
return self.intervention_flag
|
||||
|
||||
|
||||
class GamepadControllerHID(InputController):
|
||||
"""Generate motion deltas from gamepad input using HIDAPI."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
x_step_size=0.01,
|
||||
y_step_size=0.01,
|
||||
z_step_size=0.01,
|
||||
deadzone=0.1,
|
||||
vendor_id=0x046D,
|
||||
product_id=0xC219,
|
||||
):
|
||||
"""
|
||||
Initialize the HID gamepad controller.
|
||||
|
||||
Args:
|
||||
step_size: Base movement step size in meters
|
||||
z_scale: Scaling factor for Z-axis movement
|
||||
deadzone: Joystick deadzone to prevent drift
|
||||
vendor_id: USB vendor ID of the gamepad (default: Logitech)
|
||||
product_id: USB product ID of the gamepad (default: RumblePad 2)
|
||||
"""
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.deadzone = deadzone
|
||||
self.vendor_id = vendor_id
|
||||
self.product_id = product_id
|
||||
self.device = None
|
||||
self.device_info = None
|
||||
|
||||
# Movement values (normalized from -1.0 to 1.0)
|
||||
self.left_x = 0.0
|
||||
self.left_y = 0.0
|
||||
self.right_x = 0.0
|
||||
self.right_y = 0.0
|
||||
|
||||
# Button states
|
||||
self.buttons = {}
|
||||
self.quit_requested = False
|
||||
self.save_requested = False
|
||||
self.intervention_flag = False
|
||||
|
||||
def find_device(self):
|
||||
"""Look for the gamepad device by vendor and product ID."""
|
||||
import hid
|
||||
|
||||
devices = hid.enumerate()
|
||||
for device in devices:
|
||||
if (
|
||||
device["vendor_id"] == self.vendor_id
|
||||
and device["product_id"] == self.product_id
|
||||
):
|
||||
logging.info(
|
||||
f"Found gamepad: {device.get('product_string', 'Unknown')}"
|
||||
)
|
||||
return device
|
||||
|
||||
logging.error(
|
||||
f"No gamepad with vendor ID 0x{self.vendor_id:04X} and "
|
||||
f"product ID 0x{self.product_id:04X} found"
|
||||
)
|
||||
return None
|
||||
|
||||
def start(self):
|
||||
"""Connect to the gamepad using HIDAPI."""
|
||||
import hid
|
||||
|
||||
self.device_info = self.find_device()
|
||||
if not self.device_info:
|
||||
self.running = False
|
||||
return
|
||||
|
||||
try:
|
||||
logging.info(f"Connecting to gamepad at path: {self.device_info['path']}")
|
||||
self.device = hid.device()
|
||||
self.device.open_path(self.device_info["path"])
|
||||
self.device.set_nonblocking(1)
|
||||
|
||||
manufacturer = self.device.get_manufacturer_string()
|
||||
product = self.device.get_product_string()
|
||||
logging.info(f"Connected to {manufacturer} {product}")
|
||||
|
||||
logging.info("Gamepad controls (HID mode):")
|
||||
logging.info(" Left analog stick: Move in X-Y plane")
|
||||
logging.info(" Right analog stick: Move in Z axis (vertical)")
|
||||
logging.info(" Button 1/B/Circle: Exit")
|
||||
logging.info(" Button 2/A/Cross: End episode with SUCCESS")
|
||||
logging.info(" Button 3/X/Square: End episode with FAILURE")
|
||||
|
||||
except OSError as e:
|
||||
logging.error(f"Error opening gamepad: {e}")
|
||||
logging.error(
|
||||
"You might need to run this with sudo/admin privileges on some systems"
|
||||
)
|
||||
self.running = False
|
||||
|
||||
def stop(self):
|
||||
"""Close the HID device connection."""
|
||||
if self.device:
|
||||
self.device.close()
|
||||
self.device = None
|
||||
|
||||
def update(self):
|
||||
"""
|
||||
Read and process the latest gamepad data.
|
||||
Due to an issue with the HIDAPI, we need to read the read the device several times in order to get a stable reading
|
||||
"""
|
||||
for _ in range(10):
|
||||
self._update()
|
||||
|
||||
def _update(self):
|
||||
"""Read and process the latest gamepad data."""
|
||||
if not self.device or not self.running:
|
||||
return
|
||||
|
||||
try:
|
||||
# Read data from the gamepad
|
||||
data = self.device.read(64)
|
||||
if data:
|
||||
# Interpret gamepad data - this will vary by controller model
|
||||
# These offsets are for the Logitech RumblePad 2
|
||||
if len(data) >= 8:
|
||||
# Normalize joystick values from 0-255 to -1.0-1.0
|
||||
self.left_x = (data[1] - 128) / 128.0
|
||||
self.left_y = (data[2] - 128) / 128.0
|
||||
self.right_x = (data[3] - 128) / 128.0
|
||||
self.right_y = (data[4] - 128) / 128.0
|
||||
|
||||
# Apply deadzone
|
||||
self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x
|
||||
self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y
|
||||
self.right_x = (
|
||||
0 if abs(self.right_x) < self.deadzone else self.right_x
|
||||
)
|
||||
self.right_y = (
|
||||
0 if abs(self.right_y) < self.deadzone else self.right_y
|
||||
)
|
||||
|
||||
# Parse button states (byte 5 in the Logitech RumblePad 2)
|
||||
buttons = data[5]
|
||||
|
||||
# Check if RB is pressed then the intervention flag should be set
|
||||
self.intervention_flag = data[6] == 2
|
||||
|
||||
# Check if Y/Triangle button (bit 7) is pressed for saving
|
||||
# Check if X/Square button (bit 5) is pressed for failure
|
||||
# Check if A/Cross button (bit 4) is pressed for rerecording
|
||||
if buttons & 1 << 7:
|
||||
self.episode_end_status = "success"
|
||||
elif buttons & 1 << 5:
|
||||
self.episode_end_status = "failure"
|
||||
elif buttons & 1 << 4:
|
||||
self.episode_end_status = "rerecord_episode"
|
||||
else:
|
||||
self.episode_end_status = None
|
||||
|
||||
except OSError as e:
|
||||
logging.error(f"Error reading from gamepad: {e}")
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from gamepad state."""
|
||||
# Calculate deltas - invert as needed based on controller orientation
|
||||
delta_x = -self.left_y * self.x_step_size # Forward/backward
|
||||
delta_y = -self.left_x * self.y_step_size # Left/right
|
||||
delta_z = -self.right_y * self.z_step_size # Up/down
|
||||
|
||||
return delta_x, delta_y, delta_z
|
||||
|
||||
def should_quit(self):
|
||||
"""Return True if quit button was pressed."""
|
||||
return self.quit_requested
|
||||
|
||||
def should_save(self):
|
||||
"""Return True if save button was pressed."""
|
||||
return self.save_requested
|
||||
|
||||
def should_intervene(self):
|
||||
"""Return True if intervention flag was set."""
|
||||
return self.intervention_flag
|
||||
|
||||
|
||||
def test_forward_kinematics(robot, fps=10):
|
||||
logging.info("Testing Forward Kinematics")
|
||||
timestep = time.perf_counter()
|
||||
while time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
robot.teleop_step()
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
ee_pos = RobotKinematics.fk_gripper_tip(joint_positions)
|
||||
logging.info(f"EE Position: {ee_pos[:3,3]}")
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
|
||||
def test_inverse_kinematics(robot, fps=10):
|
||||
logging.info("Testing Inverse Kinematics")
|
||||
timestep = time.perf_counter()
|
||||
while time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
ee_pos = RobotKinematics.fk_gripper_tip(joint_positions)
|
||||
desired_ee_pos = ee_pos
|
||||
target_joint_state = RobotKinematics.ik(
|
||||
joint_positions, desired_ee_pos, position_only=True
|
||||
)
|
||||
robot.send_action(torch.from_numpy(target_joint_state))
|
||||
logging.info(f"Target Joint State: {target_joint_state}")
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
|
||||
def teleoperate_inverse_kinematics_with_leader(robot, fps=10):
|
||||
logging.info("Testing Inverse Kinematics")
|
||||
fk_func = RobotKinematics.fk_gripper_tip
|
||||
timestep = time.perf_counter()
|
||||
while time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
ee_pos = fk_func(joint_positions)
|
||||
|
||||
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
|
||||
leader_ee = fk_func(leader_joint_positions)
|
||||
|
||||
desired_ee_pos = leader_ee
|
||||
target_joint_state = RobotKinematics.ik(
|
||||
joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func
|
||||
)
|
||||
robot.send_action(torch.from_numpy(target_joint_state))
|
||||
logging.info(f"Leader EE: {leader_ee[:3,3]}, Follower EE: {ee_pos[:3,3]}")
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
|
||||
def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
|
||||
logging.info("Testing Delta End-Effector Control")
|
||||
timestep = time.perf_counter()
|
||||
|
||||
# Initial position capture
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
|
||||
fk_func = RobotKinematics.fk_gripper_tip
|
||||
|
||||
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
|
||||
initial_leader_ee = fk_func(leader_joint_positions)
|
||||
|
||||
desired_ee_pos = np.diag(np.ones(4))
|
||||
|
||||
while time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
|
||||
# Get leader state for teleoperation
|
||||
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
|
||||
leader_ee = fk_func(leader_joint_positions)
|
||||
|
||||
# Get current state
|
||||
# obs = robot.capture_observation()
|
||||
# joint_positions = obs["observation.state"].cpu().numpy()
|
||||
joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
current_ee_pos = fk_func(joint_positions)
|
||||
|
||||
# Calculate delta between leader and follower end-effectors
|
||||
# Scaling factor can be adjusted for sensitivity
|
||||
scaling_factor = 1.0
|
||||
ee_delta = (leader_ee - initial_leader_ee) * scaling_factor
|
||||
|
||||
# Apply delta to current position
|
||||
desired_ee_pos[0, 3] = current_ee_pos[0, 3] + ee_delta[0, 3]
|
||||
desired_ee_pos[1, 3] = current_ee_pos[1, 3] + ee_delta[1, 3]
|
||||
desired_ee_pos[2, 3] = current_ee_pos[2, 3] + ee_delta[2, 3]
|
||||
|
||||
if np.any(np.abs(ee_delta[:3, 3]) > 0.01):
|
||||
# Compute joint targets via inverse kinematics
|
||||
target_joint_state = RobotKinematics.ik(
|
||||
joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func
|
||||
)
|
||||
|
||||
initial_leader_ee = leader_ee.copy()
|
||||
|
||||
# Send command to robot
|
||||
robot.send_action(torch.from_numpy(target_joint_state))
|
||||
|
||||
# Logging
|
||||
logging.info(
|
||||
f"Current EE: {current_ee_pos[:3,3]}, Desired EE: {desired_ee_pos[:3,3]}"
|
||||
)
|
||||
logging.info(f"Delta EE: {ee_delta[:3,3]}")
|
||||
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
|
||||
def teleoperate_delta_inverse_kinematics(
|
||||
robot, controller, fps=10, bounds=None, fk_func=None
|
||||
):
|
||||
"""
|
||||
Control a robot using delta end-effector movements from any input controller.
|
||||
|
||||
Args:
|
||||
robot: Robot instance to control
|
||||
controller: InputController instance (keyboard, gamepad, etc.)
|
||||
fps: Control frequency in Hz
|
||||
bounds: Optional position limits
|
||||
fk_func: Forward kinematics function to use
|
||||
"""
|
||||
if fk_func is None:
|
||||
fk_func = RobotKinematics.fk_gripper_tip
|
||||
|
||||
logging.info(
|
||||
f"Testing Delta End-Effector Control with {controller.__class__.__name__}"
|
||||
)
|
||||
|
||||
# Initial position capture
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
current_ee_pos = fk_func(joint_positions)
|
||||
|
||||
# Initialize desired position with current position
|
||||
desired_ee_pos = np.eye(4) # Identity matrix
|
||||
|
||||
timestep = time.perf_counter()
|
||||
with controller:
|
||||
while not controller.should_quit() and time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
|
||||
# Process input events
|
||||
controller.update()
|
||||
|
||||
# Get currrent robot state
|
||||
joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
current_ee_pos = fk_func(joint_positions)
|
||||
|
||||
# Get movement deltas from the controller
|
||||
delta_x, delta_y, delta_z = controller.get_deltas()
|
||||
|
||||
# Update desired position
|
||||
desired_ee_pos[0, 3] = current_ee_pos[0, 3] + delta_x
|
||||
desired_ee_pos[1, 3] = current_ee_pos[1, 3] + delta_y
|
||||
desired_ee_pos[2, 3] = current_ee_pos[2, 3] + delta_z
|
||||
|
||||
# Apply bounds if provided
|
||||
if bounds is not None:
|
||||
desired_ee_pos[:3, 3] = np.clip(
|
||||
desired_ee_pos[:3, 3], bounds["min"], bounds["max"]
|
||||
)
|
||||
|
||||
# Only send commands if there's actual movement
|
||||
if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]):
|
||||
# Compute joint targets via inverse kinematics
|
||||
target_joint_state = RobotKinematics.ik(
|
||||
joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func
|
||||
)
|
||||
|
||||
# Send command to robot
|
||||
robot.send_action(torch.from_numpy(target_joint_state))
|
||||
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
|
||||
def teleoperate_gym_env(env, controller, fps: int = 30):
|
||||
"""
|
||||
Control a robot through a gym environment using keyboard inputs.
|
||||
|
||||
Args:
|
||||
env: A gym environment created with make_robot_env
|
||||
fps: Target control frequency
|
||||
"""
|
||||
|
||||
logging.info("Testing Keyboard Control of Gym Environment")
|
||||
print("Keyboard controls:")
|
||||
print(" Arrow keys: Move in X-Y plane")
|
||||
print(" Shift and Shift_R: Move in Z axis")
|
||||
print(" ESC: Exit")
|
||||
|
||||
# Reset the environment to get initial observation
|
||||
obs, info = env.reset()
|
||||
|
||||
try:
|
||||
with controller:
|
||||
while not controller.should_quit():
|
||||
loop_start_time = time.perf_counter()
|
||||
|
||||
# Process input events
|
||||
controller.update()
|
||||
|
||||
# Get movement deltas from the controller
|
||||
delta_x, delta_y, delta_z = controller.get_deltas()
|
||||
|
||||
# Create the action vector
|
||||
action = np.array([delta_x, delta_y, delta_z])
|
||||
|
||||
# Skip if no movement
|
||||
if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]):
|
||||
# Step the environment - pass action as a tensor with intervention flag
|
||||
action_tensor = torch.from_numpy(action.astype(np.float32))
|
||||
obs, reward, terminated, truncated, info = env.step(
|
||||
(action_tensor, False)
|
||||
)
|
||||
|
||||
# Log information
|
||||
logging.info(
|
||||
f"Action: [{delta_x:.4f}, {delta_y:.4f}, {delta_z:.4f}]"
|
||||
)
|
||||
logging.info(f"Reward: {reward}")
|
||||
|
||||
# Reset if episode ended
|
||||
if terminated or truncated:
|
||||
logging.info("Episode ended, resetting environment")
|
||||
obs, info = env.reset()
|
||||
|
||||
# Maintain target frame rate
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
finally:
|
||||
# Close the environment
|
||||
env.close()
|
||||
|
||||
|
||||
def make_robot_from_config(config_path, overrides=None):
|
||||
"""Helper function to create a robot from a config file."""
|
||||
if overrides is None:
|
||||
overrides = []
|
||||
robot_cfg = init_hydra_config(config_path, overrides)
|
||||
return make_robot(robot_cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Test end-effector control")
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
default="keyboard",
|
||||
choices=[
|
||||
"keyboard",
|
||||
"gamepad",
|
||||
"keyboard_gym",
|
||||
"gamepad_gym",
|
||||
"leader",
|
||||
"leader_abs",
|
||||
],
|
||||
help="Control mode to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default="Robot manipulation task",
|
||||
help="Description of the task being performed",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
default=True,
|
||||
type=bool,
|
||||
help="Push the dataset to Hugging Face Hub",
|
||||
)
|
||||
# Add the rest of your existing arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
robot = make_robot_from_config("lerobot/configs/robot/so100.yaml", [])
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
# Example bounds
|
||||
bounds = {
|
||||
"max": np.array([0.32170487, 0.201285, 0.10273342]),
|
||||
"min": np.array([0.16631757, -0.08237468, 0.03364977]),
|
||||
}
|
||||
|
||||
try:
|
||||
# Determine controller type based on mode prefix
|
||||
controller = None
|
||||
if args.mode.startswith("keyboard"):
|
||||
controller = KeyboardController(
|
||||
x_step_size=0.01, y_step_size=0.01, z_step_size=0.05
|
||||
)
|
||||
elif args.mode.startswith("gamepad"):
|
||||
controller = GamepadController(
|
||||
x_step_size=0.02, y_step_size=0.02, z_step_size=0.05
|
||||
)
|
||||
|
||||
# Handle mode categories
|
||||
if args.mode in ["keyboard", "gamepad"]:
|
||||
# Direct robot control modes
|
||||
teleoperate_delta_inverse_kinematics(
|
||||
robot, controller, bounds=bounds, fps=10
|
||||
)
|
||||
|
||||
elif args.mode in ["keyboard_gym", "gamepad_gym"]:
|
||||
# Gym environment control modes
|
||||
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
||||
|
||||
cfg = init_hydra_config("lerobot/configs/env/so100_real.yaml", [])
|
||||
cfg.env.wrapper.ee_action_space_params.use_gamepad = False
|
||||
env = make_robot_env(robot, None, cfg)
|
||||
teleoperate_gym_env(env, controller)
|
||||
|
||||
elif args.mode == "leader":
|
||||
# Leader-follower modes don't use controllers
|
||||
teleoperate_delta_inverse_kinematics_with_leader(robot)
|
||||
|
||||
elif args.mode == "leader_abs":
|
||||
teleoperate_inverse_kinematics_with_leader(robot)
|
||||
|
||||
finally:
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
@@ -7,35 +7,37 @@ import numpy as np
|
||||
from lerobot.common.robot_devices.control_utils import is_headless
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||
|
||||
|
||||
def find_joint_bounds(
|
||||
robot,
|
||||
control_time_s=20,
|
||||
control_time_s=30,
|
||||
display_cameras=False,
|
||||
):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
control_time_s = float("inf")
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
pos_list = []
|
||||
while timestamp < control_time_s:
|
||||
while True:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
|
||||
# Wait for 5 seconds to stabilize the robot initial position
|
||||
if time.perf_counter() - start_episode_t < 5:
|
||||
continue
|
||||
|
||||
pos_list.append(robot.follower_arms["main"].read("Present_Position"))
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.imshow(
|
||||
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
|
||||
)
|
||||
cv2.waitKey(1)
|
||||
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
if timestamp > 60:
|
||||
if time.perf_counter() - start_episode_t > control_time_s:
|
||||
max = np.max(np.stack(pos_list), 0)
|
||||
min = np.min(np.stack(pos_list), 0)
|
||||
print(f"Max angle position per joint {max}")
|
||||
@@ -43,6 +45,43 @@ def find_joint_bounds(
|
||||
break
|
||||
|
||||
|
||||
def find_ee_bounds(
|
||||
robot,
|
||||
control_time_s=30,
|
||||
display_cameras=False,
|
||||
):
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
start_episode_t = time.perf_counter()
|
||||
ee_list = []
|
||||
while True:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
|
||||
# Wait for 5 seconds to stabilize the robot initial position
|
||||
if time.perf_counter() - start_episode_t < 5:
|
||||
continue
|
||||
|
||||
joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
print(f"Joint positions: {joint_positions}")
|
||||
ee_list.append(RobotKinematics.fk_gripper_tip(joint_positions)[:3, 3])
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(
|
||||
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
|
||||
)
|
||||
cv2.waitKey(1)
|
||||
|
||||
if time.perf_counter() - start_episode_t > control_time_s:
|
||||
max = np.max(np.stack(ee_list), 0)
|
||||
min = np.min(np.stack(ee_list), 0)
|
||||
print(f"Max ee position {max}")
|
||||
print(f"Min ee position {min}")
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@@ -57,9 +96,26 @@ if __name__ == "__main__":
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds")
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
default="joint",
|
||||
choices=["joint", "ee"],
|
||||
help="Mode to run the script in. Can be 'joint' or 'ee'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--control-time-s",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Time step to use for control.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
|
||||
|
||||
robot = make_robot(robot_cfg)
|
||||
find_joint_bounds(robot, control_time_s=args.control_time_s)
|
||||
if args.mode == "joint":
|
||||
find_joint_bounds(robot, args.control_time_s)
|
||||
elif args.mode == "ee":
|
||||
find_ee_bounds(robot, args.control_time_s)
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -22,37 +22,34 @@ package hil_serl;
|
||||
// The Learner implements this service.
|
||||
service LearnerService {
|
||||
// Actor -> Learner to store transitions
|
||||
rpc SendTransition(Transition) returns (Empty);
|
||||
rpc SendInteractionMessage(InteractionMessage) returns (Empty);
|
||||
rpc SendInteractionMessage(InteractionMessage) returns (Empty);
|
||||
rpc StreamParameters(Empty) returns (stream Parameters);
|
||||
rpc SendTransitions(stream Transition) returns (Empty);
|
||||
rpc SendInteractions(stream InteractionMessage) returns (Empty);
|
||||
rpc Ready(Empty) returns (Empty);
|
||||
}
|
||||
|
||||
// ActorService: the Learner calls this to push parameters.
|
||||
// The Actor implements this service.
|
||||
service ActorService {
|
||||
// Learner -> Actor to send new parameters
|
||||
rpc StreamTransition(Empty) returns (stream ActorInformation) {};
|
||||
rpc SendParameters(Parameters) returns (Empty);
|
||||
}
|
||||
|
||||
|
||||
message ActorInformation {
|
||||
oneof data {
|
||||
Transition transition = 1;
|
||||
InteractionMessage interaction_message = 2;
|
||||
}
|
||||
enum TransferState {
|
||||
TRANSFER_UNKNOWN = 0;
|
||||
TRANSFER_BEGIN = 1;
|
||||
TRANSFER_MIDDLE = 2;
|
||||
TRANSFER_END = 3;
|
||||
}
|
||||
|
||||
// Messages
|
||||
message Transition {
|
||||
bytes transition_bytes = 1;
|
||||
TransferState transfer_state = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message Parameters {
|
||||
bytes parameter_bytes = 1;
|
||||
TransferState transfer_state = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message InteractionMessage {
|
||||
bytes interaction_message_bytes = 1;
|
||||
TransferState transfer_state = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
message Empty {}
|
||||
|
||||
@@ -24,25 +24,23 @@ _sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"\x83\x01\n\x10\x41\x63torInformation\x12*\n\ntransition\x18\x01 \x01(\x0b\x32\x14.hil_serl.TransitionH\x00\x12;\n\x13interaction_message\x18\x02 \x01(\x0b\x32\x1c.hil_serl.InteractionMessageH\x00\x42\x06\n\x04\x64\x61ta\"&\n\nTransition\x12\x18\n\x10transition_bytes\x18\x01 \x01(\x0c\"%\n\nParameters\x12\x17\n\x0fparameter_bytes\x18\x01 \x01(\x0c\"7\n\x12InteractionMessage\x12!\n\x19interaction_message_bytes\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty2\x92\x01\n\x0eLearnerService\x12\x37\n\x0eSendTransition\x12\x14.hil_serl.Transition\x1a\x0f.hil_serl.Empty\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty2\x8c\x01\n\x0c\x41\x63torService\x12\x43\n\x10StreamTransition\x12\x0f.hil_serl.Empty\x1a\x1a.hil_serl.ActorInformation\"\x00\x30\x01\x12\x37\n\x0eSendParameters\x12\x14.hil_serl.Parameters\x1a\x0f.hil_serl.Emptyb\x06proto3')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"K\n\nTransition\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"K\n\nParameters\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"S\n\x12InteractionMessage\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xc2\x02\n\x0eLearnerService\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty\x12;\n\x10StreamParameters\x12\x0f.hil_serl.Empty\x1a\x14.hil_serl.Parameters0\x01\x12:\n\x0fSendTransitions\x12\x14.hil_serl.Transition\x1a\x0f.hil_serl.Empty(\x01\x12\x43\n\x10SendInteractions\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty(\x01\x12)\n\x05Ready\x12\x0f.hil_serl.Empty\x1a\x0f.hil_serl.Emptyb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'hilserl_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_ACTORINFORMATION']._serialized_start=28
|
||||
_globals['_ACTORINFORMATION']._serialized_end=159
|
||||
_globals['_TRANSITION']._serialized_start=161
|
||||
_globals['_TRANSITION']._serialized_end=199
|
||||
_globals['_PARAMETERS']._serialized_start=201
|
||||
_globals['_PARAMETERS']._serialized_end=238
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_start=240
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_end=295
|
||||
_globals['_EMPTY']._serialized_start=297
|
||||
_globals['_EMPTY']._serialized_end=304
|
||||
_globals['_LEARNERSERVICE']._serialized_start=307
|
||||
_globals['_LEARNERSERVICE']._serialized_end=453
|
||||
_globals['_ACTORSERVICE']._serialized_start=456
|
||||
_globals['_ACTORSERVICE']._serialized_end=596
|
||||
_globals['_TRANSFERSTATE']._serialized_start=275
|
||||
_globals['_TRANSFERSTATE']._serialized_end=371
|
||||
_globals['_TRANSITION']._serialized_start=27
|
||||
_globals['_TRANSITION']._serialized_end=102
|
||||
_globals['_PARAMETERS']._serialized_start=104
|
||||
_globals['_PARAMETERS']._serialized_end=179
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_start=181
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_end=264
|
||||
_globals['_EMPTY']._serialized_start=266
|
||||
_globals['_EMPTY']._serialized_end=273
|
||||
_globals['_LEARNERSERVICE']._serialized_start=374
|
||||
_globals['_LEARNERSERVICE']._serialized_end=696
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
@@ -36,16 +36,31 @@ class LearnerServiceStub(object):
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.SendTransition = channel.unary_unary(
|
||||
'/hil_serl.LearnerService/SendTransition',
|
||||
request_serializer=hilserl__pb2.Transition.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.SendInteractionMessage = channel.unary_unary(
|
||||
'/hil_serl.LearnerService/SendInteractionMessage',
|
||||
request_serializer=hilserl__pb2.InteractionMessage.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.StreamParameters = channel.unary_stream(
|
||||
'/hil_serl.LearnerService/StreamParameters',
|
||||
request_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.Parameters.FromString,
|
||||
_registered_method=True)
|
||||
self.SendTransitions = channel.stream_unary(
|
||||
'/hil_serl.LearnerService/SendTransitions',
|
||||
request_serializer=hilserl__pb2.Transition.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.SendInteractions = channel.stream_unary(
|
||||
'/hil_serl.LearnerService/SendInteractions',
|
||||
request_serializer=hilserl__pb2.InteractionMessage.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.Ready = channel.unary_unary(
|
||||
'/hil_serl.LearnerService/Ready',
|
||||
request_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class LearnerServiceServicer(object):
|
||||
@@ -53,14 +68,32 @@ class LearnerServiceServicer(object):
|
||||
The Learner implements this service.
|
||||
"""
|
||||
|
||||
def SendTransition(self, request, context):
|
||||
def SendInteractionMessage(self, request, context):
|
||||
"""Actor -> Learner to store transitions
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendInteractionMessage(self, request, context):
|
||||
def StreamParameters(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendTransitions(self, request_iterator, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendInteractions(self, request_iterator, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Ready(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
@@ -69,16 +102,31 @@ class LearnerServiceServicer(object):
|
||||
|
||||
def add_LearnerServiceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'SendTransition': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendTransition,
|
||||
request_deserializer=hilserl__pb2.Transition.FromString,
|
||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'SendInteractionMessage': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendInteractionMessage,
|
||||
request_deserializer=hilserl__pb2.InteractionMessage.FromString,
|
||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'StreamParameters': grpc.unary_stream_rpc_method_handler(
|
||||
servicer.StreamParameters,
|
||||
request_deserializer=hilserl__pb2.Empty.FromString,
|
||||
response_serializer=hilserl__pb2.Parameters.SerializeToString,
|
||||
),
|
||||
'SendTransitions': grpc.stream_unary_rpc_method_handler(
|
||||
servicer.SendTransitions,
|
||||
request_deserializer=hilserl__pb2.Transition.FromString,
|
||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'SendInteractions': grpc.stream_unary_rpc_method_handler(
|
||||
servicer.SendInteractions,
|
||||
request_deserializer=hilserl__pb2.InteractionMessage.FromString,
|
||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'Ready': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Ready,
|
||||
request_deserializer=hilserl__pb2.Empty.FromString,
|
||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'hil_serl.LearnerService', rpc_method_handlers)
|
||||
@@ -92,33 +140,6 @@ class LearnerService(object):
|
||||
The Learner implements this service.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def SendTransition(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/hil_serl.LearnerService/SendTransition',
|
||||
hilserl__pb2.Transition.SerializeToString,
|
||||
hilserl__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendInteractionMessage(request,
|
||||
target,
|
||||
@@ -146,76 +167,8 @@ class LearnerService(object):
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class ActorServiceStub(object):
|
||||
"""ActorService: the Learner calls this to push parameters.
|
||||
The Actor implements this service.
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.StreamTransition = channel.unary_stream(
|
||||
'/hil_serl.ActorService/StreamTransition',
|
||||
request_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.ActorInformation.FromString,
|
||||
_registered_method=True)
|
||||
self.SendParameters = channel.unary_unary(
|
||||
'/hil_serl.ActorService/SendParameters',
|
||||
request_serializer=hilserl__pb2.Parameters.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class ActorServiceServicer(object):
|
||||
"""ActorService: the Learner calls this to push parameters.
|
||||
The Actor implements this service.
|
||||
"""
|
||||
|
||||
def StreamTransition(self, request, context):
|
||||
"""Learner -> Actor to send new parameters
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendParameters(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_ActorServiceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'StreamTransition': grpc.unary_stream_rpc_method_handler(
|
||||
servicer.StreamTransition,
|
||||
request_deserializer=hilserl__pb2.Empty.FromString,
|
||||
response_serializer=hilserl__pb2.ActorInformation.SerializeToString,
|
||||
),
|
||||
'SendParameters': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendParameters,
|
||||
request_deserializer=hilserl__pb2.Parameters.FromString,
|
||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'hil_serl.ActorService', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('hil_serl.ActorService', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class ActorService(object):
|
||||
"""ActorService: the Learner calls this to push parameters.
|
||||
The Actor implements this service.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def StreamTransition(request,
|
||||
def StreamParameters(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
@@ -228,9 +181,9 @@ class ActorService(object):
|
||||
return grpc.experimental.unary_stream(
|
||||
request,
|
||||
target,
|
||||
'/hil_serl.ActorService/StreamTransition',
|
||||
'/hil_serl.LearnerService/StreamParameters',
|
||||
hilserl__pb2.Empty.SerializeToString,
|
||||
hilserl__pb2.ActorInformation.FromString,
|
||||
hilserl__pb2.Parameters.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
@@ -242,7 +195,61 @@ class ActorService(object):
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendParameters(request,
|
||||
def SendTransitions(request_iterator,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.stream_unary(
|
||||
request_iterator,
|
||||
target,
|
||||
'/hil_serl.LearnerService/SendTransitions',
|
||||
hilserl__pb2.Transition.SerializeToString,
|
||||
hilserl__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendInteractions(request_iterator,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.stream_unary(
|
||||
request_iterator,
|
||||
target,
|
||||
'/hil_serl.LearnerService/SendInteractions',
|
||||
hilserl__pb2.InteractionMessage.SerializeToString,
|
||||
hilserl__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def Ready(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
@@ -255,8 +262,8 @@ class ActorService(object):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/hil_serl.ActorService/SendParameters',
|
||||
hilserl__pb2.Parameters.SerializeToString,
|
||||
'/hil_serl.LearnerService/Ready',
|
||||
hilserl__pb2.Empty.SerializeToString,
|
||||
hilserl__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
|
||||
543
lerobot/scripts/server/kinematics.py
Normal file
543
lerobot/scripts/server/kinematics.py
Normal file
@@ -0,0 +1,543 @@
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
|
||||
def skew_symmetric(w):
|
||||
"""Creates the skew-symmetric matrix from a 3D vector."""
|
||||
return np.array([[0, -w[2], w[1]], [w[2], 0, -w[0]], [-w[1], w[0], 0]])
|
||||
|
||||
|
||||
def rodrigues_rotation(w, theta):
|
||||
"""Computes the rotation matrix using Rodrigues' formula."""
|
||||
w_hat = skew_symmetric(w)
|
||||
return np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
|
||||
|
||||
|
||||
def screw_axis_to_transform(S, theta):
|
||||
"""Converts a screw axis to a 4x4 transformation matrix."""
|
||||
S_w = S[:3]
|
||||
S_v = S[3:]
|
||||
if np.allclose(S_w, 0) and np.linalg.norm(S_v) == 1: # Pure translation
|
||||
T = np.eye(4)
|
||||
T[:3, 3] = S_v * theta
|
||||
elif np.linalg.norm(S_w) == 1: # Rotation and translation
|
||||
w_hat = skew_symmetric(S_w)
|
||||
R = np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
|
||||
t = (
|
||||
np.eye(3) * theta
|
||||
+ (1 - np.cos(theta)) * w_hat
|
||||
+ (theta - np.sin(theta)) * w_hat @ w_hat
|
||||
) @ S_v
|
||||
T = np.eye(4)
|
||||
T[:3, :3] = R
|
||||
T[:3, 3] = t
|
||||
else:
|
||||
raise ValueError("Invalid screw axis parameters")
|
||||
return T
|
||||
|
||||
|
||||
def pose_difference_se3(pose1, pose2):
|
||||
"""
|
||||
Calculates the SE(3) difference between two 4x4 homogeneous transformation matrices.
|
||||
|
||||
pose1 - pose2
|
||||
|
||||
Args:
|
||||
pose1: A 4x4 numpy array representing the first pose.
|
||||
pose2: A 4x4 numpy array representing the second pose.
|
||||
|
||||
Returns:
|
||||
A tuple (translation_diff, rotation_diff) where:
|
||||
- translation_diff is a 3x1 numpy array representing the translational difference.
|
||||
- rotation_diff is a 3x1 numpy array representing the rotational difference in axis-angle representation.
|
||||
"""
|
||||
|
||||
# Extract rotation matrices from poses
|
||||
R1 = pose1[:3, :3]
|
||||
R2 = pose2[:3, :3]
|
||||
|
||||
# Calculate translational difference
|
||||
translation_diff = pose1[:3, 3] - pose2[:3, 3]
|
||||
|
||||
# Calculate rotational difference using scipy's Rotation library
|
||||
R_diff = Rotation.from_matrix(R1 @ R2.T)
|
||||
rotation_diff = R_diff.as_rotvec() # Convert to axis-angle representation
|
||||
|
||||
return np.concatenate([translation_diff, rotation_diff])
|
||||
|
||||
|
||||
def se3_error(target_pose, current_pose):
|
||||
pos_error = target_pose[:3, 3] - current_pose[:3, 3]
|
||||
R_target = target_pose[:3, :3]
|
||||
R_current = current_pose[:3, :3]
|
||||
R_error = R_target @ R_current.T
|
||||
rot_error = Rotation.from_matrix(R_error).as_rotvec()
|
||||
return np.concatenate([pos_error, rot_error])
|
||||
|
||||
|
||||
class RobotKinematics:
|
||||
"""Robot kinematics class supporting multiple robot models."""
|
||||
|
||||
# Robot measurements dictionary
|
||||
ROBOT_MEASUREMENTS = {
|
||||
"koch": {
|
||||
"gripper": [0.239, -0.001, 0.024],
|
||||
"wrist": [0.209, 0, 0.024],
|
||||
"forearm": [0.108, 0, 0.02],
|
||||
"humerus": [0, 0, 0.036],
|
||||
"shoulder": [0, 0, 0],
|
||||
"base": [0, 0, 0.02],
|
||||
},
|
||||
"so100": {
|
||||
"gripper": [0.320, 0, 0.050],
|
||||
"wrist": [0.278, 0, 0.050],
|
||||
"forearm": [0.143, 0, 0.044],
|
||||
"humerus": [0.031, 0, 0.072],
|
||||
"shoulder": [0, 0, 0],
|
||||
"base": [0, 0, 0.02],
|
||||
},
|
||||
"moss": {
|
||||
"gripper": [0.246, 0.013, 0.111],
|
||||
"wrist": [0.245, 0.002, 0.064],
|
||||
"forearm": [0.122, 0, 0.064],
|
||||
"humerus": [0.001, 0.001, 0.063],
|
||||
"shoulder": [0, 0, 0],
|
||||
"base": [0, 0, 0.02],
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, robot_type="so100"):
|
||||
"""Initialize kinematics for the specified robot type.
|
||||
|
||||
Args:
|
||||
robot_type: String specifying the robot model ("koch", "so100", or "moss")
|
||||
"""
|
||||
if robot_type not in self.ROBOT_MEASUREMENTS:
|
||||
raise ValueError(
|
||||
f"Unknown robot type: {robot_type}. Available types: {list(self.ROBOT_MEASUREMENTS.keys())}"
|
||||
)
|
||||
|
||||
self.robot_type = robot_type
|
||||
self.measurements = self.ROBOT_MEASUREMENTS[robot_type]
|
||||
|
||||
# Initialize all transformation matrices and screw axes
|
||||
self._setup_transforms()
|
||||
|
||||
def _create_translation_matrix(self, x=0, y=0, z=0):
|
||||
"""Create a 4x4 translation matrix."""
|
||||
return np.array([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z], [0, 0, 0, 1]])
|
||||
|
||||
def _setup_transforms(self):
|
||||
"""Setup all transformation matrices and screw axes for the robot."""
|
||||
# Set up rotation matrices (constant across robot types)
|
||||
|
||||
# Gripper orientation
|
||||
self.gripper_X0 = np.array(
|
||||
[
|
||||
[1, 0, 0, 0],
|
||||
[0, 0, 1, 0],
|
||||
[0, -1, 0, 0],
|
||||
[0, 0, 0, 1],
|
||||
]
|
||||
)
|
||||
|
||||
# Wrist orientation
|
||||
self.wrist_X0 = np.array(
|
||||
[
|
||||
[0, -1, 0, 0],
|
||||
[1, 0, 0, 0],
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 0, 1],
|
||||
]
|
||||
)
|
||||
|
||||
# Base orientation
|
||||
self.base_X0 = np.array(
|
||||
[
|
||||
[0, 0, 1, 0],
|
||||
[1, 0, 0, 0],
|
||||
[0, 1, 0, 0],
|
||||
[0, 0, 0, 1],
|
||||
]
|
||||
)
|
||||
|
||||
# Gripper
|
||||
# Screw axis of gripper frame wrt base frame
|
||||
self.S_BG = np.array(
|
||||
[
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
self.measurements["gripper"][2],
|
||||
-self.measurements["gripper"][1],
|
||||
]
|
||||
)
|
||||
|
||||
# Gripper origin to centroid transform
|
||||
self.X_GoGc = self._create_translation_matrix(x=0.07)
|
||||
|
||||
# Gripper origin to tip transform
|
||||
self.X_GoGt = self._create_translation_matrix(x=0.12)
|
||||
|
||||
# 0-position gripper frame pose wrt base
|
||||
self.X_BoGo = self._create_translation_matrix(
|
||||
x=self.measurements["gripper"][0],
|
||||
y=self.measurements["gripper"][1],
|
||||
z=self.measurements["gripper"][2],
|
||||
)
|
||||
|
||||
# Wrist
|
||||
# Screw axis of wrist frame wrt base frame
|
||||
self.S_BR = np.array(
|
||||
[0, 1, 0, -self.measurements["wrist"][2], 0, self.measurements["wrist"][0]]
|
||||
)
|
||||
|
||||
# 0-position origin to centroid transform
|
||||
self.X_RoRc = self._create_translation_matrix(x=0.0035, y=-0.002)
|
||||
|
||||
# 0-position wrist frame pose wrt base
|
||||
self.X_BR = self._create_translation_matrix(
|
||||
x=self.measurements["wrist"][0],
|
||||
y=self.measurements["wrist"][1],
|
||||
z=self.measurements["wrist"][2],
|
||||
)
|
||||
|
||||
# Forearm
|
||||
# Screw axis of forearm frame wrt base frame
|
||||
self.S_BF = np.array(
|
||||
[
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
-self.measurements["forearm"][2],
|
||||
0,
|
||||
self.measurements["forearm"][0],
|
||||
]
|
||||
)
|
||||
|
||||
# Forearm origin + centroid transform
|
||||
self.X_FoFc = self._create_translation_matrix(x=0.036)
|
||||
|
||||
# 0-position forearm frame pose wrt base
|
||||
self.X_BF = self._create_translation_matrix(
|
||||
x=self.measurements["forearm"][0],
|
||||
y=self.measurements["forearm"][1],
|
||||
z=self.measurements["forearm"][2],
|
||||
)
|
||||
|
||||
# Humerus
|
||||
# Screw axis of humerus frame wrt base frame
|
||||
self.S_BH = np.array(
|
||||
[
|
||||
0,
|
||||
-1,
|
||||
0,
|
||||
self.measurements["humerus"][2],
|
||||
0,
|
||||
-self.measurements["humerus"][0],
|
||||
]
|
||||
)
|
||||
|
||||
# Humerus origin to centroid transform
|
||||
self.X_HoHc = self._create_translation_matrix(x=0.0475)
|
||||
|
||||
# 0-position humerus frame pose wrt base
|
||||
self.X_BH = self._create_translation_matrix(
|
||||
x=self.measurements["humerus"][0],
|
||||
y=self.measurements["humerus"][1],
|
||||
z=self.measurements["humerus"][2],
|
||||
)
|
||||
|
||||
# Shoulder
|
||||
# Screw axis of shoulder frame wrt Base frame
|
||||
self.S_BS = np.array([0, 0, -1, 0, 0, 0])
|
||||
|
||||
# Shoulder origin to centroid transform
|
||||
self.X_SoSc = self._create_translation_matrix(x=-0.017, z=0.0235)
|
||||
|
||||
# 0-position shoulder frame pose wrt base
|
||||
self.X_BS = self._create_translation_matrix(
|
||||
x=self.measurements["shoulder"][0],
|
||||
y=self.measurements["shoulder"][1],
|
||||
z=self.measurements["shoulder"][2],
|
||||
)
|
||||
|
||||
# Base
|
||||
# Base origin to centroid transform
|
||||
self.X_BoBc = self._create_translation_matrix(y=0.015)
|
||||
|
||||
# World to base transform
|
||||
self.X_WoBo = self._create_translation_matrix(
|
||||
x=self.measurements["base"][0],
|
||||
y=self.measurements["base"][1],
|
||||
z=self.measurements["base"][2],
|
||||
)
|
||||
|
||||
# Pre-compute gripper post-multiplication matrix
|
||||
self._fk_gripper_post = self.X_GoGc @ self.X_BoGo @ self.gripper_X0
|
||||
|
||||
def fk_base(self):
|
||||
"""Forward kinematics for the base frame."""
|
||||
return self.X_WoBo @ self.X_BoBc @ self.base_X0
|
||||
|
||||
def fk_shoulder(self, robot_pos_deg):
|
||||
"""Forward kinematics for the shoulder frame."""
|
||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||
return (
|
||||
self.X_WoBo
|
||||
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
||||
@ self.X_SoSc
|
||||
@ self.X_BS
|
||||
)
|
||||
|
||||
def fk_humerus(self, robot_pos_deg):
|
||||
"""Forward kinematics for the humerus frame."""
|
||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||
return (
|
||||
self.X_WoBo
|
||||
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
||||
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
|
||||
@ self.X_HoHc
|
||||
@ self.X_BH
|
||||
)
|
||||
|
||||
def fk_forearm(self, robot_pos_deg):
|
||||
"""Forward kinematics for the forearm frame."""
|
||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||
return (
|
||||
self.X_WoBo
|
||||
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
||||
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
|
||||
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
|
||||
@ self.X_FoFc
|
||||
@ self.X_BF
|
||||
)
|
||||
|
||||
def fk_wrist(self, robot_pos_deg):
|
||||
"""Forward kinematics for the wrist frame."""
|
||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||
return (
|
||||
self.X_WoBo
|
||||
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
||||
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
|
||||
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
|
||||
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
|
||||
@ self.X_RoRc
|
||||
@ self.X_BR
|
||||
@ self.wrist_X0
|
||||
)
|
||||
|
||||
def fk_gripper(self, robot_pos_deg):
|
||||
"""Forward kinematics for the gripper frame."""
|
||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||
return (
|
||||
self.X_WoBo
|
||||
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
||||
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
|
||||
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
|
||||
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
|
||||
@ screw_axis_to_transform(self.S_BG, robot_pos_rad[4])
|
||||
@ self._fk_gripper_post
|
||||
)
|
||||
|
||||
def fk_gripper_tip(self, robot_pos_deg):
|
||||
"""Forward kinematics for the gripper tip frame."""
|
||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||
return (
|
||||
self.X_WoBo
|
||||
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
||||
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
|
||||
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
|
||||
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
|
||||
@ screw_axis_to_transform(self.S_BG, robot_pos_rad[4])
|
||||
@ self.X_GoGt
|
||||
@ self.X_BoGo
|
||||
@ self.gripper_X0
|
||||
)
|
||||
|
||||
def compute_jacobian(self, robot_pos_deg, fk_func=None):
|
||||
"""Finite differences to compute the Jacobian.
|
||||
J(i, j) represents how the ith component of the end-effector's velocity changes wrt a small change
|
||||
in the jth joint's velocity.
|
||||
|
||||
Args:
|
||||
robot_pos_deg: Current joint positions in degrees
|
||||
fk_func: Forward kinematics function to use (defaults to fk_gripper)
|
||||
"""
|
||||
if fk_func is None:
|
||||
fk_func = self.fk_gripper
|
||||
|
||||
eps = 1e-8
|
||||
jac = np.zeros(shape=(6, 5))
|
||||
delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64)
|
||||
for el_ix in range(len(robot_pos_deg[:-1])):
|
||||
delta *= 0
|
||||
delta[el_ix] = eps / 2
|
||||
Sdot = (
|
||||
pose_difference_se3(
|
||||
fk_func(robot_pos_deg[:-1] + delta),
|
||||
fk_func(robot_pos_deg[:-1] - delta),
|
||||
)
|
||||
/ eps
|
||||
)
|
||||
jac[:, el_ix] = Sdot
|
||||
return jac
|
||||
|
||||
def compute_positional_jacobian(self, robot_pos_deg, fk_func=None):
|
||||
"""Finite differences to compute the positional Jacobian.
|
||||
J(i, j) represents how the ith component of the end-effector's position changes wrt a small change
|
||||
in the jth joint's velocity.
|
||||
|
||||
Args:
|
||||
robot_pos_deg: Current joint positions in degrees
|
||||
fk_func: Forward kinematics function to use (defaults to fk_gripper)
|
||||
"""
|
||||
if fk_func is None:
|
||||
fk_func = self.fk_gripper
|
||||
|
||||
eps = 1e-8
|
||||
jac = np.zeros(shape=(3, 5))
|
||||
delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64)
|
||||
for el_ix in range(len(robot_pos_deg[:-1])):
|
||||
delta *= 0
|
||||
delta[el_ix] = eps / 2
|
||||
Sdot = (
|
||||
fk_func(robot_pos_deg[:-1] + delta)[:3, 3]
|
||||
- fk_func(robot_pos_deg[:-1] - delta)[:3, 3]
|
||||
) / eps
|
||||
jac[:, el_ix] = Sdot
|
||||
return jac
|
||||
|
||||
def ik(
|
||||
self, current_joint_state, desired_ee_pose, position_only=True, fk_func=None
|
||||
):
|
||||
"""Inverse kinematics using gradient descent.
|
||||
|
||||
Args:
|
||||
current_joint_state: Initial joint positions in degrees
|
||||
desired_ee_pose: Target end-effector pose as a 4x4 transformation matrix
|
||||
position_only: If True, only match end-effector position, not orientation
|
||||
fk_func: Forward kinematics function to use (defaults to fk_gripper)
|
||||
|
||||
Returns:
|
||||
Joint positions in degrees that achieve the desired end-effector pose
|
||||
"""
|
||||
if fk_func is None:
|
||||
fk_func = self.fk_gripper
|
||||
|
||||
# Do gradient descent.
|
||||
max_iterations = 5
|
||||
learning_rate = 1
|
||||
for _ in range(max_iterations):
|
||||
current_ee_pose = fk_func(current_joint_state)
|
||||
if not position_only:
|
||||
error = se3_error(desired_ee_pose, current_ee_pose)
|
||||
jac = self.compute_jacobian(current_joint_state, fk_func)
|
||||
else:
|
||||
error = desired_ee_pose[:3, 3] - current_ee_pose[:3, 3]
|
||||
jac = self.compute_positional_jacobian(current_joint_state, fk_func)
|
||||
delta_angles = np.linalg.pinv(jac) @ error
|
||||
current_joint_state[:-1] += learning_rate * delta_angles
|
||||
|
||||
if np.linalg.norm(error) < 5e-3:
|
||||
return current_joint_state
|
||||
return current_joint_state
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
def run_test(robot_type):
|
||||
"""Run test suite for a specific robot type."""
|
||||
print(f"\n--- Testing {robot_type.upper()} Robot ---")
|
||||
|
||||
# Initialize kinematics for this robot
|
||||
robot = RobotKinematics(robot_type)
|
||||
|
||||
# Test 1: Forward kinematics consistency
|
||||
print("Test 1: Forward kinematics consistency")
|
||||
test_angles = np.array(
|
||||
[30, 45, -30, 20, 10, 0]
|
||||
) # Example joint angles in degrees
|
||||
|
||||
# Calculate FK for different joints
|
||||
shoulder_pose = robot.fk_shoulder(test_angles)
|
||||
humerus_pose = robot.fk_humerus(test_angles)
|
||||
forearm_pose = robot.fk_forearm(test_angles)
|
||||
wrist_pose = robot.fk_wrist(test_angles)
|
||||
gripper_pose = robot.fk_gripper(test_angles)
|
||||
gripper_tip_pose = robot.fk_gripper_tip(test_angles)
|
||||
|
||||
# Check that poses form a consistent kinematic chain (positions should be progressively further from origin)
|
||||
distances = [
|
||||
np.linalg.norm(shoulder_pose[:3, 3]),
|
||||
np.linalg.norm(humerus_pose[:3, 3]),
|
||||
np.linalg.norm(forearm_pose[:3, 3]),
|
||||
np.linalg.norm(wrist_pose[:3, 3]),
|
||||
np.linalg.norm(gripper_pose[:3, 3]),
|
||||
np.linalg.norm(gripper_tip_pose[:3, 3]),
|
||||
]
|
||||
|
||||
# Check if distances generally increase along the chain
|
||||
is_consistent = all(
|
||||
distances[i] <= distances[i + 1] for i in range(len(distances) - 1)
|
||||
)
|
||||
print(f" Pose distances from origin: {[round(d, 3) for d in distances]}")
|
||||
print(
|
||||
f" Kinematic chain consistency: {'PASSED' if is_consistent else 'FAILED'}"
|
||||
)
|
||||
|
||||
# Test 2: Jacobian computation
|
||||
print("Test 2: Jacobian computation")
|
||||
jacobian = robot.compute_jacobian(test_angles)
|
||||
positional_jacobian = robot.compute_positional_jacobian(test_angles)
|
||||
|
||||
# Check shapes
|
||||
jacobian_shape_ok = jacobian.shape == (6, 5)
|
||||
pos_jacobian_shape_ok = positional_jacobian.shape == (3, 5)
|
||||
|
||||
print(f" Jacobian shape: {'PASSED' if jacobian_shape_ok else 'FAILED'}")
|
||||
print(
|
||||
f" Positional Jacobian shape: {'PASSED' if pos_jacobian_shape_ok else 'FAILED'}"
|
||||
)
|
||||
|
||||
# Test 3: Inverse kinematics
|
||||
print("Test 3: Inverse kinematics (position only)")
|
||||
|
||||
# Generate target pose from known joint angles
|
||||
original_angles = np.array([10, 20, 30, -10, 5, 0])
|
||||
target_pose = robot.fk_gripper(original_angles)
|
||||
|
||||
# Start IK from a different position
|
||||
initial_guess = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
|
||||
|
||||
# Measure IK performance
|
||||
start_time = time.time()
|
||||
computed_angles = robot.ik(initial_guess.copy(), target_pose)
|
||||
ik_time = time.time() - start_time
|
||||
|
||||
# Compute resulting pose from IK solution
|
||||
result_pose = robot.fk_gripper(computed_angles)
|
||||
|
||||
# Calculate position error
|
||||
pos_error = np.linalg.norm(target_pose[:3, 3] - result_pose[:3, 3])
|
||||
passed = pos_error < 0.01 # Accept errors less than 1cm
|
||||
|
||||
print(f" IK computation time: {ik_time:.4f} seconds")
|
||||
print(f" Position error: {pos_error:.4f}")
|
||||
print(f" IK position accuracy: {'PASSED' if passed else 'FAILED'}")
|
||||
|
||||
return is_consistent and jacobian_shape_ok and pos_jacobian_shape_ok and passed
|
||||
|
||||
# Run tests for all robot types
|
||||
results = {}
|
||||
for robot_type in ["koch", "so100", "moss"]:
|
||||
results[robot_type] = run_test(robot_type)
|
||||
|
||||
# Print overall summary
|
||||
print("\n=== Test Summary ===")
|
||||
all_passed = all(results.values())
|
||||
for robot_type, passed in results.items():
|
||||
print(f"{robot_type.upper()}: {'PASSED' if passed else 'FAILED'}")
|
||||
print(f"\nOverall: {'ALL TESTS PASSED' if all_passed else 'SOME TESTS FAILED'}")
|
||||
@@ -14,19 +14,22 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import io
|
||||
import logging
|
||||
import pickle
|
||||
import queue
|
||||
import shutil
|
||||
import time
|
||||
from pprint import pformat
|
||||
from threading import Lock, Thread
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
# from torch.multiprocessing import Event, Queue, Process
|
||||
# from threading import Event, Thread
|
||||
# from torch.multiprocessing import Queue, Event
|
||||
from torch.multiprocessing import Queue
|
||||
|
||||
from lerobot.scripts.server.utils import setup_process_handlers
|
||||
|
||||
import grpc
|
||||
|
||||
# Import generated stubs
|
||||
import hilserl_pb2 # type: ignore
|
||||
import hilserl_pb2_grpc # type: ignore
|
||||
import hydra
|
||||
import torch
|
||||
@@ -52,17 +55,18 @@ from lerobot.common.utils.utils import (
|
||||
set_global_random_state,
|
||||
set_global_seed,
|
||||
)
|
||||
|
||||
from lerobot.scripts.server.buffer import (
|
||||
ReplayBuffer,
|
||||
concatenate_batch_transitions,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
move_state_dict_to_device,
|
||||
bytes_to_transitions,
|
||||
state_to_bytes,
|
||||
bytes_to_python_object,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
transition_queue = queue.Queue()
|
||||
interaction_message_queue = queue.Queue()
|
||||
from lerobot.scripts.server import learner_service
|
||||
|
||||
|
||||
def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
|
||||
@@ -77,9 +81,13 @@ def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
|
||||
# if resume == True
|
||||
checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir)
|
||||
if not checkpoint_dir.exists():
|
||||
raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True")
|
||||
raise RuntimeError(
|
||||
f"No model checkpoint found in {checkpoint_dir} for resume=True"
|
||||
)
|
||||
|
||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
||||
checkpoint_cfg_path = str(
|
||||
Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml"
|
||||
)
|
||||
logging.info(
|
||||
colored(
|
||||
"Resume=True detected, resuming previous run",
|
||||
@@ -112,7 +120,9 @@ def load_training_state(
|
||||
if not cfg.resume:
|
||||
return None, None
|
||||
|
||||
training_state = torch.load(logger.last_checkpoint_dir / logger.training_state_file_name)
|
||||
training_state = torch.load(
|
||||
logger.last_checkpoint_dir / logger.training_state_file_name, weights_only=False
|
||||
)
|
||||
|
||||
if isinstance(training_state["optimizer"], dict):
|
||||
assert set(training_state["optimizer"].keys()) == set(optimizers.keys())
|
||||
@@ -126,7 +136,9 @@ def load_training_state(
|
||||
|
||||
|
||||
def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_learnable_params = sum(
|
||||
p.numel() for p in policy.parameters() if p.requires_grad
|
||||
)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
log_output_dir(out_dir)
|
||||
@@ -136,177 +148,266 @@ def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
|
||||
def initialize_replay_buffer(cfg: DictConfig, logger: Logger, device: str) -> ReplayBuffer:
|
||||
def initialize_replay_buffer(
|
||||
cfg: DictConfig, logger: Logger, device: str, storage_device: str
|
||||
) -> ReplayBuffer:
|
||||
if not cfg.resume:
|
||||
return ReplayBuffer(
|
||||
capacity=cfg.training.online_buffer_capacity,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
storage_device=device
|
||||
storage_device=storage_device,
|
||||
optimize_memory=True,
|
||||
)
|
||||
|
||||
logging.info("Resume training load the online dataset")
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=cfg.dataset_repo_id, local_files_only=True, root=logger.log_dir / "dataset"
|
||||
repo_id=cfg.dataset_repo_id,
|
||||
local_files_only=True,
|
||||
root=logger.log_dir / "dataset",
|
||||
)
|
||||
return ReplayBuffer.from_lerobot_dataset(
|
||||
lerobot_dataset=dataset,
|
||||
capacity=cfg.training.online_buffer_capacity,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
optimize_memory=True,
|
||||
)
|
||||
|
||||
|
||||
def initialize_offline_replay_buffer(
|
||||
cfg: DictConfig,
|
||||
logger: Logger,
|
||||
device: str,
|
||||
storage_device: str,
|
||||
active_action_dims: list[int] | None = None,
|
||||
) -> ReplayBuffer:
|
||||
if not cfg.resume:
|
||||
logging.info("make_dataset offline buffer")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
if cfg.resume:
|
||||
logging.info("load offline dataset")
|
||||
offline_dataset = LeRobotDataset(
|
||||
repo_id=cfg.dataset_repo_id,
|
||||
local_files_only=True,
|
||||
root=logger.log_dir / "dataset_offline",
|
||||
)
|
||||
|
||||
logging.info("Convert to a offline replay buffer")
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
offline_dataset,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
action_mask=active_action_dims,
|
||||
action_delta=cfg.env.wrapper.delta_action,
|
||||
storage_device=storage_device,
|
||||
optimize_memory=True,
|
||||
capacity=cfg.training.offline_buffer_capacity,
|
||||
)
|
||||
return offline_replay_buffer
|
||||
|
||||
|
||||
def get_observation_features(
|
||||
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||
if (
|
||||
policy.config.vision_encoder_name is None
|
||||
or not policy.config.freeze_vision_encoder
|
||||
):
|
||||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
observation_features = (
|
||||
policy.actor.encoder(observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
)
|
||||
next_observation_features = (
|
||||
policy.actor.encoder(next_observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
)
|
||||
|
||||
return observation_features, next_observation_features
|
||||
|
||||
|
||||
def use_threads(cfg: DictConfig) -> bool:
|
||||
return cfg.actor_learner_config.concurrency.learner == "threads"
|
||||
|
||||
|
||||
def start_learner_threads(
|
||||
cfg: DictConfig,
|
||||
device: str,
|
||||
replay_buffer: ReplayBuffer,
|
||||
offline_replay_buffer: ReplayBuffer,
|
||||
batch_size: int,
|
||||
optimizers: dict,
|
||||
policy: SACPolicy,
|
||||
policy_lock: Lock,
|
||||
logger: Logger,
|
||||
resume_optimization_step: int | None = None,
|
||||
resume_interaction_step: int | None = None,
|
||||
out_dir: str,
|
||||
shutdown_event: any, # Event,
|
||||
) -> None:
|
||||
actor_ip = cfg.actor_learner_config.actor_ip
|
||||
port = cfg.actor_learner_config.port
|
||||
# Create multiprocessing queues
|
||||
transition_queue = Queue()
|
||||
interaction_message_queue = Queue()
|
||||
parameters_queue = Queue()
|
||||
|
||||
server_thread = Thread(
|
||||
target=stream_transitions_from_actor,
|
||||
args=(
|
||||
actor_ip,
|
||||
port,
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
concurrency_entity = None
|
||||
|
||||
transition_thread = Thread(
|
||||
target=add_actor_information_and_train,
|
||||
daemon=True,
|
||||
if use_threads(cfg):
|
||||
from threading import Thread
|
||||
|
||||
concurrency_entity = Thread
|
||||
else:
|
||||
from torch.multiprocessing import Process
|
||||
|
||||
concurrency_entity = Process
|
||||
|
||||
communication_process = concurrency_entity(
|
||||
target=start_learner_server,
|
||||
args=(
|
||||
parameters_queue,
|
||||
transition_queue,
|
||||
interaction_message_queue,
|
||||
shutdown_event,
|
||||
cfg,
|
||||
device,
|
||||
replay_buffer,
|
||||
offline_replay_buffer,
|
||||
batch_size,
|
||||
optimizers,
|
||||
policy,
|
||||
policy_lock,
|
||||
logger,
|
||||
resume_optimization_step,
|
||||
resume_interaction_step,
|
||||
),
|
||||
)
|
||||
|
||||
param_push_thread = Thread(
|
||||
target=learner_push_parameters,
|
||||
args=(policy, policy_lock, actor_ip, port, 15),
|
||||
daemon=True,
|
||||
)
|
||||
communication_process.start()
|
||||
|
||||
server_thread.start()
|
||||
transition_thread.start()
|
||||
param_push_thread.start()
|
||||
param_push_thread.join()
|
||||
transition_thread.join()
|
||||
server_thread.join()
|
||||
add_actor_information_and_train(
|
||||
cfg,
|
||||
logger,
|
||||
out_dir,
|
||||
shutdown_event,
|
||||
transition_queue,
|
||||
interaction_message_queue,
|
||||
parameters_queue,
|
||||
)
|
||||
logging.info("[LEARNER] Training process stopped")
|
||||
|
||||
logging.info("[LEARNER] Closing queues")
|
||||
transition_queue.close()
|
||||
interaction_message_queue.close()
|
||||
parameters_queue.close()
|
||||
|
||||
communication_process.join()
|
||||
logging.info("[LEARNER] Communication process joined")
|
||||
|
||||
logging.info("[LEARNER] join queues")
|
||||
transition_queue.cancel_join_thread()
|
||||
interaction_message_queue.cancel_join_thread()
|
||||
parameters_queue.cancel_join_thread()
|
||||
|
||||
logging.info("[LEARNER] queues closed")
|
||||
|
||||
|
||||
def stream_transitions_from_actor(host="127.0.0.1", port=50051):
|
||||
def start_learner_server(
|
||||
parameters_queue: Queue,
|
||||
transition_queue: Queue,
|
||||
interaction_message_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
cfg: DictConfig,
|
||||
):
|
||||
if not use_threads(cfg):
|
||||
# We need init logging for MP separataly
|
||||
init_logging()
|
||||
|
||||
# Setup process handlers to handle shutdown signal
|
||||
# But use shutdown event from the main process
|
||||
# Return back for MP
|
||||
setup_process_handlers(False)
|
||||
|
||||
service = learner_service.LearnerService(
|
||||
shutdown_event,
|
||||
parameters_queue,
|
||||
cfg.actor_learner_config.policy_parameters_push_frequency,
|
||||
transition_queue,
|
||||
interaction_message_queue,
|
||||
)
|
||||
|
||||
server = grpc.server(
|
||||
ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS),
|
||||
options=[
|
||||
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||
],
|
||||
)
|
||||
|
||||
hilserl_pb2_grpc.add_LearnerServiceServicer_to_server(
|
||||
service,
|
||||
server,
|
||||
)
|
||||
|
||||
host = cfg.actor_learner_config.learner_host
|
||||
port = cfg.actor_learner_config.learner_port
|
||||
|
||||
server.add_insecure_port(f"{host}:{port}")
|
||||
server.start()
|
||||
logging.info("[LEARNER] gRPC server started")
|
||||
|
||||
shutdown_event.wait()
|
||||
logging.info("[LEARNER] Stopping gRPC server...")
|
||||
server.stop(learner_service.STUTDOWN_TIMEOUT)
|
||||
logging.info("[LEARNER] gRPC server stopped")
|
||||
|
||||
|
||||
def check_nan_in_transition(
|
||||
observations: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
next_state: torch.Tensor,
|
||||
raise_error: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Runs a gRPC client that listens for transition and interaction messages from an Actor service.
|
||||
|
||||
This function establishes a gRPC connection with the given `host` and `port`, then continuously
|
||||
streams transition data from the `ActorServiceStub`. The received transition data is deserialized
|
||||
and stored in a queue (`transition_queue`). Similarly, interaction messages are also deserialized
|
||||
and stored in a separate queue (`interaction_message_queue`).
|
||||
Check for NaN values in transition data.
|
||||
|
||||
Args:
|
||||
host (str, optional): The IP address or hostname of the gRPC server. Defaults to `"127.0.0.1"`.
|
||||
port (int, optional): The port number on which the gRPC server is running. Defaults to `50051`.
|
||||
observations: Dictionary of observation tensors
|
||||
actions: Action tensor
|
||||
next_state: Dictionary of next state tensors
|
||||
raise_error: If True, raises ValueError when NaN is detected
|
||||
|
||||
Returns:
|
||||
bool: True if NaN values were detected, False otherwise
|
||||
"""
|
||||
# NOTE: This is waiting for the handshake to be done
|
||||
# In the future we will do it in a canonical way with a proper handshake
|
||||
time.sleep(10)
|
||||
channel = grpc.insecure_channel(
|
||||
f"{host}:{port}",
|
||||
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
|
||||
)
|
||||
stub = hilserl_pb2_grpc.ActorServiceStub(channel)
|
||||
for response in stub.StreamTransition(hilserl_pb2.Empty()):
|
||||
if response.HasField("transition"):
|
||||
buffer = io.BytesIO(response.transition.transition_bytes)
|
||||
transition = torch.load(buffer)
|
||||
transition_queue.put(transition)
|
||||
if response.HasField("interaction_message"):
|
||||
content = pickle.loads(response.interaction_message.interaction_message_bytes)
|
||||
interaction_message_queue.put(content)
|
||||
nan_detected = False
|
||||
|
||||
# Check observations
|
||||
for key, tensor in observations.items():
|
||||
if torch.isnan(tensor).any():
|
||||
logging.error(f"observations[{key}] contains NaN values")
|
||||
nan_detected = True
|
||||
if raise_error:
|
||||
raise ValueError(f"NaN detected in observations[{key}]")
|
||||
|
||||
def learner_push_parameters(
|
||||
policy: nn.Module, policy_lock: Lock, actor_host="127.0.0.1", actor_port=50052, seconds_between_pushes=5
|
||||
):
|
||||
"""
|
||||
As a client, connect to the Actor's gRPC server (ActorService)
|
||||
and periodically push new parameters.
|
||||
"""
|
||||
time.sleep(10)
|
||||
channel = grpc.insecure_channel(
|
||||
f"{actor_host}:{actor_port}",
|
||||
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
|
||||
)
|
||||
actor_stub = hilserl_pb2_grpc.ActorServiceStub(channel)
|
||||
# Check next state
|
||||
for key, tensor in next_state.items():
|
||||
if torch.isnan(tensor).any():
|
||||
logging.error(f"next_state[{key}] contains NaN values")
|
||||
nan_detected = True
|
||||
if raise_error:
|
||||
raise ValueError(f"NaN detected in next_state[{key}]")
|
||||
|
||||
while True:
|
||||
with policy_lock:
|
||||
params_dict = policy.actor.state_dict()
|
||||
if policy.config.vision_encoder_name is not None:
|
||||
if policy.config.freeze_vision_encoder:
|
||||
params_dict: dict[str, torch.Tensor] = {
|
||||
k: v for k, v in params_dict.items() if not k.startswith("encoder.")
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
|
||||
)
|
||||
|
||||
params_dict = move_state_dict_to_device(params_dict, device="cpu")
|
||||
# Serialize
|
||||
buf = io.BytesIO()
|
||||
torch.save(params_dict, buf)
|
||||
params_bytes = buf.getvalue()
|
||||
|
||||
# Push them to the Actor's "SendParameters" method
|
||||
logging.info("[LEARNER] Publishing parameters to the Actor")
|
||||
response = actor_stub.SendParameters(hilserl_pb2.Parameters(parameter_bytes=params_bytes)) # noqa: F841
|
||||
time.sleep(seconds_between_pushes)
|
||||
|
||||
|
||||
def check_nan_in_transition(observations: torch.Tensor, actions: torch.Tensor, next_state: torch.Tensor):
|
||||
for k in observations:
|
||||
if torch.isnan(observations[k]).any():
|
||||
logging.error(f"observations[{k}] contains NaN values")
|
||||
for k in next_state:
|
||||
if torch.isnan(next_state[k]).any():
|
||||
logging.error(f"next_state[{k}] contains NaN values")
|
||||
# Check actions
|
||||
if torch.isnan(actions).any():
|
||||
logging.error("actions contains NaN values")
|
||||
nan_detected = True
|
||||
if raise_error:
|
||||
raise ValueError("NaN detected in actions")
|
||||
|
||||
return nan_detected
|
||||
|
||||
|
||||
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
|
||||
logging.debug("[LEARNER] Pushing actor policy to the queue")
|
||||
state_dict = move_state_dict_to_device(policy.actor.state_dict(), device="cpu")
|
||||
state_bytes = state_to_bytes(state_dict)
|
||||
parameters_queue.put(state_bytes)
|
||||
|
||||
|
||||
def add_actor_information_and_train(
|
||||
cfg,
|
||||
device: str,
|
||||
replay_buffer: ReplayBuffer,
|
||||
offline_replay_buffer: ReplayBuffer,
|
||||
batch_size: int,
|
||||
optimizers: dict[str, torch.optim.Optimizer],
|
||||
policy: nn.Module,
|
||||
policy_lock: Lock,
|
||||
logger: Logger,
|
||||
resume_optimization_step: int | None = None,
|
||||
resume_interaction_step: int | None = None,
|
||||
out_dir: str,
|
||||
shutdown_event: any, # Event,
|
||||
transition_queue: Queue,
|
||||
interaction_message_queue: Queue,
|
||||
parameters_queue: Queue,
|
||||
):
|
||||
"""
|
||||
Handles data transfer from the actor to the learner, manages training updates,
|
||||
@@ -329,52 +430,147 @@ def add_actor_information_and_train(
|
||||
Args:
|
||||
cfg: Configuration object containing hyperparameters.
|
||||
device (str): The computing device (`"cpu"` or `"cuda"`).
|
||||
replay_buffer (ReplayBuffer): The primary replay buffer storing online transitions.
|
||||
offline_replay_buffer (ReplayBuffer): An additional buffer for offline transitions.
|
||||
batch_size (int): The number of transitions to sample per training step.
|
||||
optimizers (Dict[str, torch.optim.Optimizer]): A dictionary of optimizers (`"actor"`, `"critic"`, `"temperature"`).
|
||||
policy (nn.Module): The reinforcement learning policy with critic, actor, and temperature parameters.
|
||||
policy_lock (Lock): A threading lock to ensure safe policy updates.
|
||||
logger (Logger): Logger instance for tracking training progress.
|
||||
resume_optimization_step (int | None): In the case of resume training, start from the last optimization step reached.
|
||||
resume_interaction_step (int | None): In the case of resume training, shift the interaction step with the last saved step in order to not break logging.
|
||||
out_dir (str): The output directory for storing training checkpoints and logs.
|
||||
shutdown_event (Event): Event to signal shutdown.
|
||||
transition_queue (Queue): Queue for receiving transitions from the actor.
|
||||
interaction_message_queue (Queue): Queue for receiving interaction messages from the actor.
|
||||
parameters_queue (Queue): Queue for sending policy parameters to the actor.
|
||||
"""
|
||||
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device)
|
||||
|
||||
logging.info("Initializing policy")
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
# TODO: At some point we should just need make sac policy
|
||||
|
||||
policy: SACPolicy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online traning, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
|
||||
if cfg.resume
|
||||
else None,
|
||||
)
|
||||
|
||||
# Update the policy config with the grad_clip_norm value from training config if it exists
|
||||
clip_grad_norm_value = cfg.training.grad_clip_norm
|
||||
|
||||
# compile policy
|
||||
policy = torch.compile(policy)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
push_actor_policy_to_queue(parameters_queue, policy)
|
||||
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(
|
||||
cfg, logger, optimizers
|
||||
)
|
||||
|
||||
log_training_info(cfg, out_dir, policy)
|
||||
|
||||
replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device)
|
||||
batch_size = cfg.training.batch_size
|
||||
offline_replay_buffer = None
|
||||
|
||||
if cfg.dataset_repo_id is not None:
|
||||
active_action_dims = None
|
||||
if cfg.env.wrapper.joint_masking_action_space is not None:
|
||||
active_action_dims = [
|
||||
i
|
||||
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
|
||||
if mask
|
||||
]
|
||||
offline_replay_buffer = initialize_offline_replay_buffer(
|
||||
cfg=cfg,
|
||||
logger=logger,
|
||||
device=device,
|
||||
storage_device=storage_device,
|
||||
active_action_dims=active_action_dims,
|
||||
)
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
# NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
|
||||
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
|
||||
# are divided by 200. So we need to have a single thread that does all the work.
|
||||
time.time()
|
||||
logging.info("Starting learner thread")
|
||||
interaction_message, transition = None, None
|
||||
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
|
||||
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
|
||||
optimization_step = (
|
||||
resume_optimization_step if resume_optimization_step is not None else 0
|
||||
)
|
||||
interaction_step_shift = (
|
||||
resume_interaction_step if resume_interaction_step is not None else 0
|
||||
)
|
||||
|
||||
# Extract variables from cfg
|
||||
online_step_before_learning = cfg.training.online_step_before_learning
|
||||
utd_ratio = cfg.policy.utd_ratio
|
||||
dataset_repo_id = cfg.dataset_repo_id
|
||||
fps = cfg.fps
|
||||
log_freq = cfg.training.log_freq
|
||||
save_freq = cfg.training.save_freq
|
||||
device = cfg.device
|
||||
storage_device = cfg.training.storage_device
|
||||
policy_update_freq = cfg.training.policy_update_freq
|
||||
policy_parameters_push_frequency = (
|
||||
cfg.actor_learner_config.policy_parameters_push_frequency
|
||||
)
|
||||
save_checkpoint = cfg.training.save_checkpoint
|
||||
online_steps = cfg.training.online_steps
|
||||
|
||||
while True:
|
||||
while not transition_queue.empty():
|
||||
if shutdown_event is not None and shutdown_event.is_set():
|
||||
logging.info("[LEARNER] Shutdown signal received. Exiting...")
|
||||
break
|
||||
|
||||
logging.debug("[LEARNER] Waiting for transitions")
|
||||
while not transition_queue.empty() and not shutdown_event.is_set():
|
||||
transition_list = transition_queue.get()
|
||||
transition_list = bytes_to_transitions(transition_list)
|
||||
|
||||
for transition in transition_list:
|
||||
transition = move_transition_to_device(transition, device=device)
|
||||
if check_nan_in_transition(
|
||||
transition["state"], transition["action"], transition["next_state"]
|
||||
):
|
||||
logging.warning("NaN detected in transition, skipping")
|
||||
continue
|
||||
replay_buffer.add(**transition)
|
||||
|
||||
if transition.get("complementary_info", {}).get("is_intervention"):
|
||||
if cfg.dataset_repo_id is not None and transition.get(
|
||||
"complementary_info", {}
|
||||
).get("is_intervention"):
|
||||
offline_replay_buffer.add(**transition)
|
||||
|
||||
while not interaction_message_queue.empty():
|
||||
logging.debug("[LEARNER] Received transitions")
|
||||
logging.debug("[LEARNER] Waiting for interactions")
|
||||
while not interaction_message_queue.empty() and not shutdown_event.is_set():
|
||||
interaction_message = interaction_message_queue.get()
|
||||
interaction_message = bytes_to_python_object(interaction_message)
|
||||
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
|
||||
interaction_message["Interaction step"] += interaction_step_shift
|
||||
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
|
||||
# logging.info(f"Interaction message: {interaction_message}")
|
||||
logger.log_dict(
|
||||
interaction_message, mode="train", custom_step_key="Interaction step"
|
||||
)
|
||||
|
||||
if len(replay_buffer) < cfg.training.online_step_before_learning:
|
||||
logging.debug("[LEARNER] Received interactions")
|
||||
|
||||
if len(replay_buffer) < online_step_before_learning:
|
||||
continue
|
||||
|
||||
# logging.info(f"Size of replay buffer: {len(replay_buffer)}")
|
||||
# logging.info(f"Size of offline replay buffer: {len(offline_replay_buffer)}")
|
||||
|
||||
logging.debug("[LEARNER] Starting optimization loop")
|
||||
time_for_one_optimization_step = time.time()
|
||||
for _ in range(cfg.policy.utd_ratio - 1):
|
||||
for _ in range(utd_ratio - 1):
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
|
||||
if cfg.dataset_repo_id is not None:
|
||||
if dataset_repo_id is not None:
|
||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
batch = concatenate_batch_transitions(batch, batch_offline)
|
||||
|
||||
@@ -383,23 +579,35 @@ def add_actor_information_and_train(
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
check_nan_in_transition(
|
||||
observations=observations, actions=actions, next_state=next_observations
|
||||
)
|
||||
|
||||
with policy_lock:
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
optimizers["critic"].step()
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy, observations, next_observations
|
||||
)
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
|
||||
# clip gradients
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.critic_ensemble.parameters(), clip_grad_norm_value
|
||||
)
|
||||
|
||||
optimizers["critic"].step()
|
||||
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
|
||||
if cfg.dataset_repo_id is not None:
|
||||
if dataset_repo_id is not None:
|
||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
@@ -411,51 +619,101 @@ def add_actor_information_and_train(
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
check_nan_in_transition(
|
||||
observations=observations, actions=actions, next_state=next_observations
|
||||
)
|
||||
|
||||
with policy_lock:
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
optimizers["critic"].step()
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy, observations, next_observations
|
||||
)
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
|
||||
# clip gradients
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.critic_ensemble.parameters(), clip_grad_norm_value
|
||||
).item()
|
||||
|
||||
optimizers["critic"].step()
|
||||
|
||||
training_infos = {}
|
||||
training_infos["loss_critic"] = loss_critic.item()
|
||||
training_infos["critic_grad_norm"] = critic_grad_norm
|
||||
|
||||
if optimization_step % cfg.training.policy_update_freq == 0:
|
||||
for _ in range(cfg.training.policy_update_freq):
|
||||
with policy_lock:
|
||||
loss_actor = policy.compute_loss_actor(observations=observations)
|
||||
if optimization_step % policy_update_freq == 0:
|
||||
for _ in range(policy_update_freq):
|
||||
loss_actor = policy.compute_loss_actor(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
optimizers["actor"].step()
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
# clip gradients
|
||||
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.actor.parameters_to_optimize, clip_grad_norm_value
|
||||
).item()
|
||||
|
||||
loss_temperature = policy.compute_loss_temperature(observations=observations)
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
optimizers["temperature"].step()
|
||||
optimizers["actor"].step()
|
||||
|
||||
training_infos["loss_temperature"] = loss_temperature.item()
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
training_infos["actor_grad_norm"] = actor_grad_norm
|
||||
|
||||
# Temperature optimization
|
||||
loss_temperature = policy.compute_loss_temperature(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
|
||||
# clip gradients
|
||||
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
[policy.log_alpha], clip_grad_norm_value
|
||||
).item()
|
||||
|
||||
optimizers["temperature"].step()
|
||||
|
||||
training_infos["loss_temperature"] = loss_temperature.item()
|
||||
training_infos["temperature_grad_norm"] = temp_grad_norm
|
||||
training_infos["temperature"] = policy.temperature
|
||||
|
||||
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
|
||||
push_actor_policy_to_queue(parameters_queue, policy)
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
policy.update_target_networks()
|
||||
if optimization_step % cfg.training.log_freq == 0:
|
||||
|
||||
if optimization_step % log_freq == 0:
|
||||
training_infos["replay_buffer_size"] = len(replay_buffer)
|
||||
if offline_replay_buffer is not None:
|
||||
training_infos["offline_replay_buffer_size"] = len(
|
||||
offline_replay_buffer
|
||||
)
|
||||
training_infos["Optimization step"] = optimization_step
|
||||
logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step")
|
||||
logger.log_dict(
|
||||
d=training_infos, mode="train", custom_step_key="Optimization step"
|
||||
)
|
||||
# logging.info(f"Training infos: {training_infos}")
|
||||
|
||||
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
||||
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
|
||||
frequency_for_one_optimization_step = 1 / (
|
||||
time_for_one_optimization_step + 1e-9
|
||||
)
|
||||
|
||||
logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
|
||||
logging.info(
|
||||
f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}"
|
||||
)
|
||||
|
||||
logger.log_dict(
|
||||
{
|
||||
@@ -467,19 +725,19 @@ def add_actor_information_and_train(
|
||||
)
|
||||
|
||||
optimization_step += 1
|
||||
if optimization_step % cfg.training.log_freq == 0:
|
||||
if optimization_step % log_freq == 0:
|
||||
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
|
||||
|
||||
if cfg.training.save_checkpoint and (
|
||||
optimization_step % cfg.training.save_freq == 0 or optimization_step == cfg.training.online_steps
|
||||
if save_checkpoint and (
|
||||
optimization_step % save_freq == 0 or optimization_step == online_steps
|
||||
):
|
||||
logging.info(f"Checkpoint policy after step {optimization_step}")
|
||||
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
||||
# needed (choose 6 as a minimum for consistency without being overkill).
|
||||
_num_digits = max(6, len(str(cfg.training.online_steps)))
|
||||
_num_digits = max(6, len(str(online_steps)))
|
||||
step_identifier = f"{optimization_step:0{_num_digits}d}"
|
||||
interaction_step = (
|
||||
interaction_message["Interaction step"] if interaction_message is not None else 0
|
||||
interaction_message["Interaction step"]
|
||||
if interaction_message is not None
|
||||
else 0
|
||||
)
|
||||
logger.save_checkpoint(
|
||||
optimization_step,
|
||||
@@ -498,8 +756,21 @@ def add_actor_information_and_train(
|
||||
dataset_dir,
|
||||
)
|
||||
replay_buffer.to_lerobot_dataset(
|
||||
cfg.dataset_repo_id, fps=cfg.fps, root=logger.log_dir / "dataset"
|
||||
dataset_repo_id, fps=fps, root=logger.log_dir / "dataset"
|
||||
)
|
||||
if offline_replay_buffer is not None:
|
||||
dataset_dir = logger.log_dir / "dataset_offline"
|
||||
|
||||
if dataset_dir.exists() and dataset_dir.is_dir():
|
||||
shutil.rmtree(
|
||||
dataset_dir,
|
||||
)
|
||||
|
||||
offline_replay_buffer.to_lerobot_dataset(
|
||||
cfg.dataset_repo_id,
|
||||
fps=cfg.fps,
|
||||
root=logger.log_dir / "dataset_offline",
|
||||
)
|
||||
|
||||
logging.info("Resume training")
|
||||
|
||||
@@ -538,7 +809,9 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
|
||||
optimizer_temperature = torch.optim.Adam(
|
||||
params=[policy.log_alpha], lr=policy.config.critic_lr
|
||||
)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
@@ -562,76 +835,36 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("make_policy")
|
||||
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
# TODO: At some point we should just need make sac policy
|
||||
|
||||
policy_lock = Lock()
|
||||
policy: SACPolicy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online traning, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
)
|
||||
# compile policy
|
||||
policy = torch.compile(policy)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers)
|
||||
|
||||
log_training_info(cfg, out_dir, policy)
|
||||
|
||||
replay_buffer = initialize_replay_buffer(cfg, logger, device)
|
||||
batch_size = cfg.training.batch_size
|
||||
offline_replay_buffer = None
|
||||
|
||||
if cfg.dataset_repo_id is not None:
|
||||
logging.info("make_dataset offline buffer")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
logging.info("Convertion to a offline replay buffer")
|
||||
active_action_dims = [i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask]
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
offline_dataset,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
action_mask=active_action_dims,
|
||||
action_delta=cfg.env.wrapper.delta_action,
|
||||
)
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
shutdown_event = setup_process_handlers(use_threads(cfg))
|
||||
|
||||
start_learner_threads(
|
||||
cfg,
|
||||
device,
|
||||
replay_buffer,
|
||||
offline_replay_buffer,
|
||||
batch_size,
|
||||
optimizers,
|
||||
policy,
|
||||
policy_lock,
|
||||
logger,
|
||||
resume_optimization_step,
|
||||
resume_interaction_step,
|
||||
out_dir,
|
||||
shutdown_event,
|
||||
)
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
||||
def train_cli(cfg: dict):
|
||||
if not use_threads(cfg):
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
|
||||
train(
|
||||
cfg,
|
||||
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
||||
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
||||
)
|
||||
|
||||
logging.info("[LEARNER] train_cli finished")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_cli()
|
||||
|
||||
logging.info("[LEARNER] main finished")
|
||||
|
||||
82
lerobot/scripts/server/learner_service.py
Normal file
82
lerobot/scripts/server/learner_service.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import hilserl_pb2 # type: ignore
|
||||
import hilserl_pb2_grpc # type: ignore
|
||||
import logging
|
||||
from multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.scripts.server.network_utils import receive_bytes_in_chunks
|
||||
from lerobot.scripts.server.network_utils import send_bytes_in_chunks
|
||||
|
||||
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
|
||||
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
|
||||
STUTDOWN_TIMEOUT = 10
|
||||
|
||||
|
||||
class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
|
||||
def __init__(
|
||||
self,
|
||||
shutdown_event: Event,
|
||||
parameters_queue: Queue,
|
||||
seconds_between_pushes: float,
|
||||
transition_queue: Queue,
|
||||
interaction_message_queue: Queue,
|
||||
):
|
||||
self.shutdown_event = shutdown_event
|
||||
self.parameters_queue = parameters_queue
|
||||
self.seconds_between_pushes = seconds_between_pushes
|
||||
self.transition_queue = transition_queue
|
||||
self.interaction_message_queue = interaction_message_queue
|
||||
|
||||
def StreamParameters(self, request, context):
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to stream parameters from the Actor")
|
||||
|
||||
while not self.shutdown_event.is_set():
|
||||
logging.info("[LEARNER] Push parameters to the Actor")
|
||||
buffer = self.parameters_queue.get()
|
||||
|
||||
yield from send_bytes_in_chunks(
|
||||
buffer,
|
||||
hilserl_pb2.Parameters,
|
||||
log_prefix="[LEARNER] Sending parameters",
|
||||
silent=True,
|
||||
)
|
||||
|
||||
logging.info("[LEARNER] Parameters sent")
|
||||
|
||||
self.shutdown_event.wait(self.seconds_between_pushes)
|
||||
|
||||
logging.info("[LEARNER] Stream parameters finished")
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
def SendTransitions(self, request_iterator, _context):
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to receive transitions from the Actor")
|
||||
|
||||
receive_bytes_in_chunks(
|
||||
request_iterator,
|
||||
self.transition_queue,
|
||||
self.shutdown_event,
|
||||
log_prefix="[LEARNER] transitions",
|
||||
)
|
||||
|
||||
logging.debug("[LEARNER] Finished receiving transitions")
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
def SendInteractions(self, request_iterator, _context):
|
||||
# TODO: authorize the request
|
||||
logging.info(
|
||||
"[LEARNER] Received request to receive interactions from the Actor"
|
||||
)
|
||||
|
||||
receive_bytes_in_chunks(
|
||||
request_iterator,
|
||||
self.interaction_message_queue,
|
||||
self.shutdown_event,
|
||||
log_prefix="[LEARNER] interactions",
|
||||
)
|
||||
|
||||
logging.debug("[LEARNER] Finished receiving interactions")
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
def Ready(self, request, context):
|
||||
return hilserl_pb2.Empty()
|
||||
@@ -5,13 +5,13 @@ import torch
|
||||
|
||||
from omegaconf import DictConfig
|
||||
from typing import Any
|
||||
|
||||
"""Make ManiSkill3 gym environment"""
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
from mani_skill.utils.wrappers.record import RecordEpisode
|
||||
|
||||
|
||||
|
||||
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, torch.Tensor]:
|
||||
def preprocess_maniskill_observation(
|
||||
observations: dict[str, np.ndarray],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
Args:
|
||||
observation: Dictionary of observation batches from a Gym vector environment.
|
||||
@@ -63,7 +63,9 @@ class ManiSkillCompat(gym.Wrapper):
|
||||
new_action_space_shape = env.action_space.shape[-1]
|
||||
new_low = np.squeeze(env.action_space.low, axis=0)
|
||||
new_high = np.squeeze(env.action_space.high, axis=0)
|
||||
self.action_space = gym.spaces.Box(low=new_low, high=new_high, shape=(new_action_space_shape,))
|
||||
self.action_space = gym.spaces.Box(
|
||||
low=new_low, high=new_high, shape=(new_action_space_shape,)
|
||||
)
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
@@ -82,7 +84,9 @@ class ManiSkillCompat(gym.Wrapper):
|
||||
class ManiSkillActionWrapper(gym.ActionWrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
self.action_space = gym.spaces.Tuple(spaces=(env.action_space, gym.spaces.Discrete(2)))
|
||||
self.action_space = gym.spaces.Tuple(
|
||||
spaces=(env.action_space, gym.spaces.Discrete(2))
|
||||
)
|
||||
|
||||
def action(self, action):
|
||||
action, telop = action
|
||||
@@ -96,7 +100,9 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper):
|
||||
action_space_agent: gym.spaces.Box = env.action_space[0]
|
||||
action_space_agent.low = action_space_agent.low * multiply_factor
|
||||
action_space_agent.high = action_space_agent.high * multiply_factor
|
||||
self.action_space = gym.spaces.Tuple(spaces=(action_space_agent, gym.spaces.Discrete(2)))
|
||||
self.action_space = gym.spaces.Tuple(
|
||||
spaces=(action_space_agent, gym.spaces.Discrete(2))
|
||||
)
|
||||
|
||||
def step(self, action):
|
||||
if isinstance(action, tuple):
|
||||
@@ -136,13 +142,24 @@ def make_maniskill(
|
||||
num_envs=n_envs,
|
||||
)
|
||||
|
||||
if cfg.env.video_record.enabled:
|
||||
env = RecordEpisode(
|
||||
env,
|
||||
output_dir=cfg.env.video_record.record_dir,
|
||||
save_trajectory=True,
|
||||
trajectory_name=cfg.env.video_record.trajectory_name,
|
||||
save_video=True,
|
||||
video_fps=30,
|
||||
)
|
||||
env = ManiSkillObservationWrapper(env, device=cfg.env.device)
|
||||
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
|
||||
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
|
||||
env._max_episode_steps = env.max_episode_steps = (
|
||||
50 # gym_utils.find_max_episode_steps_value(env)
|
||||
)
|
||||
env.unwrapped.metadata["render_fps"] = 20
|
||||
env = ManiSkillCompat(env)
|
||||
env = ManiSkillActionWrapper(env)
|
||||
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=10.0)
|
||||
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03)
|
||||
|
||||
return env
|
||||
|
||||
@@ -150,10 +167,11 @@ def make_maniskill(
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import hydra
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default="lerobot/configs/env/maniskill_example.yaml")
|
||||
parser.add_argument(
|
||||
"--config", type=str, default="lerobot/configs/env/maniskill_example.yaml"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize config
|
||||
|
||||
102
lerobot/scripts/server/network_utils.py
Normal file
102
lerobot/scripts/server/network_utils.py
Normal file
@@ -0,0 +1,102 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.scripts.server import hilserl_pb2
|
||||
import logging
|
||||
import io
|
||||
from multiprocessing import Queue, Event
|
||||
from typing import Any
|
||||
|
||||
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
|
||||
|
||||
|
||||
def bytes_buffer_size(buffer: io.BytesIO) -> int:
|
||||
buffer.seek(0, io.SEEK_END)
|
||||
result = buffer.tell()
|
||||
buffer.seek(0)
|
||||
return result
|
||||
|
||||
|
||||
def send_bytes_in_chunks(
|
||||
buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True
|
||||
):
|
||||
buffer = io.BytesIO(buffer)
|
||||
size_in_bytes = bytes_buffer_size(buffer)
|
||||
|
||||
sent_bytes = 0
|
||||
|
||||
logging_method = logging.info if not silent else logging.debug
|
||||
|
||||
logging_method(f"{log_prefix} Buffer size {size_in_bytes/1024/1024} MB with")
|
||||
|
||||
while sent_bytes < size_in_bytes:
|
||||
transfer_state = hilserl_pb2.TransferState.TRANSFER_MIDDLE
|
||||
|
||||
if sent_bytes + CHUNK_SIZE >= size_in_bytes:
|
||||
transfer_state = hilserl_pb2.TransferState.TRANSFER_END
|
||||
elif sent_bytes == 0:
|
||||
transfer_state = hilserl_pb2.TransferState.TRANSFER_BEGIN
|
||||
|
||||
size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
|
||||
chunk = buffer.read(size_to_read)
|
||||
|
||||
yield message_class(transfer_state=transfer_state, data=chunk)
|
||||
sent_bytes += size_to_read
|
||||
logging_method(
|
||||
f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}"
|
||||
)
|
||||
|
||||
logging_method(f"{log_prefix} Published {sent_bytes/1024/1024} MB")
|
||||
|
||||
|
||||
def receive_bytes_in_chunks(
|
||||
iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""
|
||||
):
|
||||
bytes_buffer = io.BytesIO()
|
||||
step = 0
|
||||
|
||||
logging.info(f"{log_prefix} Starting receiver")
|
||||
for item in iterator:
|
||||
logging.debug(f"{log_prefix} Received item")
|
||||
if shutdown_event.is_set():
|
||||
logging.info(f"{log_prefix} Shutting down receiver")
|
||||
return
|
||||
|
||||
if item.transfer_state == hilserl_pb2.TransferState.TRANSFER_BEGIN:
|
||||
bytes_buffer.seek(0)
|
||||
bytes_buffer.truncate(0)
|
||||
bytes_buffer.write(item.data)
|
||||
logging.debug(f"{log_prefix} Received data at step 0")
|
||||
step = 0
|
||||
continue
|
||||
elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_MIDDLE:
|
||||
bytes_buffer.write(item.data)
|
||||
step += 1
|
||||
logging.debug(f"{log_prefix} Received data at step {step}")
|
||||
elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_END:
|
||||
bytes_buffer.write(item.data)
|
||||
logging.debug(
|
||||
f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}"
|
||||
)
|
||||
|
||||
queue.put(bytes_buffer.getvalue())
|
||||
|
||||
bytes_buffer.seek(0)
|
||||
bytes_buffer.truncate(0)
|
||||
step = 0
|
||||
|
||||
logging.debug(f"{log_prefix} Queue updated")
|
||||
72
lerobot/scripts/server/utils.py
Normal file
72
lerobot/scripts/server/utils.py
Normal file
@@ -0,0 +1,72 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
from torch.multiprocessing import Queue
|
||||
from queue import Empty
|
||||
|
||||
shutdown_event_counter = 0
|
||||
|
||||
|
||||
def setup_process_handlers(use_threads: bool) -> any:
|
||||
if use_threads:
|
||||
from threading import Event
|
||||
else:
|
||||
from multiprocessing import Event
|
||||
|
||||
shutdown_event = Event()
|
||||
|
||||
# Define signal handler
|
||||
def signal_handler(signum, frame):
|
||||
logging.info("Shutdown signal received. Cleaning up...")
|
||||
shutdown_event.set()
|
||||
global shutdown_event_counter
|
||||
shutdown_event_counter += 1
|
||||
|
||||
if shutdown_event_counter > 1:
|
||||
logging.info("Force shutdown")
|
||||
sys.exit(1)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
||||
signal.signal(signal.SIGTERM, signal_handler) # Termination request (kill)
|
||||
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed/Hangup
|
||||
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
logging.info("Shutdown signal received. Cleaning up...")
|
||||
shutdown_event.set()
|
||||
|
||||
return shutdown_event
|
||||
|
||||
|
||||
def get_last_item_from_queue(queue: Queue):
|
||||
item = queue.get()
|
||||
counter = 1
|
||||
|
||||
# Drain queue and keep only the most recent parameters
|
||||
try:
|
||||
while True:
|
||||
item = queue.get_nowait()
|
||||
counter += 1
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
logging.debug(f"Drained {counter} items from queue")
|
||||
|
||||
return item
|
||||
@@ -71,7 +71,9 @@ def make_optimizer_and_scheduler(cfg, policy):
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(
|
||||
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
|
||||
optimizer_params_dicts,
|
||||
lr=cfg.training.lr,
|
||||
weight_decay=cfg.training.weight_decay,
|
||||
)
|
||||
lr_scheduler = None
|
||||
elif cfg.policy.name == "diffusion":
|
||||
@@ -98,14 +100,23 @@ def make_optimizer_and_scheduler(cfg, policy):
|
||||
optimizer = torch.optim.Adam(
|
||||
[
|
||||
{"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
|
||||
{"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr},
|
||||
{"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr},
|
||||
{
|
||||
"params": policy.critic_ensemble.parameters(),
|
||||
"lr": policy.config.critic_lr,
|
||||
},
|
||||
{
|
||||
"params": policy.temperature.parameters(),
|
||||
"lr": policy.config.temperature_lr,
|
||||
},
|
||||
]
|
||||
)
|
||||
lr_scheduler = None
|
||||
|
||||
elif cfg.policy.name == "vqbet":
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import (
|
||||
VQBeTOptimizer,
|
||||
VQBeTScheduler,
|
||||
)
|
||||
|
||||
optimizer = VQBeTOptimizer(policy, cfg)
|
||||
lr_scheduler = VQBeTScheduler(optimizer, cfg)
|
||||
@@ -255,7 +266,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
logging.info(pformat(OmegaConf.to_container(cfg)))
|
||||
|
||||
if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig):
|
||||
raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
|
||||
raise NotImplementedError(
|
||||
"Online training with LeRobotMultiDataset is not implemented."
|
||||
)
|
||||
|
||||
# If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need
|
||||
# to check for any differences between the provided config and the checkpoint's config.
|
||||
@@ -265,7 +278,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
"You have set resume=True, but there is no model checkpoint in "
|
||||
f"{Logger.get_last_checkpoint_dir(out_dir)}"
|
||||
)
|
||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
||||
checkpoint_cfg_path = str(
|
||||
Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml"
|
||||
)
|
||||
logging.info(
|
||||
colored(
|
||||
"You have set resume=True, indicating that you wish to resume a run",
|
||||
@@ -278,7 +293,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
# Check for differences between the checkpoint configuration and provided configuration.
|
||||
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
|
||||
resolve_delta_timestamps(cfg)
|
||||
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
||||
diff = DeepDiff(
|
||||
OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)
|
||||
)
|
||||
# Ignore the `resume` and parameters.
|
||||
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
|
||||
del diff["values_changed"]["root['resume']"]
|
||||
@@ -325,7 +342,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
# TODO (michel-aractingi): temporary fix to avoid datasets with task_index key that doesn't exist in online environment
|
||||
# i.e., pusht
|
||||
if "task_index" in offline_dataset.hf_dataset[0]:
|
||||
offline_dataset.hf_dataset = offline_dataset.hf_dataset.remove_columns(["task_index"])
|
||||
offline_dataset.hf_dataset = offline_dataset.hf_dataset.remove_columns(
|
||||
["task_index"]
|
||||
)
|
||||
|
||||
if isinstance(offline_dataset, MultiLeRobotDataset):
|
||||
logging.info(
|
||||
@@ -345,7 +364,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
policy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
|
||||
if cfg.resume
|
||||
else None,
|
||||
)
|
||||
assert isinstance(policy, nn.Module)
|
||||
# Create optimizer and scheduler
|
||||
@@ -358,36 +379,58 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
if cfg.resume:
|
||||
step = logger.load_last_training_state(optimizer, lr_scheduler)
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_learnable_params = sum(
|
||||
p.numel() for p in policy.parameters() if p.requires_grad
|
||||
)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
log_output_dir(out_dir)
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
|
||||
logging.info(
|
||||
f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})"
|
||||
)
|
||||
logging.info(f"{cfg.training.online_steps=}")
|
||||
logging.info(f"{offline_dataset.num_frames=} ({format_big_number(offline_dataset.num_frames)})")
|
||||
logging.info(
|
||||
f"{offline_dataset.num_frames=} ({format_big_number(offline_dataset.num_frames)})"
|
||||
)
|
||||
logging.info(f"{offline_dataset.num_episodes=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
# Note: this helper will be used in offline and online training loops.
|
||||
def evaluate_and_checkpoint_if_needed(step, is_online):
|
||||
_num_digits = max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
|
||||
_num_digits = max(
|
||||
6, len(str(cfg.training.offline_steps + cfg.training.online_steps))
|
||||
)
|
||||
step_identifier = f"{step:0{_num_digits}d}"
|
||||
|
||||
if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0:
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type)
|
||||
if cfg.use_amp
|
||||
else nullcontext(),
|
||||
):
|
||||
assert eval_env is not None
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
policy,
|
||||
cfg.eval.n_episodes,
|
||||
videos_dir=Path(out_dir) / "eval" / f"videos_step_{step_identifier}",
|
||||
videos_dir=Path(out_dir)
|
||||
/ "eval"
|
||||
/ f"videos_step_{step_identifier}",
|
||||
max_episodes_rendered=4,
|
||||
start_seed=cfg.seed,
|
||||
)
|
||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_online=is_online)
|
||||
log_eval_info(
|
||||
logger,
|
||||
eval_info["aggregated"],
|
||||
step,
|
||||
cfg,
|
||||
offline_dataset,
|
||||
is_online=is_online,
|
||||
)
|
||||
if cfg.wandb.enable:
|
||||
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||
logging.info("Resume training")
|
||||
@@ -456,7 +499,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
train_info["dataloading_s"] = dataloading_s
|
||||
|
||||
if step % cfg.training.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, offline_dataset, is_online=False)
|
||||
log_train_info(
|
||||
logger, train_info, step, cfg, offline_dataset, is_online=False
|
||||
)
|
||||
|
||||
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
||||
# so we pass in step + 1.
|
||||
@@ -489,8 +534,14 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
online_dataset = OnlineBuffer(
|
||||
online_buffer_path,
|
||||
data_spec={
|
||||
**{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.input_shapes.items()},
|
||||
**{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.output_shapes.items()},
|
||||
**{
|
||||
k: {"shape": v, "dtype": np.dtype("float32")}
|
||||
for k, v in policy.config.input_shapes.items()
|
||||
},
|
||||
**{
|
||||
k: {"shape": v, "dtype": np.dtype("float32")}
|
||||
for k, v in policy.config.output_shapes.items()
|
||||
},
|
||||
"next.reward": {"shape": (), "dtype": np.dtype("float32")},
|
||||
"next.done": {"shape": (), "dtype": np.dtype("?")},
|
||||
"next.success": {"shape": (), "dtype": np.dtype("?")},
|
||||
@@ -502,7 +553,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
# If we are doing online rollouts asynchronously, deepcopy the policy to use for online rollouts (this
|
||||
# makes it possible to do online rollouts in parallel with training updates).
|
||||
online_rollout_policy = deepcopy(policy) if cfg.training.do_online_rollout_async else policy
|
||||
online_rollout_policy = (
|
||||
deepcopy(policy) if cfg.training.do_online_rollout_async else policy
|
||||
)
|
||||
|
||||
# Create dataloader for online training.
|
||||
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
||||
@@ -539,7 +592,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
online_step = 0
|
||||
online_rollout_s = 0 # time take to do online rollout
|
||||
update_online_buffer_s = 0 # time taken to update the online buffer with the online rollout data
|
||||
update_online_buffer_s = (
|
||||
0 # time taken to update the online buffer with the online rollout data
|
||||
)
|
||||
# Time taken waiting for the online buffer to finish being updated. This is relevant when using the async
|
||||
# online rollout option.
|
||||
await_update_online_buffer_s = 0
|
||||
@@ -563,11 +618,16 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
online_env,
|
||||
online_rollout_policy,
|
||||
n_episodes=cfg.training.online_rollout_n_episodes,
|
||||
max_episodes_rendered=min(10, cfg.training.online_rollout_n_episodes),
|
||||
max_episodes_rendered=min(
|
||||
10, cfg.training.online_rollout_n_episodes
|
||||
),
|
||||
videos_dir=logger.log_dir / "online_rollout_videos",
|
||||
return_episode_data=True,
|
||||
start_seed=(
|
||||
rollout_start_seed := (rollout_start_seed + cfg.training.batch_size) % 1000000
|
||||
rollout_start_seed := (
|
||||
rollout_start_seed + cfg.training.batch_size
|
||||
)
|
||||
% 1000000
|
||||
),
|
||||
)
|
||||
online_rollout_s = time.perf_counter() - start_rollout_time
|
||||
@@ -577,16 +637,21 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
online_dataset.add_data(eval_info["episodes"])
|
||||
|
||||
# Update the concatenated dataset length used during sampling.
|
||||
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
|
||||
concat_dataset.cumulative_sizes = concat_dataset.cumsum(
|
||||
concat_dataset.datasets
|
||||
)
|
||||
|
||||
# Update the sampling weights.
|
||||
sampler.weights = compute_sampler_weights(
|
||||
offline_dataset,
|
||||
offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0),
|
||||
offline_drop_n_last_frames=cfg.training.get(
|
||||
"drop_n_last_frames", 0
|
||||
),
|
||||
online_dataset=online_dataset,
|
||||
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
|
||||
# this final observation in the offline datasets, but we might add them in future.
|
||||
online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1,
|
||||
online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0)
|
||||
+ 1,
|
||||
online_sampling_ratio=cfg.training.online_sampling_ratio,
|
||||
)
|
||||
sampler.num_frames = len(concat_dataset)
|
||||
@@ -639,7 +704,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
train_info["online_buffer_size"] = len(online_dataset)
|
||||
|
||||
if step % cfg.training.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, online_dataset, is_online=True)
|
||||
log_train_info(
|
||||
logger, train_info, step, cfg, online_dataset, is_online=True
|
||||
)
|
||||
|
||||
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
||||
# so we pass in step + 1.
|
||||
@@ -672,7 +739,9 @@ def train_cli(cfg: dict):
|
||||
)
|
||||
|
||||
|
||||
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
|
||||
def train_notebook(
|
||||
out_dir=None, job_name=None, config_name="default", config_path="../configs"
|
||||
):
|
||||
from hydra import compose, initialize
|
||||
|
||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
import logging
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
import hydra
|
||||
@@ -28,19 +27,22 @@ from termcolor import colored
|
||||
from torch import optim
|
||||
from torch.autograd import profiler
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler, random_split
|
||||
from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.logger import Logger
|
||||
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
|
||||
ClassifierConfig,
|
||||
)
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.scripts.server.buffer import random_shift
|
||||
@@ -50,19 +52,40 @@ def get_model(cfg, logger): # noqa I001
|
||||
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
||||
model = Classifier(classifier_config)
|
||||
if cfg.resume:
|
||||
model.load_state_dict(Classifier.from_pretrained(str(logger.last_pretrained_model_dir)).state_dict())
|
||||
model.load_state_dict(
|
||||
Classifier.from_pretrained(
|
||||
str(logger.last_pretrained_model_dir)
|
||||
).state_dict()
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def create_balanced_sampler(dataset, cfg):
|
||||
# Creates a weighted sampler to handle class imbalance
|
||||
# Get underlying dataset if using Subset
|
||||
original_dataset = (
|
||||
dataset.dataset if isinstance(dataset, torch.utils.data.Subset) else dataset
|
||||
)
|
||||
|
||||
labels = torch.tensor([item[cfg.training.label_key] for item in dataset])
|
||||
# Get indices if using Subset (for slicing)
|
||||
indices = dataset.indices if isinstance(dataset, torch.utils.data.Subset) else None
|
||||
|
||||
# Get labels from Hugging Face dataset
|
||||
if indices is not None:
|
||||
# Get subset of labels using Hugging Face's select()
|
||||
hf_subset = original_dataset.hf_dataset.select(indices)
|
||||
labels = hf_subset[cfg.training.label_key]
|
||||
else:
|
||||
# Get all labels directly
|
||||
labels = original_dataset.hf_dataset[cfg.training.label_key]
|
||||
|
||||
labels = torch.stack(labels)
|
||||
_, counts = torch.unique(labels, return_counts=True)
|
||||
class_weights = 1.0 / counts.float()
|
||||
sample_weights = class_weights[labels]
|
||||
|
||||
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
|
||||
return WeightedRandomSampler(
|
||||
weights=sample_weights, num_samples=len(sample_weights), replacement=True
|
||||
)
|
||||
|
||||
|
||||
def support_amp(device: torch.device, cfg: DictConfig) -> bool:
|
||||
@@ -71,7 +94,9 @@ def support_amp(device: torch.device, cfg: DictConfig) -> bool:
|
||||
return cfg.training.use_amp and device.type in ("cuda", "cpu")
|
||||
|
||||
|
||||
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
|
||||
def train_epoch(
|
||||
model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg
|
||||
):
|
||||
# Single epoch training loop with AMP support and progress tracking
|
||||
model.train()
|
||||
correct = 0
|
||||
@@ -85,7 +110,11 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
|
||||
labels = batch[cfg.training.label_key].float().to(device)
|
||||
|
||||
# Forward pass with optional AMP
|
||||
with torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext():
|
||||
with (
|
||||
torch.autocast(device_type=device.type)
|
||||
if support_amp(device, cfg)
|
||||
else nullcontext()
|
||||
):
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs.logits, labels)
|
||||
|
||||
@@ -130,7 +159,9 @@ def validate(model, val_loader, criterion, device, logger, cfg):
|
||||
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(),
|
||||
torch.autocast(device_type=device.type)
|
||||
if support_amp(device, cfg)
|
||||
else nullcontext(),
|
||||
):
|
||||
for batch in tqdm(val_loader, desc="Validation"):
|
||||
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
|
||||
@@ -143,7 +174,9 @@ def validate(model, val_loader, criterion, device, logger, cfg):
|
||||
):
|
||||
outputs = model(images)
|
||||
inference_times.append(
|
||||
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
|
||||
next(
|
||||
x for x in prof.key_averages() if x.key == "model_inference"
|
||||
).cpu_time
|
||||
)
|
||||
else:
|
||||
outputs = model(images)
|
||||
@@ -161,16 +194,24 @@ def validate(model, val_loader, criterion, device, logger, cfg):
|
||||
|
||||
# Log sample predictions for visualization
|
||||
if len(samples) < cfg.eval.num_samples_to_log:
|
||||
for i in range(min(cfg.eval.num_samples_to_log - len(samples), len(images))):
|
||||
for i in range(
|
||||
min(cfg.eval.num_samples_to_log - len(samples), len(images))
|
||||
):
|
||||
if model.config.num_classes == 2:
|
||||
confidence = round(outputs.probabilities[i].item(), 3)
|
||||
else:
|
||||
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
|
||||
confidence = [
|
||||
round(prob, 3) for prob in outputs.probabilities[i].tolist()
|
||||
]
|
||||
samples.append(
|
||||
{
|
||||
**{
|
||||
f"image_{img_key}": wandb.Image(images[img_idx][i].cpu())
|
||||
for img_idx, img_key in enumerate(cfg.training.image_keys)
|
||||
f"image_{img_key}": wandb.Image(
|
||||
images[img_idx][i].cpu()
|
||||
)
|
||||
for img_idx, img_key in enumerate(
|
||||
cfg.training.image_keys
|
||||
)
|
||||
},
|
||||
"true_label": labels[i].item(),
|
||||
"predicted": predictions[i].item(),
|
||||
@@ -238,15 +279,24 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step):
|
||||
elif device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
|
||||
with profiler.profile(record_shapes=True) as prof, profiler.record_function("model_inference"):
|
||||
with (
|
||||
profiler.profile(record_shapes=True) as prof,
|
||||
profiler.record_function("model_inference"),
|
||||
):
|
||||
_ = model(x)
|
||||
|
||||
inference_times.append(
|
||||
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
|
||||
next(
|
||||
x for x in prof.key_averages() if x.key == "model_inference"
|
||||
).cpu_time
|
||||
)
|
||||
|
||||
inference_times = np.array(inference_times)
|
||||
avg, median, std = inference_times.mean(), np.median(inference_times), inference_times.std()
|
||||
avg, median, std = (
|
||||
inference_times.mean(),
|
||||
np.median(inference_times),
|
||||
inference_times.std(),
|
||||
)
|
||||
print(
|
||||
f"Inference time mean: {avg:.2f} us, median: {median:.2f} us, std: {std:.2f} us, with {iters} iterations on {device.type} device"
|
||||
)
|
||||
@@ -264,21 +314,29 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step):
|
||||
return avg, median, std
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier")
|
||||
def train(cfg: DictConfig) -> None:
|
||||
def train(
|
||||
cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None
|
||||
) -> None:
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
if job_name is None:
|
||||
raise NotImplementedError()
|
||||
|
||||
# Main training pipeline with support for resuming training
|
||||
init_logging()
|
||||
logging.info(OmegaConf.to_yaml(cfg))
|
||||
|
||||
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
||||
|
||||
# Initialize training environment
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
out_dir = hydra.core.hydra_config.HydraConfig.get().run.dir + "frozen_resnet10_2"
|
||||
logger = Logger(cfg, out_dir, cfg.wandb.job_name if cfg.wandb.enable else None)
|
||||
|
||||
# Setup dataset and dataloaders
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset_repo_id, root=cfg.dataset_root, local_files_only=cfg.local_files_only
|
||||
cfg.dataset_repo_id,
|
||||
root=cfg.dataset_root,
|
||||
local_files_only=cfg.local_files_only,
|
||||
)
|
||||
logging.info(f"Dataset size: {len(dataset)}")
|
||||
|
||||
@@ -314,7 +372,9 @@ def train(cfg: DictConfig) -> None:
|
||||
"You have set resume=True, but there is no model checkpoint in "
|
||||
f"{Logger.get_last_checkpoint_dir(out_dir)}"
|
||||
)
|
||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
||||
checkpoint_cfg_path = str(
|
||||
Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml"
|
||||
)
|
||||
logging.info(
|
||||
colored(
|
||||
"You have set resume=True, indicating that you wish to resume a run",
|
||||
@@ -327,7 +387,9 @@ def train(cfg: DictConfig) -> None:
|
||||
# Check for differences between the checkpoint configuration and provided configuration.
|
||||
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
|
||||
resolve_delta_timestamps(cfg)
|
||||
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
||||
diff = DeepDiff(
|
||||
OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)
|
||||
)
|
||||
# Ignore the `resume` and parameters.
|
||||
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
|
||||
del diff["values_changed"]["root['resume']"]
|
||||
@@ -346,7 +408,11 @@ def train(cfg: DictConfig) -> None:
|
||||
|
||||
optimizer = optim.AdamW(model.parameters(), lr=cfg.training.learning_rate)
|
||||
# Use BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multi-class
|
||||
criterion = nn.BCEWithLogitsLoss() if model.config.num_classes == 2 else nn.CrossEntropyLoss()
|
||||
criterion = (
|
||||
nn.BCEWithLogitsLoss()
|
||||
if model.config.num_classes == 2
|
||||
else nn.CrossEntropyLoss()
|
||||
)
|
||||
grad_scaler = GradScaler(enabled=cfg.training.use_amp)
|
||||
|
||||
# Log model parameters
|
||||
@@ -362,7 +428,17 @@ def train(cfg: DictConfig) -> None:
|
||||
for epoch in range(cfg.training.num_epochs):
|
||||
logging.info(f"\nEpoch {epoch+1}/{cfg.training.num_epochs}")
|
||||
|
||||
train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg)
|
||||
train_epoch(
|
||||
model,
|
||||
train_loader,
|
||||
criterion,
|
||||
optimizer,
|
||||
grad_scaler,
|
||||
device,
|
||||
logger,
|
||||
step,
|
||||
cfg,
|
||||
)
|
||||
|
||||
# Periodic validation
|
||||
if cfg.training.eval_freq > 0 and (epoch + 1) % cfg.training.eval_freq == 0:
|
||||
@@ -404,5 +480,32 @@ def train(cfg: DictConfig) -> None:
|
||||
logging.info("Training completed")
|
||||
|
||||
|
||||
@hydra.main(
|
||||
version_base="1.2",
|
||||
config_name="hilserl_classifier",
|
||||
config_path="../configs/policy",
|
||||
)
|
||||
def train_cli(cfg: dict):
|
||||
train(
|
||||
cfg,
|
||||
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
||||
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
||||
)
|
||||
|
||||
|
||||
def train_notebook(
|
||||
out_dir=None,
|
||||
job_name=None,
|
||||
config_name="hilserl_classifier",
|
||||
config_path="../configs/policy",
|
||||
):
|
||||
from hydra import compose, initialize
|
||||
|
||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||
initialize(config_path=config_path)
|
||||
cfg = compose(config_name=config_name)
|
||||
train(cfg, out_dir=out_dir, job_name=job_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
train_cli()
|
||||
|
||||
@@ -22,7 +22,6 @@ from typing import Callable, Optional, Sequence, TypedDict
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
@@ -30,20 +29,17 @@ from tqdm import tqdm
|
||||
# TODO: Remove the import of maniskill
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.envs.factory import make_env, make_maniskill_env
|
||||
from lerobot.common.envs.utils import preprocess_maniskill_observation, preprocess_observation
|
||||
from lerobot.common.envs.factory import make_maniskill_env
|
||||
from lerobot.common.envs.utils import preprocess_maniskill_observation
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
|
||||
|
||||
def make_optimizers_and_scheduler(cfg, policy):
|
||||
@@ -56,7 +52,9 @@ def make_optimizers_and_scheduler(cfg, policy):
|
||||
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
|
||||
)
|
||||
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
|
||||
optimizer_temperature = torch.optim.Adam(
|
||||
params=[policy.log_alpha], lr=policy.config.critic_lr
|
||||
)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
@@ -108,7 +106,9 @@ def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Te
|
||||
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
|
||||
|
||||
# Gather pixels
|
||||
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
|
||||
cropped_hwcn = images_hwcn[
|
||||
torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :
|
||||
]
|
||||
# cropped_hwcn => (B, crop_h, crop_w, C)
|
||||
|
||||
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
|
||||
@@ -198,8 +198,12 @@ class ReplayBuffer:
|
||||
"""
|
||||
# We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from
|
||||
# a replay buffer than from a lerobot dataset.
|
||||
replay_buffer = cls(capacity=len(lerobot_dataset), device=device, state_keys=state_keys)
|
||||
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
|
||||
replay_buffer = cls(
|
||||
capacity=len(lerobot_dataset), device=device, state_keys=state_keys
|
||||
)
|
||||
list_transition = cls._lerobotdataset_to_transitions(
|
||||
dataset=lerobot_dataset, state_keys=state_keys
|
||||
)
|
||||
# Fill the replay buffer with the lerobot dataset transitions
|
||||
for data in list_transition:
|
||||
replay_buffer.add(
|
||||
@@ -244,7 +248,9 @@ class ReplayBuffer:
|
||||
|
||||
# If not provided, you can either raise an error or define a default:
|
||||
if state_keys is None:
|
||||
raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.")
|
||||
raise ValueError(
|
||||
"You must provide a list of keys in `state_keys` that define your 'state'."
|
||||
)
|
||||
|
||||
transitions: list[Transition] = []
|
||||
num_frames = len(dataset)
|
||||
@@ -298,36 +304,40 @@ class ReplayBuffer:
|
||||
# -- Build batched states --
|
||||
batch_state = {}
|
||||
for key in self.state_keys:
|
||||
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to(
|
||||
self.device
|
||||
)
|
||||
batch_state[key] = torch.cat(
|
||||
[t["state"][key] for t in list_of_transitions], dim=0
|
||||
).to(self.device)
|
||||
if key.startswith("observation.image") and self.use_drq:
|
||||
batch_state[key] = self.image_augmentation_function(batch_state[key])
|
||||
|
||||
# -- Build batched actions --
|
||||
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device)
|
||||
|
||||
# -- Build batched rewards --
|
||||
batch_rewards = torch.tensor([t["reward"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
# -- Build batched rewards --
|
||||
batch_rewards = torch.tensor(
|
||||
[t["reward"] for t in list_of_transitions], dtype=torch.float32
|
||||
).to(self.device)
|
||||
|
||||
# -- Build batched next states --
|
||||
batch_next_state = {}
|
||||
for key in self.state_keys:
|
||||
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to(
|
||||
self.device
|
||||
)
|
||||
batch_next_state[key] = torch.cat(
|
||||
[t["next_state"][key] for t in list_of_transitions], dim=0
|
||||
).to(self.device)
|
||||
if key.startswith("observation.image") and self.use_drq:
|
||||
batch_next_state[key] = self.image_augmentation_function(batch_next_state[key])
|
||||
batch_next_state[key] = self.image_augmentation_function(
|
||||
batch_next_state[key]
|
||||
)
|
||||
|
||||
# -- Build batched dones --
|
||||
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
batch_dones = torch.tensor(
|
||||
[t["done"] for t in list_of_transitions], dtype=torch.float32
|
||||
).to(self.device)
|
||||
batch_dones = torch.tensor(
|
||||
[t["done"] for t in list_of_transitions], dtype=torch.float32
|
||||
).to(self.device)
|
||||
|
||||
# Return a BatchTransition typed dict
|
||||
return BatchTransition(
|
||||
@@ -344,7 +354,13 @@ def concatenate_batch_transitions(
|
||||
) -> BatchTransition:
|
||||
"""NOTE: Be careful it change the left_batch_transitions in place"""
|
||||
left_batch_transitions["state"] = {
|
||||
key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0)
|
||||
key: torch.cat(
|
||||
[
|
||||
left_batch_transitions["state"][key],
|
||||
right_batch_transition["state"][key],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
for key in left_batch_transitions["state"]
|
||||
}
|
||||
left_batch_transitions["action"] = torch.cat(
|
||||
@@ -355,7 +371,11 @@ def concatenate_batch_transitions(
|
||||
)
|
||||
left_batch_transitions["next_state"] = {
|
||||
key: torch.cat(
|
||||
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0
|
||||
[
|
||||
left_batch_transitions["next_state"][key],
|
||||
right_batch_transition["next_state"][key],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
for key in left_batch_transitions["next_state"]
|
||||
}
|
||||
@@ -407,7 +427,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online traning, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
|
||||
if cfg.resume
|
||||
else None,
|
||||
device=device,
|
||||
)
|
||||
assert isinstance(policy, nn.Module)
|
||||
@@ -416,7 +438,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
# TODO: Handle resume
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_learnable_params = sum(
|
||||
p.numel() for p in policy.parameters() if p.requires_grad
|
||||
)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
log_output_dir(out_dir)
|
||||
@@ -433,7 +457,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
obs = {key: obs[key].to(device, non_blocking=True) for key in obs}
|
||||
|
||||
replay_buffer = ReplayBuffer(
|
||||
capacity=cfg.training.online_buffer_capacity, device=device, state_keys=cfg.policy.input_shapes.keys()
|
||||
capacity=cfg.training.online_buffer_capacity,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
)
|
||||
|
||||
batch_size = cfg.training.batch_size
|
||||
@@ -455,12 +481,16 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
if interaction_step >= cfg.training.online_step_before_learning:
|
||||
action = policy.select_action(batch=obs)
|
||||
next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy())
|
||||
next_obs, reward, done, truncated, info = online_env.step(
|
||||
action.cpu().numpy()
|
||||
)
|
||||
else:
|
||||
action = online_env.action_space.sample()
|
||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||
# HACK
|
||||
action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True)
|
||||
action = torch.tensor(action, dtype=torch.float32).to(
|
||||
device, non_blocking=True
|
||||
)
|
||||
|
||||
# HACK: For maniskill
|
||||
# next_obs = preprocess_observation(next_obs)
|
||||
@@ -470,14 +500,20 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
# Because we are using a single environment
|
||||
# we can safely assume that the episode is done
|
||||
if done[0] or truncated[0]:
|
||||
logging.info(f"Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
logger.log_dict({"Sum episode reward": sum_reward_episode}, interaction_step)
|
||||
logging.info(
|
||||
f"Global step {interaction_step}: Episode reward: {sum_reward_episode}"
|
||||
)
|
||||
logger.log_dict(
|
||||
{"Sum episode reward": sum_reward_episode}, interaction_step
|
||||
)
|
||||
sum_reward_episode = 0
|
||||
# HACK: This is for maniskill
|
||||
logging.info(
|
||||
f"global step {interaction_step}: episode success: {info['success'].float().item()} \n"
|
||||
)
|
||||
logger.log_dict({"Episode success": info["success"].float().item()}, interaction_step)
|
||||
logger.log_dict(
|
||||
{"Episode success": info["success"].float().item()}, interaction_step
|
||||
)
|
||||
|
||||
replay_buffer.add(
|
||||
state=obs,
|
||||
@@ -551,7 +587,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
|
||||
loss_temperature = policy.compute_loss_temperature(observations=observations)
|
||||
loss_temperature = policy.compute_loss_temperature(
|
||||
observations=observations
|
||||
)
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
optimizers["temperature"].step()
|
||||
@@ -573,7 +611,9 @@ def train_cli(cfg: dict):
|
||||
)
|
||||
|
||||
|
||||
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
|
||||
def train_notebook(
|
||||
out_dir=None, job_name=None, config_name="default", config_path="../configs"
|
||||
):
|
||||
from hydra import compose, initialize
|
||||
|
||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||
|
||||
@@ -94,8 +94,12 @@ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||
assert chw_float32_torch.dtype == torch.float32
|
||||
assert chw_float32_torch.ndim == 3
|
||||
c, h, w = chw_float32_torch.shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}"
|
||||
hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
|
||||
assert (
|
||||
c < h and c < w
|
||||
), f"expect channel first images, but instead {chw_float32_torch.shape}"
|
||||
hwc_uint8_numpy = (
|
||||
(chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
|
||||
)
|
||||
return hwc_uint8_numpy
|
||||
|
||||
|
||||
|
||||
@@ -81,7 +81,11 @@ def run_server(
|
||||
static_folder: Path,
|
||||
template_folder: Path,
|
||||
):
|
||||
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
|
||||
app = Flask(
|
||||
__name__,
|
||||
static_folder=static_folder.resolve(),
|
||||
template_folder=template_folder.resolve(),
|
||||
)
|
||||
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
|
||||
|
||||
@app.route("/")
|
||||
@@ -138,8 +142,12 @@ def run_server(
|
||||
)
|
||||
)
|
||||
|
||||
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
|
||||
def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
|
||||
@app.route(
|
||||
"/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>"
|
||||
)
|
||||
def show_episode(
|
||||
dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes
|
||||
):
|
||||
repo_id = f"{dataset_namespace}/{dataset_name}"
|
||||
try:
|
||||
if dataset is None:
|
||||
@@ -150,7 +158,9 @@ def run_server(
|
||||
400,
|
||||
)
|
||||
dataset_version = (
|
||||
dataset.meta._version if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
|
||||
dataset.meta._version
|
||||
if isinstance(dataset, LeRobotDataset)
|
||||
else dataset.codebase_version
|
||||
)
|
||||
match = re.search(r"v(\d+)\.", dataset_version)
|
||||
if match:
|
||||
@@ -171,15 +181,21 @@ def run_server(
|
||||
}
|
||||
if isinstance(dataset, LeRobotDataset):
|
||||
video_paths = [
|
||||
dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
|
||||
dataset.meta.get_video_file_path(episode_id, key)
|
||||
for key in dataset.meta.video_keys
|
||||
]
|
||||
videos_info = [
|
||||
{"url": url_for("static", filename=video_path), "filename": video_path.parent.name}
|
||||
{
|
||||
"url": url_for("static", filename=video_path),
|
||||
"filename": video_path.parent.name,
|
||||
}
|
||||
for video_path in video_paths
|
||||
]
|
||||
tasks = dataset.meta.episodes[episode_id]["tasks"]
|
||||
else:
|
||||
video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"]
|
||||
video_keys = [
|
||||
key for key, ft in dataset.features.items() if ft["dtype"] == "video"
|
||||
]
|
||||
videos_info = [
|
||||
{
|
||||
"url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
|
||||
@@ -198,16 +214,24 @@ def run_server(
|
||||
)
|
||||
response.raise_for_status()
|
||||
# Split into lines and parse each line as JSON
|
||||
tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()]
|
||||
tasks_jsonl = [
|
||||
json.loads(line) for line in response.text.splitlines() if line.strip()
|
||||
]
|
||||
|
||||
filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id]
|
||||
filtered_tasks_jsonl = [
|
||||
row for row in tasks_jsonl if row["episode_index"] == episode_id
|
||||
]
|
||||
tasks = filtered_tasks_jsonl[0]["tasks"]
|
||||
|
||||
videos_info[0]["language_instruction"] = tasks
|
||||
|
||||
if episodes is None:
|
||||
episodes = list(
|
||||
range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
|
||||
range(
|
||||
dataset.num_episodes
|
||||
if isinstance(dataset, LeRobotDataset)
|
||||
else dataset.total_episodes
|
||||
)
|
||||
)
|
||||
|
||||
return render_template(
|
||||
@@ -233,7 +257,9 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
||||
This file will be loaded by Dygraph javascript to plot data in real time."""
|
||||
columns = []
|
||||
|
||||
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] == "float32"]
|
||||
selected_columns = [
|
||||
col for col, ft in dataset.features.items() if ft["dtype"] == "float32"
|
||||
]
|
||||
selected_columns.remove("timestamp")
|
||||
|
||||
# init header of csv with state and action names
|
||||
@@ -247,7 +273,10 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
||||
)
|
||||
header += [f"{column_name}_{i}" for i in range(dim_state)]
|
||||
|
||||
if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
|
||||
if (
|
||||
"names" in dataset.features[column_name]
|
||||
and dataset.features[column_name]["names"]
|
||||
):
|
||||
column_names = dataset.features[column_name]["names"]
|
||||
while not isinstance(column_names, list):
|
||||
column_names = list(column_names.values())[0]
|
||||
@@ -268,8 +297,12 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
||||
else:
|
||||
repo_id = dataset.repo_id
|
||||
|
||||
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
|
||||
episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index
|
||||
url = (
|
||||
f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
|
||||
+ dataset.data_path.format(
|
||||
episode_chunk=int(episode_index) // dataset.chunks_size,
|
||||
episode_index=episode_index,
|
||||
)
|
||||
)
|
||||
df = pd.read_parquet(url)
|
||||
data = df[selected_columns] # Select specific columns
|
||||
@@ -302,7 +335,9 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
|
||||
]
|
||||
|
||||
|
||||
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
||||
def get_episode_language_instruction(
|
||||
dataset: LeRobotDataset, ep_index: int
|
||||
) -> list[str]:
|
||||
# check if the dataset has language instructions
|
||||
if "language_instruction" not in dataset.features:
|
||||
return None
|
||||
@@ -313,11 +348,15 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) ->
|
||||
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
|
||||
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
|
||||
# with the tf.tensor appearing in the string
|
||||
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
|
||||
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix(
|
||||
"', shape=(), dtype=string)"
|
||||
)
|
||||
|
||||
|
||||
def get_dataset_info(repo_id: str) -> IterableNamespace:
|
||||
response = requests.get(f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json")
|
||||
response = requests.get(
|
||||
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json"
|
||||
)
|
||||
response.raise_for_status() # Raises an HTTPError for bad responses
|
||||
dataset_info = response.json()
|
||||
dataset_info["repo_id"] = repo_id
|
||||
@@ -346,7 +385,9 @@ def visualize_dataset_html(
|
||||
if force_override:
|
||||
shutil.rmtree(output_dir)
|
||||
else:
|
||||
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
|
||||
logging.info(
|
||||
f"Output directory already exists. Loading from it: '{output_dir}'"
|
||||
)
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
@@ -162,8 +162,12 @@ def visualize_transforms(cfg, output_dir: Path, n_examples: int = 5):
|
||||
print("\nOriginal frame saved to:")
|
||||
print(f" {output_dir / 'original_frame.png'}.")
|
||||
|
||||
save_config_all_transforms(cfg.training.image_transforms, original_frame, output_dir, n_examples)
|
||||
save_config_single_transforms(cfg.training.image_transforms, original_frame, output_dir, n_examples)
|
||||
save_config_all_transforms(
|
||||
cfg.training.image_transforms, original_frame, output_dir, n_examples
|
||||
)
|
||||
save_config_single_transforms(
|
||||
cfg.training.image_transforms, original_frame, output_dir, n_examples
|
||||
)
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
<script defer src="https://cdn.jsdelivr.net/npm/alpinejs@3.x.x/dist/cdn.min.js"></script>
|
||||
</head>
|
||||
<body class="h-screen overflow-hidden font-mono text-white" x-data="{
|
||||
<body class="h-screen overflow-hidden font-mono text-white" x-data="{
|
||||
inputValue: '',
|
||||
navigateToDataset() {
|
||||
const trimmedValue = this.inputValue.trim();
|
||||
@@ -40,14 +40,14 @@
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex w-full max-w-lg px-4 mb-4">
|
||||
<input
|
||||
type="text"
|
||||
<input
|
||||
type="text"
|
||||
x-model="inputValue"
|
||||
@keyup.enter="navigateToDataset"
|
||||
placeholder="enter dataset id (ex: lerobot/droid_100)"
|
||||
class="flex-grow px-4 py-2 rounded-l bg-white bg-opacity-20 text-white placeholder-gray-300 focus:outline-none focus:ring-2 focus:ring-blue-300"
|
||||
>
|
||||
<button
|
||||
<button
|
||||
@click="navigateToDataset"
|
||||
class="px-4 py-2 bg-blue-500 text-white rounded-r hover:bg-blue-600 focus:outline-none focus:ring-2 focus:ring-blue-300"
|
||||
>
|
||||
@@ -65,4 +65,4 @@
|
||||
</details>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
</html>
|
||||
|
||||
@@ -107,8 +107,8 @@
|
||||
<span class="truncate">filter videos</span>
|
||||
<div class="transition-transform" :class="{ 'rotate-180': isVideosDropdownOpen }">🔽</div>
|
||||
</div>
|
||||
|
||||
<div x-show="isVideosDropdownOpen"
|
||||
|
||||
<div x-show="isVideosDropdownOpen"
|
||||
class="absolute mt-1 border border-slate-500 rounded shadow-lg z-10">
|
||||
<div>
|
||||
<template x-for="option in videosKeys" :key="option">
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user