Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots

This commit is contained in:
Simon Alibert
2025-03-10 18:39:48 +01:00
135 changed files with 2177 additions and 514 deletions

View File

@@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script configure a single motor at a time to a given ID and baudrate.

View File

@@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities to control a robot.
@@ -254,7 +267,7 @@ def record(
)
# Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, cfg.device, ds_meta=dataset.meta)
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
if not robot.is_connected:
robot.connect()
@@ -285,8 +298,6 @@ def record(
episode_time_s=cfg.episode_time_s,
display_cameras=cfg.display_cameras,
policy=policy,
device=cfg.device,
use_amp=cfg.use_amp,
fps=cfg.fps,
single_task=cfg.single_task,
)

View File

@@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities to control a robot in simulation.

View File

@@ -454,11 +454,11 @@ def _compile_episode_data(
@parser.wrap()
def eval(cfg: EvalPipelineConfig):
def eval_main(cfg: EvalPipelineConfig):
logging.info(pformat(asdict(cfg)))
# Check device is available
device = get_safe_torch_device(cfg.device, log=True)
device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@@ -470,14 +470,14 @@ def eval(cfg: EvalPipelineConfig):
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
logging.info("Making policy.")
policy = make_policy(
cfg=cfg.policy,
device=device,
env_cfg=cfg.env,
)
policy.eval()
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy(
env,
policy,
@@ -499,4 +499,4 @@ def eval(cfg: EvalPipelineConfig):
if __name__ == "__main__":
init_logging()
eval()
eval_main()

View File

@@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
from pathlib import Path

View File

@@ -120,7 +120,7 @@ def train(cfg: TrainPipelineConfig):
set_seed(cfg.seed)
# Check device is available
device = get_safe_torch_device(cfg.device, log=True)
device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@@ -138,13 +138,12 @@ def train(cfg: TrainPipelineConfig):
logging.info("Creating policy")
policy = make_policy(
cfg=cfg.policy,
device=device,
ds_meta=dataset.meta,
)
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(device, enabled=cfg.use_amp)
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
step = 0 # number of policy updates (forward + backward + optim)
@@ -218,7 +217,7 @@ def train(cfg: TrainPipelineConfig):
cfg.optimizer.grad_clip_norm,
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.use_amp,
use_amp=cfg.policy.use_amp,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
@@ -249,7 +248,10 @@ def train(cfg: TrainPipelineConfig):
if cfg.env and is_eval_step:
step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}")
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
with (
torch.no_grad(),
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
):
eval_info = eval_policy(
eval_env,
policy,

View File

@@ -158,7 +158,7 @@ def run_server(
if major_version < 2:
return "Make sure to convert your LeRobotDataset to v2 & above."
episode_data_csv_str, columns = get_episode_data(dataset, episode_id)
episode_data_csv_str, columns, ignored_columns = get_episode_data(dataset, episode_id)
dataset_info = {
"repo_id": f"{dataset_namespace}/{dataset_name}",
"num_samples": dataset.num_frames
@@ -194,7 +194,7 @@ def run_server(
]
response = requests.get(
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl"
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", timeout=5
)
response.raise_for_status()
# Split into lines and parse each line as JSON
@@ -218,6 +218,7 @@ def run_server(
videos_info=videos_info,
episode_data_csv_str=episode_data_csv_str,
columns=columns,
ignored_columns=ignored_columns,
)
app.run(host=host, port=port)
@@ -233,9 +234,17 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
This file will be loaded by Dygraph javascript to plot data in real time."""
columns = []
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] == "float32"]
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]]
selected_columns.remove("timestamp")
ignored_columns = []
for column_name in selected_columns:
shape = dataset.features[column_name]["shape"]
shape_dim = len(shape)
if shape_dim > 1:
selected_columns.remove(column_name)
ignored_columns.append(column_name)
# init header of csv with state and action names
header = ["timestamp"]
@@ -245,16 +254,17 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
if isinstance(dataset, LeRobotDataset)
else dataset.features[column_name].shape[0]
)
header += [f"{column_name}_{i}" for i in range(dim_state)]
if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
column_names = dataset.features[column_name]["names"]
while not isinstance(column_names, list):
column_names = list(column_names.values())[0]
else:
column_names = [f"motor_{i}" for i in range(dim_state)]
column_names = [f"{column_name}_{i}" for i in range(dim_state)]
columns.append({"key": column_name, "value": column_names})
header += column_names
selected_columns.insert(0, "timestamp")
if isinstance(dataset, LeRobotDataset):
@@ -290,7 +300,7 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
csv_writer.writerows(rows)
csv_string = csv_buffer.getvalue()
return csv_string, columns
return csv_string, columns, ignored_columns
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
@@ -317,7 +327,9 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) ->
def get_dataset_info(repo_id: str) -> IterableNamespace:
response = requests.get(f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json")
response = requests.get(
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", timeout=5
)
response.raise_for_status() # Raises an HTTPError for bad responses
dataset_info = response.json()
dataset_info["repo_id"] = repo_id